@@ -338,6 +338,39 @@ def _accumulate_response(self, item, accumulate: Dict[str, Any]) -> None:
338338 if tool_call .function and tool_call .function .arguments :
339339 accumulate ["message" ]["tool_calls" ][- 1 ]["function" ]["arguments" ] += tool_call .function .arguments
340340
341+ def _accumulate_async_streaming_response (self , item , accumulate : Dict [str , Any ]) -> None :
342+ if not "choices" in item :
343+ return
344+ if "finish_reason" in item ["choices" ][0 ] and item ["choices" ][0 ]["finish_reason" ]:
345+ accumulate ["finish_reason" ] = item ["choices" ][0 ]["finish_reason" ]
346+ if "index" in item ["choices" ][0 ] and item ["choices" ][0 ]["index" ]:
347+ accumulate ["index" ] = item ["choices" ][0 ]["index" ]
348+ if not "delta" in item ["choices" ][0 ]:
349+ return
350+ if "content" in item ["choices" ][0 ]["delta" ] and item ["choices" ][0 ]["delta" ]["content" ]:
351+ accumulate .setdefault ("message" , {})
352+ accumulate ["message" ].setdefault ("content" , "" )
353+ accumulate ["message" ]["content" ] += item ["choices" ][0 ]["delta" ]["content" ]
354+ if "tool_calls" in item ["choices" ][0 ]["delta" ] and item ["choices" ][0 ]["delta" ]["tool_calls" ]:
355+ accumulate .setdefault ("message" , {})
356+ accumulate ["message" ].setdefault ("tool_calls" , [])
357+ if item ["choices" ][0 ]["delta" ]["tool_calls" ] is not None :
358+ for tool_call in item ["choices" ][0 ]["delta" ]["tool_calls" ]:
359+ if tool_call .id :
360+ accumulate ["message" ]["tool_calls" ].append (
361+ {
362+ "id" : tool_call .id ,
363+ "type" : "" ,
364+ "function" : {"name" : "" , "arguments" : "" },
365+ }
366+ )
367+ if tool_call .function :
368+ accumulate ["message" ]["tool_calls" ][- 1 ]["type" ] = "function"
369+ if tool_call .function and tool_call .function .name :
370+ accumulate ["message" ]["tool_calls" ][- 1 ]["function" ]["name" ] = tool_call .function .name
371+ if tool_call .function and tool_call .function .arguments :
372+ accumulate ["message" ]["tool_calls" ][- 1 ]["function" ]["arguments" ] += tool_call .function .arguments
373+
341374 def _wrapped_stream (
342375 self , stream_obj : _models .StreamingChatCompletions , span : "AbstractSpan"
343376 ) -> _models .StreamingChatCompletions :
@@ -408,6 +441,63 @@ def __iter__( # pyright: ignore [reportIncompatibleMethodOverride]
408441
409442 return StreamWrapper (stream_obj , self )
410443
444+ def _async_wrapped_stream (
445+ self , stream_obj : _models .AsyncStreamingChatCompletions , span : "AbstractSpan"
446+ ) -> _models .AsyncStreamingChatCompletions :
447+ class AsyncStreamWrapper (_models .AsyncStreamingChatCompletions ):
448+ def __init__ (self , stream_obj , instrumentor , span ):
449+ super ().__init__ (stream_obj ._response )
450+ self ._instrumentor = instrumentor
451+ self ._accumulate : Dict [str , Any ] = {}
452+ self ._stream_obj = stream_obj
453+ self .span = span
454+ self ._last_result = None
455+
456+ async def __anext__ (self ) -> "_models.StreamingChatCompletionsUpdate" :
457+ try :
458+ result = await super ().__anext__ ()
459+ self ._instrumentor ._accumulate_async_streaming_response ( # pylint: disable=protected-access, line-too-long # pyright: ignore [reportFunctionMemberAccess]
460+ result , self ._accumulate
461+ )
462+ self ._last_result = result
463+ except StopAsyncIteration as exc :
464+ self ._trace_stream_content ()
465+ raise exc
466+ return result
467+
468+ def _trace_stream_content (self ) -> None :
469+ if self ._last_result :
470+ self ._instrumentor ._add_response_chat_attributes ( # pylint: disable=protected-access, line-too-long # pyright: ignore [reportFunctionMemberAccess]
471+ span , self ._last_result
472+ )
473+ # Only one choice expected with streaming
474+ self ._accumulate ["index" ] = 0
475+ # Delete message if content tracing is not enabled
476+ if not _trace_inference_content :
477+ if "message" in self ._accumulate :
478+ if "content" in self ._accumulate ["message" ]:
479+ del self ._accumulate ["message" ]["content" ]
480+ if not self ._accumulate ["message" ]:
481+ del self ._accumulate ["message" ]
482+ if "message" in self ._accumulate :
483+ if "tool_calls" in self ._accumulate ["message" ]:
484+ tools_no_recording = self ._instrumentor ._remove_function_call_names_and_arguments ( # pylint: disable=protected-access, line-too-long # pyright: ignore [reportFunctionMemberAccess]
485+ self ._accumulate ["message" ]["tool_calls" ]
486+ )
487+ self ._accumulate ["message" ]["tool_calls" ] = list (tools_no_recording )
488+
489+ self .span .span_instance .add_event (
490+ name = "gen_ai.choice" ,
491+ attributes = {
492+ "gen_ai.system" : _INFERENCE_GEN_AI_SYSTEM_NAME ,
493+ "gen_ai.event.content" : json .dumps (self ._accumulate ),
494+ },
495+ )
496+ span .finish ()
497+
498+ async_stream_wrapper = AsyncStreamWrapper (stream_obj , self , span )
499+ return async_stream_wrapper
500+
411501 def _trace_sync_function (
412502 self ,
413503 function : Callable ,
@@ -534,7 +624,7 @@ async def inner(*args, **kwargs):
534624 self ._add_request_span_attributes (span , span_name , args , kwargs )
535625 result = await function (* args , ** kwargs )
536626 if kwargs .get ("stream" ) is True :
537- return self ._wrapped_stream (result , span )
627+ return self ._async_wrapped_stream (result , span )
538628 self ._add_response_span_attributes (span , result )
539629
540630 except Exception as exc :
0 commit comments