Skip to content

Commit 5c28e02

Browse files
M-HietalaMarko Hietala
andauthored
inference async streaming tracing fix (Azure#38350)
* initial changes * completing the fix * robustness improvements * trying to fix code check tool related issue --------- Co-authored-by: Marko Hietala <[email protected]>
1 parent 9e3b0e9 commit 5c28e02

File tree

1 file changed

+91
-1
lines changed
  • sdk/ai/azure-ai-inference/azure/ai/inference

1 file changed

+91
-1
lines changed

sdk/ai/azure-ai-inference/azure/ai/inference/tracing.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,39 @@ def _accumulate_response(self, item, accumulate: Dict[str, Any]) -> None:
338338
if tool_call.function and tool_call.function.arguments:
339339
accumulate["message"]["tool_calls"][-1]["function"]["arguments"] += tool_call.function.arguments
340340

341+
def _accumulate_async_streaming_response(self, item, accumulate: Dict[str, Any]) -> None:
342+
if not "choices" in item:
343+
return
344+
if "finish_reason" in item["choices"][0] and item["choices"][0]["finish_reason"]:
345+
accumulate["finish_reason"] = item["choices"][0]["finish_reason"]
346+
if "index" in item["choices"][0] and item["choices"][0]["index"]:
347+
accumulate["index"] = item["choices"][0]["index"]
348+
if not "delta" in item["choices"][0]:
349+
return
350+
if "content" in item["choices"][0]["delta"] and item["choices"][0]["delta"]["content"]:
351+
accumulate.setdefault("message", {})
352+
accumulate["message"].setdefault("content", "")
353+
accumulate["message"]["content"] += item["choices"][0]["delta"]["content"]
354+
if "tool_calls" in item["choices"][0]["delta"] and item["choices"][0]["delta"]["tool_calls"]:
355+
accumulate.setdefault("message", {})
356+
accumulate["message"].setdefault("tool_calls", [])
357+
if item["choices"][0]["delta"]["tool_calls"] is not None:
358+
for tool_call in item["choices"][0]["delta"]["tool_calls"]:
359+
if tool_call.id:
360+
accumulate["message"]["tool_calls"].append(
361+
{
362+
"id": tool_call.id,
363+
"type": "",
364+
"function": {"name": "", "arguments": ""},
365+
}
366+
)
367+
if tool_call.function:
368+
accumulate["message"]["tool_calls"][-1]["type"] = "function"
369+
if tool_call.function and tool_call.function.name:
370+
accumulate["message"]["tool_calls"][-1]["function"]["name"] = tool_call.function.name
371+
if tool_call.function and tool_call.function.arguments:
372+
accumulate["message"]["tool_calls"][-1]["function"]["arguments"] += tool_call.function.arguments
373+
341374
def _wrapped_stream(
342375
self, stream_obj: _models.StreamingChatCompletions, span: "AbstractSpan"
343376
) -> _models.StreamingChatCompletions:
@@ -408,6 +441,63 @@ def __iter__( # pyright: ignore [reportIncompatibleMethodOverride]
408441

409442
return StreamWrapper(stream_obj, self)
410443

444+
def _async_wrapped_stream(
445+
self, stream_obj: _models.AsyncStreamingChatCompletions, span: "AbstractSpan"
446+
) -> _models.AsyncStreamingChatCompletions:
447+
class AsyncStreamWrapper(_models.AsyncStreamingChatCompletions):
448+
def __init__(self, stream_obj, instrumentor, span):
449+
super().__init__(stream_obj._response)
450+
self._instrumentor = instrumentor
451+
self._accumulate: Dict[str, Any] = {}
452+
self._stream_obj = stream_obj
453+
self.span = span
454+
self._last_result = None
455+
456+
async def __anext__(self) -> "_models.StreamingChatCompletionsUpdate":
457+
try:
458+
result = await super().__anext__()
459+
self._instrumentor._accumulate_async_streaming_response( # pylint: disable=protected-access, line-too-long # pyright: ignore [reportFunctionMemberAccess]
460+
result, self._accumulate
461+
)
462+
self._last_result = result
463+
except StopAsyncIteration as exc:
464+
self._trace_stream_content()
465+
raise exc
466+
return result
467+
468+
def _trace_stream_content(self) -> None:
469+
if self._last_result:
470+
self._instrumentor._add_response_chat_attributes( # pylint: disable=protected-access, line-too-long # pyright: ignore [reportFunctionMemberAccess]
471+
span, self._last_result
472+
)
473+
# Only one choice expected with streaming
474+
self._accumulate["index"] = 0
475+
# Delete message if content tracing is not enabled
476+
if not _trace_inference_content:
477+
if "message" in self._accumulate:
478+
if "content" in self._accumulate["message"]:
479+
del self._accumulate["message"]["content"]
480+
if not self._accumulate["message"]:
481+
del self._accumulate["message"]
482+
if "message" in self._accumulate:
483+
if "tool_calls" in self._accumulate["message"]:
484+
tools_no_recording = self._instrumentor._remove_function_call_names_and_arguments( # pylint: disable=protected-access, line-too-long # pyright: ignore [reportFunctionMemberAccess]
485+
self._accumulate["message"]["tool_calls"]
486+
)
487+
self._accumulate["message"]["tool_calls"] = list(tools_no_recording)
488+
489+
self.span.span_instance.add_event(
490+
name="gen_ai.choice",
491+
attributes={
492+
"gen_ai.system": _INFERENCE_GEN_AI_SYSTEM_NAME,
493+
"gen_ai.event.content": json.dumps(self._accumulate),
494+
},
495+
)
496+
span.finish()
497+
498+
async_stream_wrapper = AsyncStreamWrapper(stream_obj, self, span)
499+
return async_stream_wrapper
500+
411501
def _trace_sync_function(
412502
self,
413503
function: Callable,
@@ -534,7 +624,7 @@ async def inner(*args, **kwargs):
534624
self._add_request_span_attributes(span, span_name, args, kwargs)
535625
result = await function(*args, **kwargs)
536626
if kwargs.get("stream") is True:
537-
return self._wrapped_stream(result, span)
627+
return self._async_wrapped_stream(result, span)
538628
self._add_response_span_attributes(span, result)
539629

540630
except Exception as exc:

0 commit comments

Comments
 (0)