|
1 | | -import inspect |
2 | 1 | import pprint |
3 | 2 | from typing import Optional |
4 | 3 |
|
|
16 | 15 | class OpenAiProvider(InstrumentedProvider): |
17 | 16 | original_create = None |
18 | 17 | original_create_async = None |
| 18 | + original_assistant_methods = None |
| 19 | + assistants_run_steps = {} |
19 | 20 |
|
20 | 21 | def __init__(self, client): |
21 | 22 | super().__init__(client) |
@@ -138,6 +139,7 @@ async def async_generator(): |
138 | 139 | def override(self): |
139 | 140 | self._override_openai_v1_completion() |
140 | 141 | self._override_openai_v1_async_completion() |
| 142 | + self._override_openai_assistants_beta() |
141 | 143 |
|
142 | 144 | def _override_openai_v1_completion(self): |
143 | 145 | from openai.resources.chat import completions |
@@ -228,9 +230,114 @@ async def patched_function(*args, **kwargs): |
228 | 230 | # Override the original method with the patched one |
229 | 231 | completions.AsyncCompletions.create = patched_function |
230 | 232 |
|
| 233 | + def _override_openai_assistants_beta(self): |
| 234 | + """Override OpenAI Assistants API methods""" |
| 235 | + from openai._legacy_response import LegacyAPIResponse |
| 236 | + from openai.resources import beta |
| 237 | + from openai.pagination import BasePage |
| 238 | + |
| 239 | + def handle_response(response, kwargs, init_timestamp, session: Optional[Session] = None) -> dict: |
| 240 | + """Handle response based on return type""" |
| 241 | + action_event = ActionEvent(init_timestamp=init_timestamp, params=kwargs) |
| 242 | + if session is not None: |
| 243 | + action_event.session_id = session.session_id |
| 244 | + |
| 245 | + try: |
| 246 | + # Set action type and returns |
| 247 | + action_event.action_type = ( |
| 248 | + response.__class__.__name__.split("[")[1][:-1] |
| 249 | + if isinstance(response, BasePage) |
| 250 | + else response.__class__.__name__ |
| 251 | + ) |
| 252 | + action_event.returns = response.model_dump() if hasattr(response, "model_dump") else response |
| 253 | + action_event.end_timestamp = get_ISO_time() |
| 254 | + self._safe_record(session, action_event) |
| 255 | + |
| 256 | + # Create LLMEvent if usage data exists |
| 257 | + response_dict = response.model_dump() if hasattr(response, "model_dump") else {} |
| 258 | + |
| 259 | + if "id" in response_dict and response_dict.get("id").startswith("run"): |
| 260 | + if response_dict["id"] not in self.assistants_run_steps: |
| 261 | + self.assistants_run_steps[response_dict.get("id")] = {"model": response_dict.get("model")} |
| 262 | + |
| 263 | + if "usage" in response_dict and response_dict["usage"] is not None: |
| 264 | + llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs) |
| 265 | + if session is not None: |
| 266 | + llm_event.session_id = session.session_id |
| 267 | + |
| 268 | + llm_event.model = response_dict.get("model") |
| 269 | + llm_event.prompt_tokens = response_dict["usage"]["prompt_tokens"] |
| 270 | + llm_event.completion_tokens = response_dict["usage"]["completion_tokens"] |
| 271 | + llm_event.end_timestamp = get_ISO_time() |
| 272 | + self._safe_record(session, llm_event) |
| 273 | + |
| 274 | + elif "data" in response_dict: |
| 275 | + for item in response_dict["data"]: |
| 276 | + if "usage" in item and item["usage"] is not None: |
| 277 | + llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs) |
| 278 | + if session is not None: |
| 279 | + llm_event.session_id = session.session_id |
| 280 | + |
| 281 | + llm_event.model = self.assistants_run_steps[item["run_id"]]["model"] |
| 282 | + llm_event.prompt_tokens = item["usage"]["prompt_tokens"] |
| 283 | + llm_event.completion_tokens = item["usage"]["completion_tokens"] |
| 284 | + llm_event.end_timestamp = get_ISO_time() |
| 285 | + self._safe_record(session, llm_event) |
| 286 | + |
| 287 | + except Exception as e: |
| 288 | + self._safe_record(session, ErrorEvent(trigger_event=action_event, exception=e)) |
| 289 | + |
| 290 | + kwargs_str = pprint.pformat(kwargs) |
| 291 | + response = pprint.pformat(response) |
| 292 | + logger.warning( |
| 293 | + f"Unable to parse response for Assistants API. Skipping upload to AgentOps\n" |
| 294 | + f"response:\n {response}\n" |
| 295 | + f"kwargs:\n {kwargs_str}\n" |
| 296 | + ) |
| 297 | + |
| 298 | + return response |
| 299 | + |
| 300 | + def create_patched_function(original_func): |
| 301 | + def patched_function(*args, **kwargs): |
| 302 | + init_timestamp = get_ISO_time() |
| 303 | + |
| 304 | + session = kwargs.get("session", None) |
| 305 | + if "session" in kwargs.keys(): |
| 306 | + del kwargs["session"] |
| 307 | + |
| 308 | + response = original_func(*args, **kwargs) |
| 309 | + if isinstance(response, LegacyAPIResponse): |
| 310 | + return response |
| 311 | + |
| 312 | + return handle_response(response, kwargs, init_timestamp, session=session) |
| 313 | + |
| 314 | + return patched_function |
| 315 | + |
| 316 | + # Store and patch Assistant API methods |
| 317 | + assistant_api_methods = { |
| 318 | + beta.Assistants: ["create", "retrieve", "update", "delete", "list"], |
| 319 | + beta.Threads: ["create", "retrieve", "update", "delete"], |
| 320 | + beta.threads.Messages: ["create", "retrieve", "update", "list"], |
| 321 | + beta.threads.Runs: ["create", "retrieve", "update", "list", "submit_tool_outputs", "cancel"], |
| 322 | + beta.threads.runs.steps.Steps: ["retrieve", "list"], |
| 323 | + } |
| 324 | + |
| 325 | + self.original_assistant_methods = { |
| 326 | + (cls, method): getattr(cls, method) for cls, methods in assistant_api_methods.items() for method in methods |
| 327 | + } |
| 328 | + |
| 329 | + # Override methods and verify |
| 330 | + for (cls, method), original_func in self.original_assistant_methods.items(): |
| 331 | + patched_function = create_patched_function(original_func) |
| 332 | + setattr(cls, method, patched_function) |
| 333 | + |
231 | 334 | def undo_override(self): |
232 | 335 | if self.original_create is not None and self.original_create_async is not None: |
233 | 336 | from openai.resources.chat import completions |
234 | 337 |
|
235 | 338 | completions.AsyncCompletions.create = self.original_create_async |
236 | 339 | completions.Completions.create = self.original_create |
| 340 | + |
| 341 | + if self.original_assistant_methods is not None: |
| 342 | + for (cls, method), original in self.original_assistant_methods.items(): |
| 343 | + setattr(cls, method, original) |
0 commit comments