Skip to content

Commit 2bc1304

Browse files
committed
refactor parts manager and add parametrized cases
1 parent 0838109 commit 2bc1304

File tree

7 files changed

+1054
-878
lines changed

7 files changed

+1054
-878
lines changed

pydantic_ai_slim/pydantic_ai/_parts_manager.py

Lines changed: 444 additions & 416 deletions
Large diffs are not rendered by default.

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,11 @@ class TextPart:
967967
part_kind: Literal['text'] = 'text'
968968
"""Part type identifier, this is available on all parts as a discriminator."""
969969

970+
potential_opening_tag_buffer: Annotated[str, pydantic.Field(exclude=True)] = field(
971+
compare=False, default='', repr=False
972+
)
973+
"""A buffer to accumulate a potential opening tag (like '<thi')."""
974+
970975
def has_content(self) -> bool:
971976
"""Return `True` if the text content is non-empty."""
972977
return bool(self.content)
@@ -1006,6 +1011,9 @@ class ThinkingPart:
10061011
part_kind: Literal['thinking'] = 'thinking'
10071012
"""Part type identifier, this is available on all parts as a discriminator."""
10081013

1014+
closing_tag_buffer: Annotated[str, pydantic.Field(exclude=True)] = field(compare=False, default='', repr=False)
1015+
"""A buffer to accumulate a potential closing tag (like '</th')."""
1016+
10091017
def has_content(self) -> bool:
10101018
"""Return `True` if the thinking content is non-empty."""
10111019
return bool(self.content)

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,17 @@ def part_end_event(next_part: ModelResponsePart | None = None) -> PartEndEvent |
569569
next_part_kind=next_part.part_kind if next_part else None,
570570
)
571571

572-
async for event in iterator:
572+
async def chain_async_and_sync_iters(
573+
iter1: AsyncIterator[ModelResponseStreamEvent], iter2: Iterator[ModelResponseStreamEvent]
574+
) -> AsyncIterator[ModelResponseStreamEvent]:
575+
async for event in iter1:
576+
yield event
577+
for event in (
578+
iter2
579+
): # pragma: no cover - loop never started - flush_buffer() seems to be being called before
580+
yield event
581+
582+
async for event in chain_async_and_sync_iters(iterator, self._parts_manager.flush_buffer()):
573583
if isinstance(event, PartStartEvent):
574584
if last_start_event:
575585
end_event = part_end_event(event.part)
@@ -581,16 +591,6 @@ def part_end_event(next_part: ModelResponsePart | None = None) -> PartEndEvent |
581591

582592
yield event
583593

584-
# Flush any buffered content and stream finalize events
585-
for finalize_event in self._parts_manager.finalize(): # pragma: no cover
586-
if isinstance(finalize_event, PartStartEvent):
587-
if last_start_event:
588-
end_event = part_end_event(finalize_event.part)
589-
if end_event:
590-
yield end_event
591-
last_start_event = finalize_event
592-
yield finalize_event
593-
594594
end_event = part_end_event()
595595
if end_event:
596596
yield end_event
@@ -616,7 +616,7 @@ def get(self) -> ModelResponse:
616616
# Flush any buffered content before building response
617617
# clone parts manager to avoid modifying the ongoing stream state
618618
cloned_manager = copy.deepcopy(self._parts_manager)
619-
for _ in cloned_manager.finalize():
619+
for _ in cloned_manager.flush_buffer():
620620
pass
621621

622622
return ModelResponse(

tests/models/test_groq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2061,7 +2061,8 @@ async def test_groq_model_thinking_part_iter(allow_model_requests: None, groq_ap
20612061

20622062
assert event_parts == snapshot(
20632063
[
2064-
PartStartEvent(index=0, part=ThinkingPart(content='\n')),
2064+
PartStartEvent(index=0, part=ThinkingPart(content='')),
2065+
PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='\n')),
20652066
PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='Okay')),
20662067
PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=',')),
20672068
PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=' so')),

0 commit comments

Comments
 (0)