Skip to content

Commit adc51e6

Browse files
committed
adds a finalize method to prevent lost content from buffered chunks that look like thinking tags
1 parent 876ebb2 commit adc51e6

File tree

4 files changed

+129
-0
lines changed

4 files changed

+129
-0
lines changed

pydantic_ai_slim/pydantic_ai/_parts_manager.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,28 @@ def get_parts(self) -> list[ModelResponsePart]:
6969
"""
7070
return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)]
7171

72+
def finalize(self) -> Generator[ModelResponseStreamEvent, None, None]:
73+
"""Flush any buffered content as text parts.
74+
75+
This should be called when streaming is complete to ensure no content is lost.
76+
Any content buffered in _thinking_tag_buffer that hasn't been processed will be
77+
treated as regular text and emitted.
78+
79+
Yields:
80+
ModelResponseStreamEvent for any buffered content that gets flushed.
81+
"""
82+
for vendor_part_id, buffered_content in list(self._thinking_tag_buffer.items()):
83+
if buffered_content:
84+
yield from self._handle_text_delta_simple(
85+
vendor_part_id=vendor_part_id,
86+
content=buffered_content,
87+
id=None,
88+
thinking_tags=None,
89+
ignore_leading_whitespace=False,
90+
)
91+
92+
self._thinking_tag_buffer.clear()
93+
7294
def handle_text_delta(
7395
self,
7496
*,

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
553553

554554
def get(self) -> ModelResponse:
555555
"""Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
556+
# Flush any buffered content before building response
557+
for _ in self._parts_manager.finalize():
558+
pass
559+
556560
return ModelResponse(
557561
parts=self._parts_manager.get_parts(),
558562
model_name=self.model_name,

tests/models/test_model_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,3 +342,24 @@ def test_different_content_input(content: AudioUrl | VideoUrl | ImageUrl | Binar
342342
result = agent.run_sync(['x', content], model=TestModel(custom_output_text='custom'))
343343
assert result.output == snapshot('custom')
344344
assert result.usage() == snapshot(RunUsage(requests=1, input_tokens=51, output_tokens=1))
345+
346+
347+
@pytest.mark.anyio
348+
async def test_finalize_integration_buffered_content():
349+
"""Integration test: StreamedResponse.get() calls finalize() without breaking.
350+
351+
Note: TestModel doesn't pass thinking_tags during streaming, so this doesn't actually
352+
test buffering behavior - it just verifies that calling get() works correctly.
353+
The actual buffering logic is thoroughly tested in test_parts_manager_split_tags.py,
354+
and normal streaming is tested extensively in test_streaming.py.
355+
"""
356+
test_model = TestModel(custom_output_text='Hello <thi')
357+
358+
agent = Agent(test_model)
359+
360+
# Run with streaming and get the final output
361+
async with agent.run_stream('test prompt') as result:
362+
output = await result.get_output()
363+
364+
# Verify we get the expected output (processed as plain text, not buffered)
365+
assert output == 'Hello <thi'

tests/test_parts_manager_split_tags.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,85 @@ def test_exact_tag_length_boundary():
202202
assert events[0] == snapshot(
203203
PartStartEvent(index=0, part=ThinkingPart(content='', part_kind='thinking'), event_kind='part_start')
204204
)
205+
206+
207+
def test_buffered_content_flushed_on_finalize():
208+
"""Test that buffered content is flushed when finalize is called."""
209+
manager = ModelResponsePartsManager()
210+
thinking_tags = ('<think>', '</think>')
211+
212+
# Buffer partial tag
213+
events = list(manager.handle_text_delta(vendor_part_id='content', content='<thi', thinking_tags=thinking_tags))
214+
assert len(events) == 0 # Buffered
215+
216+
# Finalize should flush buffer
217+
final_events = list(manager.finalize())
218+
assert len(final_events) == 1
219+
assert final_events[0] == snapshot(
220+
PartStartEvent(index=0, part=TextPart(content='<thi', part_kind='text'), event_kind='part_start')
221+
)
222+
223+
224+
def test_finalize_flushes_all_buffers():
225+
"""Test that finalize flushes all vendor_part_id buffers."""
226+
manager = ModelResponsePartsManager()
227+
thinking_tags = ('<think>', '</think>')
228+
229+
# Buffer for vendor_id_1
230+
list(manager.handle_text_delta(vendor_part_id='id1', content='<th', thinking_tags=thinking_tags))
231+
232+
# Buffer for vendor_id_2
233+
list(manager.handle_text_delta(vendor_part_id='id2', content='<thi', thinking_tags=thinking_tags))
234+
235+
# Finalize should flush both
236+
final_events = list(manager.finalize())
237+
assert len(final_events) == 2
238+
239+
# Both should become TextParts
240+
parts = manager.get_parts()
241+
assert len(parts) == 2
242+
assert all(isinstance(p, TextPart) for p in parts)
243+
# Note: order may vary, so check content exists
244+
text_parts = [p for p in parts if isinstance(p, TextPart)]
245+
contents = {p.content for p in text_parts}
246+
assert contents == {'<th', '<thi'}
247+
248+
249+
def test_finalize_with_no_buffer():
250+
"""Test that finalize is safe when buffer is empty."""
251+
manager = ModelResponsePartsManager()
252+
events = list(manager.finalize())
253+
assert len(events) == 0 # No events, no errors
254+
255+
256+
def test_finalize_with_empty_buffered_content():
257+
"""Test that finalize handles empty string in buffer (covers 83->82 branch)."""
258+
manager = ModelResponsePartsManager()
259+
# Add both empty and non-empty content to test the branch where buffered_content is falsy
260+
# This ensures the loop continues after skipping the empty content
261+
manager._thinking_tag_buffer['id1'] = '' # Will be skipped # pyright: ignore[reportPrivateUsage]
262+
manager._thinking_tag_buffer['id2'] = 'content' # Will be flushed # pyright: ignore[reportPrivateUsage]
263+
events = list(manager.finalize())
264+
assert len(events) == 1 # Only non-empty content produces events
265+
assert isinstance(events[0], PartStartEvent)
266+
assert events[0].part == TextPart(content='content')
267+
assert manager._thinking_tag_buffer == {} # Buffer should be cleared # pyright: ignore[reportPrivateUsage]
268+
269+
270+
def test_get_parts_after_finalize():
271+
"""Test that get_parts returns flushed content after finalize (unit test)."""
272+
# NOTE: This is a unit test of the manager. Real integration testing with
273+
# StreamedResponse is done in test_finalize_integration().
274+
manager = ModelResponsePartsManager()
275+
thinking_tags = ('<think>', '</think>')
276+
277+
list(manager.handle_text_delta(vendor_part_id='content', content='<thi', thinking_tags=thinking_tags))
278+
279+
# Before finalize
280+
assert manager.get_parts() == [] # Buffer not included
281+
282+
# Finalize
283+
list(manager.finalize())
284+
285+
# After finalize
286+
assert manager.get_parts() == snapshot([TextPart(content='<thi', part_kind='text')])

0 commit comments

Comments
 (0)