Skip to content

Commit 9b598dd

Browse files
committed
models
- move finalize to aiter - update models to the generator return type parts manager - disallow thinking after text - delay emittion of thinking parts until there's content tests - swap out list calls for iteration - add helper and consolidate tests to make them clearer
1 parent 0998a63 commit 9b598dd

File tree

13 files changed

+449
-627
lines changed

13 files changed

+449
-627
lines changed

pydantic_ai_slim/pydantic_ai/_parts_manager.py

Lines changed: 151 additions & 79 deletions
Large diffs are not rendered by default.

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ class StreamedResponse(ABC):
521521
_event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
522522
_usage: RequestUsage = field(default_factory=RequestUsage, init=False)
523523

524-
def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]:
524+
def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
525525
"""Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
526526
527527
This proxies the `_event_iterator()` and emits all events, while also checking for matches
@@ -580,6 +580,16 @@ def part_end_event(next_part: ModelResponsePart | None = None) -> PartEndEvent |
580580

581581
yield event
582582

583+
# Flush any buffered content and stream finalize events
584+
for finalize_event in self._parts_manager.finalize():
585+
if isinstance(finalize_event, PartStartEvent):
586+
if last_start_event:
587+
end_event = part_end_event(finalize_event.part)
588+
if end_event:
589+
yield end_event
590+
last_start_event = finalize_event
591+
yield finalize_event
592+
583593
end_event = part_end_event()
584594
if end_event:
585595
yield end_event

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -734,19 +734,21 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
734734
):
735735
yield event_item
736736
elif isinstance(current_block, BetaThinkingBlock):
737-
yield self._parts_manager.handle_thinking_delta(
737+
for e in self._parts_manager.handle_thinking_delta(
738738
vendor_part_id=event.index,
739739
content=current_block.thinking,
740740
signature=current_block.signature,
741741
provider_name=self.provider_name,
742-
)
742+
):
743+
yield e
743744
elif isinstance(current_block, BetaRedactedThinkingBlock):
744-
yield self._parts_manager.handle_thinking_delta(
745+
for e in self._parts_manager.handle_thinking_delta(
745746
vendor_part_id=event.index,
746747
id='redacted_thinking',
747748
signature=current_block.data,
748749
provider_name=self.provider_name,
749-
)
750+
):
751+
yield e
750752
elif isinstance(current_block, BetaToolUseBlock):
751753
maybe_event = self._parts_manager.handle_tool_call_delta(
752754
vendor_part_id=event.index,
@@ -807,17 +809,19 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
807809
):
808810
yield event_item
809811
elif isinstance(event.delta, BetaThinkingDelta):
810-
yield self._parts_manager.handle_thinking_delta(
812+
for e in self._parts_manager.handle_thinking_delta(
811813
vendor_part_id=event.index,
812814
content=event.delta.thinking,
813815
provider_name=self.provider_name,
814-
)
816+
):
817+
yield e
815818
elif isinstance(event.delta, BetaSignatureDelta):
816-
yield self._parts_manager.handle_thinking_delta(
819+
for e in self._parts_manager.handle_thinking_delta(
817820
vendor_part_id=event.index,
818821
signature=event.delta.signature,
819822
provider_name=self.provider_name,
820-
)
823+
):
824+
yield e
821825
elif isinstance(event.delta, BetaInputJSONDelta):
822826
maybe_event = self._parts_manager.handle_tool_call_delta(
823827
vendor_part_id=event.index,

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -687,20 +687,22 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
687687
delta = content_block_delta['delta']
688688
if 'reasoningContent' in delta:
689689
if redacted_content := delta['reasoningContent'].get('redactedContent'):
690-
yield self._parts_manager.handle_thinking_delta(
690+
for e in self._parts_manager.handle_thinking_delta(
691691
vendor_part_id=index,
692692
id='redacted_content',
693693
signature=redacted_content.decode('utf-8'),
694694
provider_name=self.provider_name,
695-
)
695+
):
696+
yield e
696697
else:
697698
signature = delta['reasoningContent'].get('signature')
698-
yield self._parts_manager.handle_thinking_delta(
699+
for e in self._parts_manager.handle_thinking_delta(
699700
vendor_part_id=index,
700701
content=delta['reasoningContent'].get('text'),
701702
signature=signature,
702703
provider_name=self.provider_name if signature else None,
703-
)
704+
):
705+
yield e
704706
if text := delta.get('text'):
705707
for event in self._parts_manager.handle_text_delta(vendor_part_id=index, content=text):
706708
yield event

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ class FunctionStreamedResponse(StreamedResponse):
284284
def __post_init__(self):
285285
self._usage += _estimate_usage([])
286286

287-
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
287+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
288288
async for item in self._iter:
289289
if isinstance(item, str):
290290
response_tokens = _estimate_string_tokens(item)
@@ -297,12 +297,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
297297
if delta.content: # pragma: no branch
298298
response_tokens = _estimate_string_tokens(delta.content)
299299
self._usage += usage.RequestUsage(output_tokens=response_tokens)
300-
yield self._parts_manager.handle_thinking_delta(
300+
for e in self._parts_manager.handle_thinking_delta(
301301
vendor_part_id=dtc_index,
302302
content=delta.content,
303303
signature=delta.signature,
304304
provider_name='function' if delta.signature else None,
305-
)
305+
):
306+
yield e
306307
elif isinstance(delta, DeltaToolCall):
307308
if delta.json_args:
308309
response_tokens = _estimate_string_tokens(delta.json_args)

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -668,15 +668,19 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
668668
for part in parts:
669669
if part.thought_signature:
670670
signature = base64.b64encode(part.thought_signature).decode('utf-8')
671-
yield self._parts_manager.handle_thinking_delta(
671+
for e in self._parts_manager.handle_thinking_delta(
672672
vendor_part_id='thinking',
673673
signature=signature,
674674
provider_name=self.provider_name,
675-
)
675+
):
676+
yield e
676677

677678
if part.text is not None:
678679
if part.thought:
679-
yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text)
680+
for e in self._parts_manager.handle_thinking_delta(
681+
vendor_part_id='thinking', content=part.text
682+
):
683+
yield e
680684
else:
681685
for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text):
682686
yield event

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -547,9 +547,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
547547
reasoning = True
548548

549549
# NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`.
550-
yield self._parts_manager.handle_thinking_delta(
550+
for e in self._parts_manager.handle_thinking_delta(
551551
vendor_part_id=f'reasoning-{reasoning_index}', content=choice.delta.reasoning
552-
)
552+
):
553+
yield e
553554
else:
554555
reasoning = False
555556

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,7 +1680,7 @@ class OpenAIStreamedResponse(StreamedResponse):
16801680
_provider_name: str
16811681
_provider_url: str
16821682

1683-
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
1683+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
16841684
async for chunk in self._response:
16851685
self._usage += _map_usage(chunk, self._provider_name, self._provider_url, self._model_name)
16861686

@@ -1706,23 +1706,25 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
17061706
# The `reasoning_content` field is only present in DeepSeek models.
17071707
# https://api-docs.deepseek.com/guides/reasoning_model
17081708
if reasoning_content := getattr(choice.delta, 'reasoning_content', None):
1709-
yield self._parts_manager.handle_thinking_delta(
1709+
for e in self._parts_manager.handle_thinking_delta(
17101710
vendor_part_id='reasoning_content',
17111711
id='reasoning_content',
17121712
content=reasoning_content,
17131713
provider_name=self.provider_name,
1714-
)
1714+
):
1715+
yield e
17151716

17161717
# The `reasoning` field is only present in gpt-oss via Ollama and OpenRouter.
17171718
# - https://cookbook.openai.com/articles/gpt-oss/handle-raw-cot#chat-completions-api
17181719
# - https://openrouter.ai/docs/use-cases/reasoning-tokens#basic-usage-with-reasoning-tokens
17191720
if reasoning := getattr(choice.delta, 'reasoning', None): # pragma: no cover
1720-
yield self._parts_manager.handle_thinking_delta(
1721+
for e in self._parts_manager.handle_thinking_delta(
17211722
vendor_part_id='reasoning',
17221723
id='reasoning',
17231724
content=reasoning,
17241725
provider_name=self.provider_name,
1725-
)
1726+
):
1727+
yield e
17261728

17271729
# Handle the text part of the response
17281730
content = choice.delta.content
@@ -1887,12 +1889,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
18871889
if isinstance(chunk.item, responses.ResponseReasoningItem):
18881890
if signature := chunk.item.encrypted_content: # pragma: no branch
18891891
# Add the signature to the part corresponding to the first summary item
1890-
yield self._parts_manager.handle_thinking_delta(
1892+
for e in self._parts_manager.handle_thinking_delta(
18911893
vendor_part_id=f'{chunk.item.id}-0',
18921894
id=chunk.item.id,
18931895
signature=signature,
18941896
provider_name=self.provider_name,
1895-
)
1897+
):
1898+
yield e
18961899
elif isinstance(chunk.item, responses.ResponseCodeInterpreterToolCall):
18971900
_, return_part, file_parts = _map_code_interpreter_tool_call(chunk.item, self.provider_name)
18981901
for i, file_part in enumerate(file_parts):
@@ -1925,11 +1928,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
19251928
yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-return', part=return_part)
19261929

19271930
elif isinstance(chunk, responses.ResponseReasoningSummaryPartAddedEvent):
1928-
yield self._parts_manager.handle_thinking_delta(
1931+
for e in self._parts_manager.handle_thinking_delta(
19291932
vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}',
19301933
content=chunk.part.text,
19311934
id=chunk.item_id,
1932-
)
1935+
):
1936+
yield e
19331937

19341938
elif isinstance(chunk, responses.ResponseReasoningSummaryPartDoneEvent):
19351939
pass # there's nothing we need to do here
@@ -1938,11 +1942,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
19381942
pass # there's nothing we need to do here
19391943

19401944
elif isinstance(chunk, responses.ResponseReasoningSummaryTextDeltaEvent):
1941-
yield self._parts_manager.handle_thinking_delta(
1945+
for e in self._parts_manager.handle_thinking_delta(
19421946
vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}',
19431947
content=chunk.delta,
19441948
id=chunk.item_id,
1945-
)
1949+
):
1950+
yield e
19461951

19471952
elif isinstance(chunk, responses.ResponseOutputTextAnnotationAddedEvent):
19481953
# TODO(Marcelo): We should support annotations in the future.

pydantic_ai_slim/pydantic_ai/models/outlines.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from __future__ import annotations
77

88
import io
9-
from collections.abc import AsyncIterable, AsyncIterator, Sequence
9+
from collections.abc import AsyncIterable, AsyncIterator, Iterator, Sequence
1010
from contextlib import asynccontextmanager
1111
from dataclasses import dataclass
1212
from datetime import datetime, timezone
@@ -537,15 +537,18 @@ class OutlinesStreamedResponse(StreamedResponse):
537537
_provider_name: str
538538

539539
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
540-
async for event in self._response:
541-
event = self._parts_manager.handle_text_delta(
542-
vendor_part_id='content',
543-
content=event,
544-
thinking_tags=self._model_profile.thinking_tags,
545-
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
540+
async for chunk in self._response:
541+
events = cast(
542+
Iterator[ModelResponseStreamEvent],
543+
self._parts_manager.handle_text_delta(
544+
vendor_part_id='content',
545+
content=chunk,
546+
thinking_tags=self._model_profile.thinking_tags,
547+
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
548+
),
546549
)
547-
if event is not None: # pragma: no branch
548-
yield event
550+
for e in events:
551+
yield e
549552

550553
@property
551554
def model_name(self) -> str:

tests/models/test_groq.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2061,8 +2061,7 @@ async def test_groq_model_thinking_part_iter(allow_model_requests: None, groq_ap
20612061

20622062
assert event_parts == snapshot(
20632063
[
2064-
PartStartEvent(index=0, part=ThinkingPart(content='')),
2065-
PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='\n')),
2064+
PartStartEvent(index=0, part=ThinkingPart(content='\n')),
20662065
PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='Okay')),
20672066
PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=',')),
20682067
PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=' so')),

0 commit comments

Comments
 (0)