|
6 | 6 | import asyncio |
7 | 7 | from typing import Any |
8 | 8 |
|
| 9 | +from sqlalchemy.exc import IntegrityError |
| 10 | + |
9 | 11 | from llama_stack.apis.inference import ( |
10 | 12 | ListOpenAIChatCompletionResponse, |
11 | 13 | OpenAIChatCompletion, |
@@ -129,16 +131,44 @@ async def _write_chat_completion( |
129 | 131 | raise ValueError("Inference store is not initialized") |
130 | 132 |
|
131 | 133 | data = chat_completion.model_dump() |
132 | | - |
133 | | - await self.sql_store.insert( |
134 | | - table="chat_completions", |
135 | | - data={ |
136 | | - "id": data["id"], |
137 | | - "created": data["created"], |
138 | | - "model": data["model"], |
139 | | - "choices": data["choices"], |
140 | | - "input_messages": [message.model_dump() for message in input_messages], |
141 | | - }, |
| 134 | + record_data = { |
| 135 | + "id": data["id"], |
| 136 | + "created": data["created"], |
| 137 | + "model": data["model"], |
| 138 | + "choices": data["choices"], |
| 139 | + "input_messages": [message.model_dump() for message in input_messages], |
| 140 | + } |
| 141 | + |
| 142 | + try: |
| 143 | + await self.sql_store.insert( |
| 144 | + table="chat_completions", |
| 145 | + data=record_data, |
| 146 | + ) |
| 147 | + except IntegrityError as e: |
| 148 | + # Duplicate chat completion IDs can be generated during tests especially if they are replaying |
| 149 | + # recorded responses across different tests. No need to warn or error under those circumstances. |
| 150 | + # In the wild, this is not likely to happen at all (no evidence) so we aren't really hiding any problem. |
| 151 | + |
| 152 | + # Check if it's a unique constraint violation |
| 153 | + error_message = str(e.orig) if e.orig else str(e) |
| 154 | + if self._is_unique_constraint_error(error_message): |
| 155 | + # Update the existing record instead |
| 156 | + await self.sql_store.update(table="chat_completions", data=record_data, where={"id": data["id"]}) |
| 157 | + else: |
| 158 | + # Re-raise if it's not a unique constraint error |
| 159 | + raise |
| 160 | + |
| 161 | + def _is_unique_constraint_error(self, error_message: str) -> bool: |
| 162 | + """Check if the error is specifically a unique constraint violation.""" |
| 163 | + error_lower = error_message.lower() |
| 164 | + return any( |
| 165 | + indicator in error_lower |
| 166 | + for indicator in [ |
| 167 | + "unique constraint failed", # SQLite |
| 168 | + "duplicate key", # PostgreSQL |
| 169 | + "unique violation", # PostgreSQL alternative |
| 170 | + "duplicate entry", # MySQL |
| 171 | + ] |
142 | 172 | ) |
143 | 173 |
|
144 | 174 | async def list_chat_completions( |
|
0 commit comments