Skip to content

Commit 7f43968

Browse files
authored
Ignore empty text deltas when streaming gpt-oss via Ollama (#3216)
1 parent c829449 commit 7f43968

File tree

8 files changed

+105
-108
lines changed

8 files changed

+105
-108
lines changed

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
330330
if call_part and return_part: # pragma: no branch
331331
items.append(call_part)
332332
items.append(return_part)
333-
if choice.message.content is not None:
333+
if choice.message.content:
334334
# NOTE: The `<think>` tag is only present if `groq_reasoning_format` is set to `raw`.
335335
items.extend(split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags))
336336
if choice.message.tool_calls is not None:
@@ -563,7 +563,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
563563

564564
# Handle the text part of the response
565565
content = choice.delta.content
566-
if content is not None:
566+
if content:
567567
maybe_event = self._parts_manager.handle_text_delta(
568568
vendor_part_id='content',
569569
content=content,

pydantic_ai_slim/pydantic_ai/models/huggingface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def _process_response(self, response: ChatCompletionOutput) -> ModelResponse:
277277

278278
items: list[ModelResponsePart] = []
279279

280-
if content is not None:
280+
if content:
281281
items.extend(split_content_into_text_and_thinking(content, self.profile.thinking_tags))
282282
if tool_calls is not None:
283283
for c in tool_calls:
@@ -482,7 +482,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
482482

483483
# Handle the text part of the response
484484
content = choice.delta.content
485-
if content is not None:
485+
if content:
486486
maybe_event = self._parts_manager.handle_text_delta(
487487
vendor_part_id='content',
488488
content=content,

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -559,24 +559,7 @@ def _process_response(self, response: chat.ChatCompletion | str) -> ModelRespons
559559
# - https://openrouter.ai/docs/use-cases/reasoning-tokens#preserving-reasoning-blocks
560560
# If you need this, please file an issue.
561561

562-
vendor_details: dict[str, Any] = {}
563-
564-
# Add logprobs to vendor_details if available
565-
if choice.logprobs is not None and choice.logprobs.content:
566-
# Convert logprobs to a serializable format
567-
vendor_details['logprobs'] = [
568-
{
569-
'token': lp.token,
570-
'bytes': lp.bytes,
571-
'logprob': lp.logprob,
572-
'top_logprobs': [
573-
{'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs
574-
],
575-
}
576-
for lp in choice.logprobs.content
577-
]
578-
579-
if choice.message.content is not None:
562+
if choice.message.content:
580563
items.extend(
581564
(replace(part, id='content', provider_name=self.system) if isinstance(part, ThinkingPart) else part)
582565
for part in split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags)
@@ -594,6 +577,23 @@ def _process_response(self, response: chat.ChatCompletion | str) -> ModelRespons
594577
part.tool_call_id = _guard_tool_call_id(part)
595578
items.append(part)
596579

580+
vendor_details: dict[str, Any] = {}
581+
582+
# Add logprobs to vendor_details if available
583+
if choice.logprobs is not None and choice.logprobs.content:
584+
# Convert logprobs to a serializable format
585+
vendor_details['logprobs'] = [
586+
{
587+
'token': lp.token,
588+
'bytes': lp.bytes,
589+
'logprob': lp.logprob,
590+
'top_logprobs': [
591+
{'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs
592+
],
593+
}
594+
for lp in choice.logprobs.content
595+
]
596+
597597
raw_finish_reason = choice.finish_reason
598598
vendor_details['finish_reason'] = raw_finish_reason
599599
finish_reason = _CHAT_FINISH_REASON_MAP.get(raw_finish_reason)
@@ -1616,21 +1616,6 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
16161616
self.provider_details = {'finish_reason': raw_finish_reason}
16171617
self.finish_reason = _CHAT_FINISH_REASON_MAP.get(raw_finish_reason)
16181618

1619-
# Handle the text part of the response
1620-
content = choice.delta.content
1621-
if content is not None:
1622-
maybe_event = self._parts_manager.handle_text_delta(
1623-
vendor_part_id='content',
1624-
content=content,
1625-
thinking_tags=self._model_profile.thinking_tags,
1626-
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
1627-
)
1628-
if maybe_event is not None: # pragma: no branch
1629-
if isinstance(maybe_event, PartStartEvent) and isinstance(maybe_event.part, ThinkingPart):
1630-
maybe_event.part.id = 'content'
1631-
maybe_event.part.provider_name = self.provider_name
1632-
yield maybe_event
1633-
16341619
# The `reasoning_content` field is only present in DeepSeek models.
16351620
# https://api-docs.deepseek.com/guides/reasoning_model
16361621
if reasoning_content := getattr(choice.delta, 'reasoning_content', None):
@@ -1652,6 +1637,21 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
16521637
provider_name=self.provider_name,
16531638
)
16541639

1640+
# Handle the text part of the response
1641+
content = choice.delta.content
1642+
if content:
1643+
maybe_event = self._parts_manager.handle_text_delta(
1644+
vendor_part_id='content',
1645+
content=content,
1646+
thinking_tags=self._model_profile.thinking_tags,
1647+
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
1648+
)
1649+
if maybe_event is not None: # pragma: no branch
1650+
if isinstance(maybe_event, PartStartEvent) and isinstance(maybe_event.part, ThinkingPart):
1651+
maybe_event.part.id = 'content'
1652+
maybe_event.part.provider_name = self.provider_name
1653+
yield maybe_event
1654+
16551655
for dtc in choice.delta.tool_calls or []:
16561656
maybe_event = self._parts_manager.handle_tool_call_delta(
16571657
vendor_part_id=dtc.index,

tests/models/test_deepseek.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,17 @@
11
from __future__ import annotations as _annotations
22

3-
from typing import Any
4-
53
import pytest
6-
from dirty_equals import IsListOrTuple
74
from inline_snapshot import snapshot
85

96
from pydantic_ai import (
107
Agent,
11-
FinalResultEvent,
128
ModelRequest,
139
ModelResponse,
14-
PartDeltaEvent,
15-
PartStartEvent,
1610
TextPart,
17-
TextPartDelta,
1811
ThinkingPart,
19-
ThinkingPartDelta,
2012
UserPromptPart,
2113
)
14+
from pydantic_ai.run import AgentRunResult, AgentRunResultEvent
2215
from pydantic_ai.usage import RequestUsage
2316

2417
from ..conftest import IsDatetime, IsStr, try_import
@@ -71,27 +64,42 @@ async def test_deepseek_model_thinking_stream(allow_model_requests: None, deepse
7164
deepseek_model = OpenAIChatModel('deepseek-reasoner', provider=DeepSeekProvider(api_key=deepseek_api_key))
7265
agent = Agent(model=deepseek_model)
7366

74-
event_parts: list[Any] = []
75-
async with agent.iter(user_prompt='Hello') as agent_run:
76-
async for node in agent_run:
77-
if Agent.is_model_request_node(node) or Agent.is_call_tools_node(node):
78-
async with node.stream(agent_run.ctx) as request_stream:
79-
async for event in request_stream:
80-
event_parts.append(event)
67+
result: AgentRunResult | None = None
68+
async for event in agent.run_stream_events(user_prompt='How do I cross the street?'):
69+
if isinstance(event, AgentRunResultEvent):
70+
result = event.result
8171

82-
assert event_parts == IsListOrTuple(
83-
positions={
84-
0: snapshot(
85-
PartStartEvent(
86-
index=0, part=ThinkingPart(content='H', id='reasoning_content', provider_name='deepseek')
87-
)
72+
assert result is not None
73+
assert result.all_messages() == snapshot(
74+
[
75+
ModelRequest(
76+
parts=[
77+
UserPromptPart(
78+
content='How do I cross the street?',
79+
timestamp=IsDatetime(),
80+
)
81+
]
82+
),
83+
ModelResponse(
84+
parts=[
85+
ThinkingPart(
86+
content=IsStr(),
87+
id='reasoning_content',
88+
provider_name='deepseek',
89+
),
90+
TextPart(content='Hello there! 😊 How can I help you today?'),
91+
],
92+
usage=RequestUsage(
93+
input_tokens=6,
94+
output_tokens=212,
95+
details={'prompt_cache_hit_tokens': 0, 'prompt_cache_miss_tokens': 6, 'reasoning_tokens': 198},
96+
),
97+
model_name='deepseek-reasoner',
98+
timestamp=IsDatetime(),
99+
provider_name='deepseek',
100+
provider_details={'finish_reason': 'stop'},
101+
provider_response_id='33be18fc-3842-486c-8c29-dd8e578f7f20',
102+
finish_reason='stop',
88103
),
89-
1: snapshot(PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='mm', provider_name='deepseek'))),
90-
2: snapshot(PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=',', provider_name='deepseek'))),
91-
198: snapshot(PartStartEvent(index=1, part=TextPart(content='Hello'))),
92-
199: snapshot(FinalResultEvent(tool_name=None, tool_call_id=None)),
93-
200: snapshot(PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' there'))),
94-
201: snapshot(PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='!'))),
95-
},
96-
length=211,
104+
]
97105
)

tests/models/test_groq.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5160,7 +5160,6 @@ async def get_something_by_name(name: str) -> str:
51605160
),
51615161
ModelResponse(
51625162
parts=[
5163-
TextPart(content=''),
51645163
ThinkingPart(
51655164
content="""\
51665165
The user requests to call the tool with non-existent parameters to test error handling. We need to call the function "get_something_by_name" with wrong parameters. The function expects a single argument object with "name". Non-existent parameters means we could provide a wrong key, or missing name. Let's provide an object with wrong key "nonexistent": "value". That should cause error. So we call the function with {"nonexistent": "test"}.
@@ -5205,7 +5204,6 @@ async def get_something_by_name(name: str) -> str:
52055204
),
52065205
ModelResponse(
52075206
parts=[
5208-
TextPart(content=''),
52095207
ThinkingPart(content='We need to call with correct param: name. Use a placeholder name.'),
52105208
ToolCallPart(
52115209
tool_name='get_something_by_name',

tests/models/test_huggingface.py

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from unittest.mock import Mock
1010

1111
import pytest
12-
from dirty_equals import IsListOrTuple
1312
from inline_snapshot import snapshot
1413
from typing_extensions import TypedDict
1514

@@ -18,26 +17,22 @@
1817
AudioUrl,
1918
BinaryContent,
2019
DocumentUrl,
21-
FinalResultEvent,
2220
ImageUrl,
2321
ModelRequest,
2422
ModelResponse,
2523
ModelRetry,
26-
PartDeltaEvent,
27-
PartStartEvent,
2824
RetryPromptPart,
2925
SystemPromptPart,
3026
TextPart,
31-
TextPartDelta,
3227
ThinkingPart,
33-
ThinkingPartDelta,
3428
ToolCallPart,
3529
ToolReturnPart,
3630
UserPromptPart,
3731
VideoUrl,
3832
)
3933
from pydantic_ai.exceptions import ModelHTTPError
4034
from pydantic_ai.result import RunUsage
35+
from pydantic_ai.run import AgentRunResult, AgentRunResultEvent
4136
from pydantic_ai.settings import ModelSettings
4237
from pydantic_ai.tools import RunContext
4338
from pydantic_ai.usage import RequestUsage
@@ -978,35 +973,32 @@ async def test_hf_model_thinking_part_iter(allow_model_requests: None, huggingfa
978973
)
979974
agent = Agent(m)
980975

981-
event_parts: list[Any] = []
982-
async with agent.iter(user_prompt='How do I cross the street?') as agent_run:
983-
async for node in agent_run:
984-
if Agent.is_model_request_node(node) or Agent.is_call_tools_node(node):
985-
async with node.stream(agent_run.ctx) as request_stream:
986-
async for event in request_stream:
987-
event_parts.append(event)
988-
989-
assert event_parts == snapshot(
990-
IsListOrTuple(
991-
positions={
992-
0: PartStartEvent(index=0, part=ThinkingPart(content='')),
993-
1: PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='\n')),
994-
2: PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='Okay')),
995-
3: PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=',')),
996-
4: PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=' the')),
997-
5: PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=' user')),
998-
6: PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=' is')),
999-
7: PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=' asking')),
1000-
413: PartStartEvent(index=1, part=TextPart(content='Cross')),
1001-
414: FinalResultEvent(tool_name=None, tool_call_id=None),
1002-
415: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='ing')),
1003-
416: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' the')),
1004-
417: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' street')),
1005-
418: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' safely')),
1006-
419: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' requires')),
1007-
420: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' attent')),
1008-
421: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='iveness')),
1009-
},
1010-
length=1062,
1011-
)
976+
result: AgentRunResult | None = None
977+
async for event in agent.run_stream_events(user_prompt='How do I cross the street?'):
978+
if isinstance(event, AgentRunResultEvent):
979+
result = event.result
980+
981+
assert result is not None
982+
assert result.all_messages() == snapshot(
983+
[
984+
ModelRequest(
985+
parts=[
986+
UserPromptPart(
987+
content='How do I cross the street?',
988+
timestamp=IsDatetime(),
989+
)
990+
]
991+
),
992+
ModelResponse(
993+
parts=[
994+
ThinkingPart(content=IsStr()),
995+
TextPart(content=IsStr()),
996+
],
997+
model_name='Qwen/Qwen3-235B-A22B',
998+
timestamp=IsDatetime(),
999+
provider_name='huggingface',
1000+
provider_details={'finish_reason': 'stop'},
1001+
provider_response_id='chatcmpl-357f347a3f5d4897b36a128fb4e4cf7b',
1002+
),
1003+
]
10121004
)

tests/models/test_openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@ async def test_stream_tool_call_with_empty_text(allow_model_requests: None):
583583
chunk([]),
584584
]
585585
mock_client = MockOpenAI.create_mock_stream(stream)
586-
m = OpenAIChatModel('qwen3', provider=OllamaProvider(openai_client=mock_client))
586+
m = OpenAIChatModel('gpt-oss:20b', provider=OllamaProvider(openai_client=mock_client))
587587
agent = Agent(m, output_type=[str, MyTypedDict])
588588

589589
async with agent.run_stream('') as result:

tests/test_temporal.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,9 +1132,8 @@ async def test_temporal_agent_run_stream_events(allow_model_requests: None):
11321132
events = [event async for event in simple_temporal_agent.run_stream_events('What is the capital of Mexico?')]
11331133
assert events == snapshot(
11341134
[
1135-
PartStartEvent(index=0, part=TextPart(content='')),
1135+
PartStartEvent(index=0, part=TextPart(content='The')),
11361136
FinalResultEvent(tool_name=None, tool_call_id=None),
1137-
PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='The')),
11381137
PartDeltaEvent(index=0, delta=TextPartDelta(content_delta=' capital')),
11391138
PartDeltaEvent(index=0, delta=TextPartDelta(content_delta=' of')),
11401139
PartDeltaEvent(index=0, delta=TextPartDelta(content_delta=' Mexico')),

0 commit comments

Comments
 (0)