Skip to content

Commit 0c7f494

Browse files
authored
fix(inference_store): on duplicate chat completion IDs, replace (#3408)
# What does this PR do? Duplicate chat completion IDs can be generated during tests especially if they are replaying recorded responses across different tests. No need to warn or error under those circumstances. In the wild, this is not likely to happen at all (no evidence) so we aren't really hiding any problem.
1 parent c04f1c1 commit 0c7f494

File tree

2 files changed

+54
-10
lines changed

2 files changed

+54
-10
lines changed

llama_stack/providers/utils/inference/inference_store.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import asyncio
77
from typing import Any
88

9+
from sqlalchemy.exc import IntegrityError
10+
911
from llama_stack.apis.inference import (
1012
ListOpenAIChatCompletionResponse,
1113
OpenAIChatCompletion,
@@ -129,16 +131,44 @@ async def _write_chat_completion(
129131
raise ValueError("Inference store is not initialized")
130132

131133
data = chat_completion.model_dump()
132-
133-
await self.sql_store.insert(
134-
table="chat_completions",
135-
data={
136-
"id": data["id"],
137-
"created": data["created"],
138-
"model": data["model"],
139-
"choices": data["choices"],
140-
"input_messages": [message.model_dump() for message in input_messages],
141-
},
134+
record_data = {
135+
"id": data["id"],
136+
"created": data["created"],
137+
"model": data["model"],
138+
"choices": data["choices"],
139+
"input_messages": [message.model_dump() for message in input_messages],
140+
}
141+
142+
try:
143+
await self.sql_store.insert(
144+
table="chat_completions",
145+
data=record_data,
146+
)
147+
except IntegrityError as e:
148+
# Duplicate chat completion IDs can be generated during tests especially if they are replaying
149+
# recorded responses across different tests. No need to warn or error under those circumstances.
150+
# In the wild, this is not likely to happen at all (no evidence) so we aren't really hiding any problem.
151+
152+
# Check if it's a unique constraint violation
153+
error_message = str(e.orig) if e.orig else str(e)
154+
if self._is_unique_constraint_error(error_message):
155+
# Update the existing record instead
156+
await self.sql_store.update(table="chat_completions", data=record_data, where={"id": data["id"]})
157+
else:
158+
# Re-raise if it's not a unique constraint error
159+
raise
160+
161+
def _is_unique_constraint_error(self, error_message: str) -> bool:
162+
"""Check if the error is specifically a unique constraint violation."""
163+
error_lower = error_message.lower()
164+
return any(
165+
indicator in error_lower
166+
for indicator in [
167+
"unique constraint failed", # SQLite
168+
"duplicate key", # PostgreSQL
169+
"unique violation", # PostgreSQL alternative
170+
"duplicate entry", # MySQL
171+
]
142172
)
143173

144174
async def list_chat_completions(

llama_stack/providers/utils/sqlstore/authorized_sqlstore.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,20 @@ async def fetch_one(
172172

173173
return results.data[0] if results.data else None
174174

175+
async def update(self, table: str, data: Mapping[str, Any], where: Mapping[str, Any]) -> None:
176+
"""Update rows with automatic access control attribute capture."""
177+
enhanced_data = dict(data)
178+
179+
current_user = get_authenticated_user()
180+
if current_user:
181+
enhanced_data["owner_principal"] = current_user.principal
182+
enhanced_data["access_attributes"] = current_user.attributes
183+
else:
184+
enhanced_data["owner_principal"] = None
185+
enhanced_data["access_attributes"] = None
186+
187+
await self.sql_store.update(table, enhanced_data, where)
188+
175189
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
176190
"""Delete rows with automatic access control filtering."""
177191
await self.sql_store.delete(table, where)

0 commit comments

Comments
 (0)