Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,6 +1171,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
except IndexError:
continue

# When using Azure OpenAI and an async content filter is enabled, the openai SDK can return None deltas.
if choice.delta is None: # pyright: ignore[reportUnnecessaryComparison]
continue

# Handle the text part of the response
content = choice.delta.content
if content is not None:
Expand Down
31 changes: 31 additions & 0 deletions tests/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,37 @@ async def test_no_delta(allow_model_requests: None):
assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=6, output_tokens=3))


def none_delta_chunk(finish_reason: FinishReason | None = None) -> chat.ChatCompletionChunk:
choice = ChunkChoice(index=0, delta=ChoiceDelta())
# When using Azure OpenAI and an async content filter is enabled, the openai SDK can return None deltas.
choice.delta = None # pyright: ignore[reportAttributeAccessIssue]
return chat.ChatCompletionChunk(
id='x',
choices=[choice],
created=1704067200, # 2024-01-01
model='gpt-4o',
object='chat.completion.chunk',
usage=CompletionUsage(completion_tokens=1, prompt_tokens=2, total_tokens=3),
)


async def test_none_delta(allow_model_requests: None):
stream = [
none_delta_chunk(),
text_chunk('hello '),
text_chunk('world'),
]
mock_client = MockOpenAI.create_mock_stream(stream)
m = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
agent = Agent(m)

async with agent.run_stream('') as result:
assert not result.is_complete
assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world'])
assert result.is_complete
assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=6, output_tokens=3))


@pytest.mark.filterwarnings('ignore:Set the `system_prompt_role` in the `OpenAIModelProfile` instead.')
@pytest.mark.parametrize('system_prompt_role', ['system', 'developer', 'user', None])
async def test_system_prompt_role(
Expand Down