diff --git a/instrumentation/elastic-opentelemetry-instrumentation-openai/src/opentelemetry/instrumentation/openai/__init__.py b/instrumentation/elastic-opentelemetry-instrumentation-openai/src/opentelemetry/instrumentation/openai/__init__.py index 973e6eb..4020097 100644 --- a/instrumentation/elastic-opentelemetry-instrumentation-openai/src/opentelemetry/instrumentation/openai/__init__.py +++ b/instrumentation/elastic-opentelemetry-instrumentation-openai/src/opentelemetry/instrumentation/openai/__init__.py @@ -196,7 +196,6 @@ def _chat_completion_wrapper(self, wrapped, instance, args, kwargs): _record_operation_duration_metric(self.operation_duration_metric, error_attributes, start_time) raise - is_raw_response = _is_raw_response(result) if kwargs.get("stream"): return StreamWrapper( stream=result, @@ -208,12 +207,12 @@ def _chat_completion_wrapper(self, wrapped, instance, args, kwargs): start_time=start_time, token_usage_metric=self.token_usage_metric, operation_duration_metric=self.operation_duration_metric, - is_raw_response=is_raw_response, ) logger.debug(f"openai.resources.chat.completions.Completions.create result: {result}") # if the caller is using with_raw_response we need to parse the output to get the response class we expect + is_raw_response = _is_raw_response(result) if is_raw_response: result = result.parse() response_attributes = _get_attributes_from_response( @@ -271,7 +270,6 @@ async def _async_chat_completion_wrapper(self, wrapped, instance, args, kwargs): _record_operation_duration_metric(self.operation_duration_metric, error_attributes, start_time) raise - is_raw_response = _is_raw_response(result) if kwargs.get("stream"): return StreamWrapper( stream=result, @@ -283,12 +281,12 @@ async def _async_chat_completion_wrapper(self, wrapped, instance, args, kwargs): start_time=start_time, token_usage_metric=self.token_usage_metric, operation_duration_metric=self.operation_duration_metric, - is_raw_response=is_raw_response, ) logger.debug(f"openai.resources.chat.completions.AsyncCompletions.create result: {result}") # if the caller is using with_raw_response we need to parse the output to get the response class we expect + is_raw_response = _is_raw_response(result) if is_raw_response: result = result.parse() response_attributes = _get_attributes_from_response( diff --git a/instrumentation/elastic-opentelemetry-instrumentation-openai/src/opentelemetry/instrumentation/openai/wrappers.py b/instrumentation/elastic-opentelemetry-instrumentation-openai/src/opentelemetry/instrumentation/openai/wrappers.py index 62ec1c6..22d431d 100644 --- a/instrumentation/elastic-opentelemetry-instrumentation-openai/src/opentelemetry/instrumentation/openai/wrappers.py +++ b/instrumentation/elastic-opentelemetry-instrumentation-openai/src/opentelemetry/instrumentation/openai/wrappers.py @@ -47,7 +47,6 @@ def __init__( start_time: float, token_usage_metric: Histogram, operation_duration_metric: Histogram, - is_raw_response: bool, ): # we need to wrap the original response even in case of raw_responses super().__init__(stream) @@ -60,7 +59,6 @@ def __init__( self.token_usage_metric = token_usage_metric self.operation_duration_metric = operation_duration_metric self.start_time = start_time - self.is_raw_response = is_raw_response self.response_id = None self.model = None @@ -125,8 +123,6 @@ def __exit__(self, exc_type, exc_value, traceback): def __iter__(self): stream = self.__wrapped__ try: - if self.is_raw_response: - stream = stream.parse() for chunk in stream: self.process_chunk(chunk) yield chunk @@ -145,8 +141,6 @@ async def __aexit__(self, exc_type, exc_value, traceback): async def __aiter__(self): stream = self.__wrapped__ try: - if self.is_raw_response: - stream = stream.parse() async for chunk in stream: self.process_chunk(chunk) yield chunk @@ -154,3 +148,27 @@ async def __aiter__(self): self.end(exc) raise self.end() + + def parse(self): + """ + Handles direct parse() call on the client in order to maintain instrumentation on the parsed iterator. + """ + parsed_iterator = self.__wrapped__.parse() + + parsed_wrapper = StreamWrapper( + stream=parsed_iterator, + span=self.span, + span_attributes=self.span_attributes, + capture_message_content=self.capture_message_content, + event_attributes=self.event_attributes, + event_logger=self.event_logger, + start_time=self.start_time, + token_usage_metric=self.token_usage_metric, + operation_duration_metric=self.operation_duration_metric, + ) + + # Handle original sync/async iterators accordingly + if hasattr(parsed_iterator, "__aiter__"): + return parsed_wrapper.__aiter__() + + return parsed_wrapper.__iter__() diff --git a/instrumentation/elastic-opentelemetry-instrumentation-openai/tests/test_chat_completions.py b/instrumentation/elastic-opentelemetry-instrumentation-openai/tests/test_chat_completions.py index 5fbf820..c1a3e8a 100644 --- a/instrumentation/elastic-opentelemetry-instrumentation-openai/tests/test_chat_completions.py +++ b/instrumentation/elastic-opentelemetry-instrumentation-openai/tests/test_chat_completions.py @@ -1171,10 +1171,13 @@ def test_chat_stream_with_raw_response(default_openai_env, trace_exporter, metri } ] - chat_completion = client.chat.completions.with_raw_response.create( + raw_response = client.chat.completions.with_raw_response.create( model=TEST_CHAT_MODEL, messages=messages, stream=True ) + # Explicit parse of the raw response + chat_completion = raw_response.parse() + chunks = [chunk.choices[0].delta.content or "" for chunk in chat_completion if chunk.choices] assert "".join(chunks) == "Atlantic Ocean" @@ -2226,10 +2229,13 @@ async def test_chat_async_stream_with_raw_response(default_openai_env, trace_exp } ] - chat_completion = await client.chat.completions.with_raw_response.create( + raw_response = await client.chat.completions.with_raw_response.create( model=TEST_CHAT_MODEL, messages=messages, stream=True ) + # Explicit parse of the raw response + chat_completion = raw_response.parse() + chunks = [chunk.choices[0].delta.content or "" async for chunk in chat_completion if chunk.choices] assert "".join(chunks) == "Atlantic Ocean"