Skip to content

Commit dcac211

Browse files
committed
wip: improve coverage
1 parent 41a38e2 commit dcac211

File tree

4 files changed

+283
-33
lines changed

4 files changed

+283
-33
lines changed

pydantic_ai_slim/pydantic_ai/_parts_manager.py

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,25 @@ def get_parts(self) -> list[ModelResponsePart]:
7373
"""
7474
return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)]
7575

76+
def has_incomplete_parts(self) -> bool:
77+
"""Check if there are any incomplete ToolCallPartDeltas being managed.
78+
79+
Returns:
80+
True if there are any ToolCallPartDelta objects in the internal parts list.
81+
"""
82+
return any(isinstance(p, ToolCallPartDelta) for p in self._parts)
83+
84+
def is_vendor_id_mapped(self, vendor_id: VendorId) -> bool:
85+
"""Check if a vendor ID is currently mapped to a part index.
86+
87+
Args:
88+
vendor_id: The vendor ID to check.
89+
90+
Returns:
91+
True if the vendor ID is mapped to a part index, False otherwise.
92+
"""
93+
return vendor_id in self._vendor_id_to_part_index
94+
7695
def finalize(self) -> Generator[ModelResponseStreamEvent, None, None]:
7796
"""Flush any buffered content, appending to ThinkingParts or creating TextParts.
7897
@@ -106,7 +125,7 @@ def finalize(self) -> Generator[ModelResponseStreamEvent, None, None]:
106125

107126
# flush any remaining buffered content
108127
for vendor_part_id, buffered_content in list(self._thinking_tag_buffer.items()):
109-
if buffered_content:
128+
if buffered_content: # pragma: no branch - buffer should never contain empty string
110129
part_index = self._vendor_id_to_part_index.get(vendor_part_id)
111130

112131
# If buffered content belongs to a ThinkingPart, append it to the ThinkingPart
@@ -208,33 +227,7 @@ def _handle_text_delta_simple( # noqa: C901
208227
if part_index is not None:
209228
existing_part = self._parts[part_index]
210229

211-
if thinking_tags and isinstance(existing_part, ThinkingPart):
212-
end_tag = thinking_tags[1]
213-
if end_tag in content:
214-
before_end, after_end = content.split(end_tag, 1)
215-
216-
if before_end:
217-
yield from self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=before_end)
218-
219-
self._vendor_id_to_part_index.pop(vendor_part_id)
220-
221-
if after_end:
222-
yield from self._handle_text_delta_simple(
223-
vendor_part_id=vendor_part_id,
224-
content=after_end,
225-
id=id,
226-
thinking_tags=thinking_tags,
227-
ignore_leading_whitespace=ignore_leading_whitespace,
228-
)
229-
return
230-
231-
if content == end_tag:
232-
self._vendor_id_to_part_index.pop(vendor_part_id)
233-
return
234-
235-
yield from self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=content)
236-
return
237-
elif isinstance(existing_part, TextPart):
230+
if isinstance(existing_part, TextPart):
238231
existing_text_part_and_index = existing_part, part_index
239232
else:
240233
raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}')
@@ -267,19 +260,17 @@ def _handle_text_delta_simple( # noqa: C901
267260
# Create ThinkingPart but defer PartStartEvent until there is content
268261
new_part_index = len(self._parts)
269262
part = ThinkingPart(content='')
270-
if vendor_part_id is not None:
271-
self._vendor_id_to_part_index[vendor_part_id] = new_part_index
272263
self._parts.append(part)
273264

274-
if after_start:
265+
if after_start: # pragma: no branch
275266
yield from self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=after_start)
276267
return
277268

278269
if existing_text_part_and_index is None:
279270
# This is a workaround for models that emit `<think>\n</think>\n\n` or an empty text part ahead of tool calls (e.g. Ollama + Qwen3),
280271
# which we don't want to end up treating as a final result when using `run_stream` with `str` as a valid `output_type`.
281272
if ignore_leading_whitespace and (len(content) == 0 or content.isspace()):
282-
return None
273+
return
283274

284275
new_part_index = len(self._parts)
285276
part = TextPart(content=content, id=id)
@@ -294,7 +285,9 @@ def _handle_text_delta_simple( # noqa: C901
294285

295286
updated_text_part = part_delta.apply(existing_text_part)
296287
self._parts[part_index] = updated_text_part
297-
if part_index not in self._started_part_indices:
288+
if (
289+
part_index not in self._started_part_indices
290+
): # pragma: no cover - defensive: TextPart should always be started
298291
self._started_part_indices.add(part_index)
299292
yield PartStartEvent(index=part_index, part=updated_text_part)
300293
else:
@@ -458,6 +451,10 @@ def handle_thinking_delta(
458451
latest_part = self._parts[part_index]
459452
if isinstance(latest_part, ThinkingPart):
460453
existing_thinking_part_and_index = latest_part, part_index
454+
elif isinstance(latest_part, TextPart):
455+
raise UnexpectedModelBehavior(
456+
'Cannot create ThinkingPart after TextPart: thinking must come before text in response'
457+
)
461458
else:
462459
# Otherwise, attempt to look up an existing ThinkingPart by vendor_part_id
463460
part_index = self._vendor_id_to_part_index.get(vendor_part_id)

tests/test_parts_manager.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,9 @@ def test_handle_thinking_delta_new_part_with_vendor_id():
581581
parts = manager.get_parts()
582582
assert parts == snapshot([ThinkingPart(content='new thought')])
583583

584+
# Verify vendor_part_id was mapped to the part index
585+
assert manager.is_vendor_id_mapped('thinking')
586+
584587

585588
def test_handle_thinking_delta_no_content():
586589
manager = ModelResponsePartsManager()
@@ -603,6 +606,98 @@ def test_handle_thinking_delta_no_content_or_signature():
603606
pass
604607

605608

609+
def test_handle_text_delta_append_to_thinking_part_without_vendor_id():
610+
"""Test appending to ThinkingPart when vendor_part_id is None (lines 202-203)."""
611+
manager = ModelResponsePartsManager()
612+
thinking_tags = ('<think>', '</think>')
613+
614+
# Create a ThinkingPart using handle_text_delta with thinking tags and vendor_part_id=None
615+
events = list(manager.handle_text_delta(vendor_part_id=None, content='<think>initial', thinking_tags=thinking_tags))
616+
assert len(events) == 1
617+
assert isinstance(events[0], PartStartEvent)
618+
assert isinstance(events[0].part, ThinkingPart)
619+
assert events[0].part.content == 'initial'
620+
621+
# Now append more content with vendor_part_id=None - should append to existing ThinkingPart
622+
events = list(manager.handle_text_delta(vendor_part_id=None, content=' reasoning', thinking_tags=thinking_tags))
623+
assert len(events) == 1
624+
assert isinstance(events[0], PartDeltaEvent)
625+
assert events[0].index == 0
626+
627+
parts = manager.get_parts()
628+
assert len(parts) == 1
629+
assert isinstance(parts[0], ThinkingPart)
630+
assert parts[0].content == 'initial reasoning'
631+
632+
633+
def test_simple_path_whitespace_handling():
634+
"""Test whitespace-only prefix with ignore_leading_whitespace in simple path (S10 → S11).
635+
636+
This tests the branch where whitespace before a start tag is ignored when
637+
vendor_part_id=None (which routes to simple path).
638+
"""
639+
manager = ModelResponsePartsManager()
640+
thinking_tags = ('<think>', '</think>')
641+
642+
events = list(
643+
manager.handle_text_delta(
644+
vendor_part_id=None,
645+
content=' \n<think>reasoning',
646+
thinking_tags=thinking_tags,
647+
ignore_leading_whitespace=True,
648+
)
649+
)
650+
651+
assert len(events) == 1
652+
assert isinstance(events[0], PartStartEvent)
653+
assert isinstance(events[0].part, ThinkingPart)
654+
assert events[0].part.content == 'reasoning'
655+
656+
parts = manager.get_parts()
657+
assert len(parts) == 1
658+
assert isinstance(parts[0], ThinkingPart)
659+
assert parts[0].content == 'reasoning'
660+
661+
662+
def test_simple_path_text_prefix_rejection():
663+
"""Test that text before start tag disables thinking tag detection in simple path (S12).
664+
665+
When there's non-whitespace text before the start tag, the entire content should be
666+
treated as a TextPart with the tag included as literal text.
667+
"""
668+
manager = ModelResponsePartsManager()
669+
thinking_tags = ('<think>', '</think>')
670+
671+
events = list(
672+
manager.handle_text_delta(vendor_part_id=None, content='foo<think>reasoning', thinking_tags=thinking_tags)
673+
)
674+
675+
assert len(events) == 1
676+
assert isinstance(events[0], PartStartEvent)
677+
assert isinstance(events[0].part, TextPart)
678+
assert events[0].part.content == 'foo<think>reasoning'
679+
680+
parts = manager.get_parts()
681+
assert len(parts) == 1
682+
assert isinstance(parts[0], TextPart)
683+
assert parts[0].content == 'foo<think>reasoning'
684+
685+
686+
def test_empty_whitespace_content_with_ignore_leading_whitespace():
687+
"""Test that empty/whitespace content is ignored when ignore_leading_whitespace=True (line 282)."""
688+
manager = ModelResponsePartsManager()
689+
690+
# Empty content with ignore_leading_whitespace should yield no events
691+
events = list(manager.handle_text_delta(vendor_part_id='id1', content='', ignore_leading_whitespace=True))
692+
assert len(events) == 0
693+
assert manager.get_parts() == []
694+
695+
# Whitespace-only content with ignore_leading_whitespace should yield no events
696+
events = list(manager.handle_text_delta(vendor_part_id='id2', content=' \n\t', ignore_leading_whitespace=True))
697+
assert len(events) == 0
698+
assert manager.get_parts() == []
699+
700+
606701
def test_handle_part():
607702
manager = ModelResponsePartsManager()
608703

@@ -632,3 +727,60 @@ def test_handle_part():
632727
event = manager.handle_part(vendor_part_id=None, part=part3)
633728
assert event == snapshot(PartStartEvent(index=1, part=part3))
634729
assert manager.get_parts() == snapshot([part2, part3])
730+
731+
732+
def test_handle_tool_call_delta_no_vendor_id_with_non_tool_latest_part():
733+
"""Test handle_tool_call_delta with vendor_part_id=None when latest part is NOT a tool call (line 515->526)."""
734+
manager = ModelResponsePartsManager()
735+
736+
# Create a TextPart first
737+
for _ in manager.handle_text_delta(vendor_part_id=None, content='some text'):
738+
pass
739+
740+
# Try to send a tool call delta with vendor_part_id=None and tool_name=None
741+
# Since latest part is NOT a tool call, this should create a new incomplete tool call delta
742+
event = manager.handle_tool_call_delta(vendor_part_id=None, tool_name=None, args='{"arg":')
743+
744+
# Since tool_name is None for a new part, we get a ToolCallPartDelta with no event
745+
assert event is None
746+
747+
# The ToolCallPartDelta is created internally but not returned by get_parts() since it's incomplete
748+
assert manager.has_incomplete_parts()
749+
assert len(manager.get_parts()) == 1
750+
assert isinstance(manager.get_parts()[0], TextPart)
751+
752+
753+
def test_handle_thinking_delta_raises_error_when_thinking_after_text():
754+
"""Test that handle_thinking_delta raises error when trying to create ThinkingPart after TextPart."""
755+
manager = ModelResponsePartsManager()
756+
757+
# Create a TextPart first
758+
for _ in manager.handle_text_delta(vendor_part_id=None, content='some text'):
759+
pass
760+
761+
# Now try to create a ThinkingPart with vendor_part_id=None
762+
# This should raise an error because thinking must come before text
763+
with pytest.raises(
764+
UnexpectedModelBehavior, match='Cannot create ThinkingPart after TextPart: thinking must come before text'
765+
):
766+
for _ in manager.handle_thinking_delta(vendor_part_id=None, content='thinking'):
767+
pass
768+
769+
770+
def test_handle_thinking_delta_create_new_part_with_no_vendor_id():
771+
"""Test creating new ThinkingPart when vendor_part_id is None and no parts exist yet."""
772+
manager = ModelResponsePartsManager()
773+
774+
# Create ThinkingPart with vendor_part_id=None (no parts exist yet, so no constraint violation)
775+
events = list(manager.handle_thinking_delta(vendor_part_id=None, content='thinking'))
776+
777+
assert len(events) == 1
778+
assert isinstance(events[0], PartStartEvent)
779+
assert events[0].index == 0
780+
781+
parts = manager.get_parts()
782+
assert len(parts) == 1
783+
assert parts[0] == snapshot(ThinkingPart(content='thinking'))
784+
785+
# Verify vendor_part_id was NOT mapped (it's None)
786+
assert not manager.is_vendor_id_mapped('thinking')

tests/test_parts_manager_split_tags.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,3 +305,68 @@ def test_thinking_interrupted_by_incomplete_end_tag_and_vendor_switch():
305305
assert parts[0].content == 'thinking foo</th'
306306
assert isinstance(parts[1], TextPart)
307307
assert parts[1].content == 'new content'
308+
309+
310+
def test_split_end_tag_with_content_before():
311+
"""Test content before split end tag in buffered chunks (line 337)."""
312+
events, parts = stream_text_deltas(['<think>', 'reasoning content</th', 'ink>'])
313+
314+
assert len(parts) == 1
315+
assert isinstance(parts[0], ThinkingPart)
316+
assert parts[0].content == 'reasoning content'
317+
318+
# Verify events
319+
assert any(isinstance(e, PartStartEvent) and isinstance(e.part, ThinkingPart) for e in events)
320+
321+
322+
def test_split_end_tag_with_content_after():
323+
"""Test content after split end tag in buffered chunks (line 343)."""
324+
events, parts = stream_text_deltas(['<think>', 'reasoning', '</thi', 'nk>after text'])
325+
326+
assert len(parts) == 2
327+
assert isinstance(parts[0], ThinkingPart)
328+
assert parts[0].content == 'reasoning'
329+
assert isinstance(parts[1], TextPart)
330+
assert parts[1].content == 'after text'
331+
332+
# Verify events
333+
assert any(isinstance(e, PartStartEvent) and isinstance(e.part, ThinkingPart) for e in events)
334+
assert any(isinstance(e, PartStartEvent) and isinstance(e.part, TextPart) for e in events)
335+
336+
337+
def test_split_end_tag_with_content_before_and_after():
338+
"""Test content both before and after split end tag."""
339+
_, parts = stream_text_deltas(['<think>', 'reason</th', 'ink>after'])
340+
341+
assert len(parts) == 2
342+
assert isinstance(parts[0], ThinkingPart)
343+
assert parts[0].content == 'reason'
344+
assert isinstance(parts[1], TextPart)
345+
assert parts[1].content == 'after'
346+
347+
348+
def test_cross_path_end_tag_handling():
349+
"""Test end tag handling when buffering fallback delegates to simple path (C2 → S5).
350+
351+
This tests the scenario where buffering creates a ThinkingPart, then non-matching
352+
content triggers the C2 fallback to simple path, which then handles the end tag.
353+
"""
354+
_, parts = stream_text_deltas(['<think>initial', 'x', 'more</think>after'])
355+
356+
assert len(parts) == 2
357+
assert isinstance(parts[0], ThinkingPart)
358+
assert parts[0].content == 'initialxmore'
359+
assert isinstance(parts[1], TextPart)
360+
assert parts[1].content == 'after'
361+
362+
363+
def test_cross_path_bare_end_tag():
364+
"""Test bare end tag when buffering fallback delegates to simple path (C2 → S5).
365+
366+
This tests the specific branch where content equals exactly the end tag.
367+
"""
368+
_, parts = stream_text_deltas(['<think>done', 'x', '</think>'])
369+
370+
assert len(parts) == 1
371+
assert isinstance(parts[0], ThinkingPart)
372+
assert parts[0].content == 'donex'

tests/test_streaming.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1892,3 +1892,39 @@ async def ret_a(x: str) -> str:
18921892
AgentRunResultEvent(result=AgentRunResult(output='{"ret_a":"a-apple"}')),
18931893
]
18941894
)
1895+
1896+
1897+
async def test_streaming_finalize_with_incomplete_thinking_tag():
1898+
"""Test that incomplete thinking tags are flushed via finalize during streaming (lines 585-591 in models/__init__.py)."""
1899+
1900+
async def stream_with_incomplete_thinking(
1901+
_messages: list[ModelMessage], _agent_info: AgentInfo
1902+
) -> AsyncIterator[str]:
1903+
# Stream incomplete thinking tag that will be buffered
1904+
yield '<thi'
1905+
1906+
agent = Agent(FunctionModel(stream_function=stream_with_incomplete_thinking))
1907+
1908+
events: list[AgentStreamEvent] = []
1909+
1910+
async def event_stream_handler(_ctx: RunContext[None], stream: AsyncIterable[AgentStreamEvent]):
1911+
async for event in stream:
1912+
events.append(event)
1913+
1914+
# This will trigger the finalize logic in models/__init__.py when the stream completes
1915+
result = await agent.run('Hello', event_stream_handler=event_stream_handler)
1916+
1917+
# The incomplete tag should be flushed as TextPart
1918+
assert result.output == '<thi'
1919+
1920+
# Verify that PartStartEvent and PartEndEvent were emitted from finalize
1921+
part_start_events = [e for e in events if isinstance(e, PartStartEvent)]
1922+
part_end_events = [e for e in events if isinstance(e, PartEndEvent)]
1923+
1924+
assert len(part_start_events) == 1
1925+
assert isinstance(part_start_events[0].part, TextPart)
1926+
assert part_start_events[0].part.content == '<thi'
1927+
1928+
assert len(part_end_events) == 1
1929+
assert isinstance(part_end_events[0].part, TextPart)
1930+
assert part_end_events[0].part.content == '<thi'

0 commit comments

Comments
 (0)