Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ def _chat_completion_wrapper(self, wrapped, instance, args, kwargs):
_record_operation_duration_metric(self.operation_duration_metric, error_attributes, start_time)
raise

is_raw_response = _is_raw_response(result)
if kwargs.get("stream"):
return StreamWrapper(
stream=result,
Expand All @@ -208,12 +207,12 @@ def _chat_completion_wrapper(self, wrapped, instance, args, kwargs):
start_time=start_time,
token_usage_metric=self.token_usage_metric,
operation_duration_metric=self.operation_duration_metric,
is_raw_response=is_raw_response,
)

logger.debug(f"openai.resources.chat.completions.Completions.create result: {result}")

# if the caller is using with_raw_response we need to parse the output to get the response class we expect
is_raw_response = _is_raw_response(result)
if is_raw_response:
result = result.parse()
response_attributes = _get_attributes_from_response(
Expand Down Expand Up @@ -271,7 +270,6 @@ async def _async_chat_completion_wrapper(self, wrapped, instance, args, kwargs):
_record_operation_duration_metric(self.operation_duration_metric, error_attributes, start_time)
raise

is_raw_response = _is_raw_response(result)
if kwargs.get("stream"):
return StreamWrapper(
stream=result,
Expand All @@ -283,12 +281,12 @@ async def _async_chat_completion_wrapper(self, wrapped, instance, args, kwargs):
start_time=start_time,
token_usage_metric=self.token_usage_metric,
operation_duration_metric=self.operation_duration_metric,
is_raw_response=is_raw_response,
)

logger.debug(f"openai.resources.chat.completions.AsyncCompletions.create result: {result}")

# if the caller is using with_raw_response we need to parse the output to get the response class we expect
is_raw_response = _is_raw_response(result)
if is_raw_response:
result = result.parse()
response_attributes = _get_attributes_from_response(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def __init__(
start_time: float,
token_usage_metric: Histogram,
operation_duration_metric: Histogram,
is_raw_response: bool,
):
# we need to wrap the original response even in case of raw_responses
super().__init__(stream)
Expand All @@ -60,7 +59,6 @@ def __init__(
self.token_usage_metric = token_usage_metric
self.operation_duration_metric = operation_duration_metric
self.start_time = start_time
self.is_raw_response = is_raw_response

self.response_id = None
self.model = None
Expand Down Expand Up @@ -125,8 +123,6 @@ def __exit__(self, exc_type, exc_value, traceback):
def __iter__(self):
stream = self.__wrapped__
try:
if self.is_raw_response:
stream = stream.parse()
for chunk in stream:
self.process_chunk(chunk)
yield chunk
Expand All @@ -145,12 +141,34 @@ async def __aexit__(self, exc_type, exc_value, traceback):
async def __aiter__(self):
stream = self.__wrapped__
try:
if self.is_raw_response:
stream = stream.parse()
async for chunk in stream:
self.process_chunk(chunk)
yield chunk
except Exception as exc:
self.end(exc)
raise
self.end()

def parse(self):
"""
Handles direct parse() call on the client in order to maintain instrumentation on the parsed iterator.
"""
parsed_iterator = self.__wrapped__.parse()

parsed_wrapper = StreamWrapper(
stream=parsed_iterator,
span=self.span,
span_attributes=self.span_attributes,
capture_message_content=self.capture_message_content,
event_attributes=self.event_attributes,
event_logger=self.event_logger,
start_time=self.start_time,
token_usage_metric=self.token_usage_metric,
operation_duration_metric=self.operation_duration_metric,
)

# Handle original sync/async iterators accordingly
if hasattr(parsed_iterator, "__aiter__"):
return parsed_wrapper.__aiter__()

return parsed_wrapper.__iter__()
Original file line number Diff line number Diff line change
Expand Up @@ -1171,10 +1171,13 @@ def test_chat_stream_with_raw_response(default_openai_env, trace_exporter, metri
}
]

chat_completion = client.chat.completions.with_raw_response.create(
raw_response = client.chat.completions.with_raw_response.create(
model=TEST_CHAT_MODEL, messages=messages, stream=True
)

# Explicit parse of the raw response
chat_completion = raw_response.parse()

chunks = [chunk.choices[0].delta.content or "" for chunk in chat_completion if chunk.choices]
assert "".join(chunks) == "Atlantic Ocean"

Expand Down Expand Up @@ -2226,10 +2229,13 @@ async def test_chat_async_stream_with_raw_response(default_openai_env, trace_exp
}
]

chat_completion = await client.chat.completions.with_raw_response.create(
raw_response = await client.chat.completions.with_raw_response.create(
model=TEST_CHAT_MODEL, messages=messages, stream=True
)

# Explicit parse of the raw response
chat_completion = raw_response.parse()

chunks = [chunk.choices[0].delta.content or "" async for chunk in chat_completion if chunk.choices]
assert "".join(chunks) == "Atlantic Ocean"

Expand Down