Skip to content

Commit 6dac474

Browse files
committed
fix final_flush coverage
1 parent 706ad78 commit 6dac474

File tree

4 files changed

+32
-14
lines changed

4 files changed

+32
-14
lines changed

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ class StreamedResponse(ABC):
569569
_event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
570570
_usage: RequestUsage = field(default_factory=RequestUsage, init=False)
571571

572-
def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
572+
def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]:
573573
"""Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
574574
575575
This proxies the `_event_iterator()` and emits all events, while also checking for matches
@@ -616,15 +616,7 @@ def part_end_event(next_part: ModelResponsePart | None = None) -> PartEndEvent |
616616
next_part_kind=next_part.part_kind if next_part else None,
617617
)
618618

619-
async def chain_async_and_sync_iters(
620-
iter1: AsyncIterator[ModelResponseStreamEvent], iter2: Iterator[ModelResponseStreamEvent]
621-
) -> AsyncIterator[ModelResponseStreamEvent]:
622-
async for event in iter1:
623-
yield event
624-
for event in iter2:
625-
yield event
626-
627-
async for event in chain_async_and_sync_iters(iterator, self._parts_manager.final_flush()):
619+
async for event in iterator:
628620
if isinstance(event, PartStartEvent):
629621
if last_start_event:
630622
end_event = part_end_event(event.part)
@@ -642,8 +634,7 @@ async def chain_async_and_sync_iters(
642634

643635
self._event_iterator = iterator_with_part_end(
644636
iterator_with_final_event(
645-
# TODO chain_async_and_sync_iters(iterator, self._parts_manager.final_flush())
646-
self._get_event_iterator()
637+
chain_async_and_sync_iters(self._get_event_iterator(), self._parts_manager.final_flush())
647638
)
648639
)
649640
return self._event_iterator
@@ -704,6 +695,16 @@ def timestamp(self) -> datetime:
704695
raise NotImplementedError()
705696

706697

698+
async def chain_async_and_sync_iters(
699+
iter1: AsyncIterator[ModelResponseStreamEvent], iter2: Iterator[ModelResponseStreamEvent]
700+
) -> AsyncIterator[ModelResponseStreamEvent]:
701+
"""Chain an async iterator with a sync iterator."""
702+
async for event in iter1:
703+
yield event
704+
for event in iter2:
705+
yield event
706+
707+
707708
ALLOW_MODEL_REQUESTS = True
708709
"""Whether to allow requests to models.
709710

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ async def request_stream(
186186

187187
yield FunctionStreamedResponse(
188188
model_request_parameters=model_request_parameters,
189+
_model_profile=self.profile,
189190
_model_name=self._model_name,
190191
_iter=response_stream,
191192
)
@@ -286,6 +287,7 @@ class FunctionStreamedResponse(StreamedResponse):
286287
"""Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
287288

288289
_model_name: str
290+
_model_profile: ModelProfile
289291
_iter: AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls | BuiltinToolCallsReturns]
290292
_timestamp: datetime = field(default_factory=_utils.now_utc)
291293

@@ -297,7 +299,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
297299
if isinstance(item, str):
298300
response_tokens = _estimate_string_tokens(item)
299301
self._usage += usage.RequestUsage(output_tokens=response_tokens)
300-
for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=item):
302+
for event in self._parts_manager.handle_text_delta(
303+
vendor_part_id='content', content=item, thinking_tags=self._model_profile.thinking_tags
304+
):
301305
yield event
302306
elif isinstance(item, dict) and item:
303307
for dtc_index, delta in item.items():

tests/test_parts_manager_thinking_tags.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,16 @@ class Case:
376376
],
377377
ignore_leading_whitespace=True,
378378
),
379+
Case(
380+
name='new_part_ignore_whitespace_mixed_with_full_opening',
381+
chunks=[' <think>'],
382+
expected_parts=[TextPart('<think>')],
383+
expected_normal_events=[],
384+
expected_flushed_events=[
385+
PartStartEvent(index=0, part=TextPart('<think>')),
386+
],
387+
ignore_leading_whitespace=True,
388+
),
379389
]
380390

381391
# Category 9: No Vendor ID (updates, new after thinking, closings as text)

tests/test_streaming.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2110,7 +2110,10 @@ async def stream_with_incomplete_thinking(
21102110
) -> AsyncIterator[str]:
21112111
yield '<thi'
21122112

2113-
agent = Agent(FunctionModel(stream_function=stream_with_incomplete_thinking))
2113+
function_model = FunctionModel(stream_function=stream_with_incomplete_thinking)
2114+
function_model.profile.thinking_tags = ('<think>', '</think>')
2115+
2116+
agent = Agent(function_model)
21142117

21152118
events: list[Any] = []
21162119
async for event in agent.run_stream_events('Hello'):

0 commit comments

Comments
 (0)