Skip to content

Commit 0974d5f

Browse files
authored
Set StreamedResponse.model_name from later streamed chunk for Azure OpenAI with content filter (#2951)
1 parent 9a630a9 commit 0974d5f

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,8 @@ async def _process_streamed_response(
586586
'Streamed response ended without content or tool calls'
587587
)
588588

589-
# ChatCompletionChunk.model is required to be set, but Azure OpenAI omits it so we fall back to the model name set by the user.
589+
# When using Azure OpenAI and a content filter is enabled, the first chunk will contain a `''` model name,
590+
# so we set it from a later chunk in `OpenAIChatStreamedResponse`.
590591
model_name = first_chunk.model or self._model_name
591592

592593
return OpenAIStreamedResponse(
@@ -1352,9 +1353,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
13521353
async for chunk in self._response:
13531354
self._usage += _map_usage(chunk)
13541355

1355-
if chunk.id and self.provider_response_id is None:
1356+
if chunk.id: # pragma: no branch
13561357
self.provider_response_id = chunk.id
13571358

1359+
if chunk.model:
1360+
self._model_name = chunk.model
1361+
13581362
try:
13591363
choice = chunk.choices[0]
13601364
except IndexError:

tests/models/test_openai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ async def test_stream_text(allow_model_requests: None):
398398

399399
async def test_stream_text_finish_reason(allow_model_requests: None):
400400
first_chunk = text_chunk('hello ')
401-
# Test that we fall back to the model name set by the user if the model name is not set in the first chunk, like on Azure OpenAI.
401+
# Test that we get the model name from a later chunk if it is not set on the first one, like on Azure OpenAI with content filter enabled.
402402
first_chunk.model = ''
403403
stream = [
404404
first_chunk,
@@ -421,7 +421,7 @@ async def test_stream_text_finish_reason(allow_model_requests: None):
421421
ModelResponse(
422422
parts=[TextPart(content='hello world.')],
423423
usage=RequestUsage(input_tokens=6, output_tokens=3),
424-
model_name='gpt-4o',
424+
model_name='gpt-4o-123',
425425
timestamp=IsDatetime(),
426426
provider_name='openai',
427427
provider_details={'finish_reason': 'stop'},

0 commit comments

Comments
 (0)