Skip to content

Commit e296d38

Browse files
authored
Merge branch 'main' into psl/allow-partial
2 parents bc00460 + dad1994 commit e296d38

File tree

7 files changed

+225
-11
lines changed

7 files changed

+225
-11
lines changed

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
3535
'text/html',
3636
'text/markdown',
37+
'application/msword',
3738
'application/vnd.ms-excel',
3839
]
3940
VideoMediaType: TypeAlias = Literal[
@@ -434,8 +435,12 @@ def _infer_media_type(self) -> str:
434435
return 'application/pdf'
435436
elif self.url.endswith('.rtf'):
436437
return 'application/rtf'
438+
elif self.url.endswith('.doc'):
439+
return 'application/msword'
437440
elif self.url.endswith('.docx'):
438441
return 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
442+
elif self.url.endswith('.xls'):
443+
return 'application/vnd.ms-excel'
439444
elif self.url.endswith('.xlsx'):
440445
return 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
441446

@@ -645,6 +650,7 @@ class ToolReturn:
645650
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': 'xlsx',
646651
'text/html': 'html',
647652
'text/markdown': 'md',
653+
'application/msword': 'doc',
648654
'application/vnd.ms-excel': 'xls',
649655
}
650656
_audio_format_lookup: dict[str, AudioFormat] = {
@@ -882,7 +888,10 @@ def model_response(self) -> str:
882888
description = self.content
883889
else:
884890
json_errors = error_details_ta.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2)
885-
description = f'{len(self.content)} validation errors: {json_errors.decode()}'
891+
plural = isinstance(self.content, list) and len(self.content) != 1
892+
description = (
893+
f'{len(self.content)} validation error{"s" if plural else ""}:\n```json\n{json_errors.decode()}\n```'
894+
)
886895
return f'{description}\n\nFix the errors and try again.'
887896

888897
def otel_event(self, settings: InstrumentationSettings) -> Event:

pydantic_ai_slim/pydantic_ai/ui/_event_stream.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,31 +404,31 @@ async def before_request(self) -> AsyncIterator[EventT]:
404404
405405
Override this to inject custom events at the start of the request.
406406
"""
407-
return
407+
return # pragma: lax no cover
408408
yield # Make this an async generator
409409

410410
async def after_request(self) -> AsyncIterator[EventT]:
411411
"""Yield events after a model request is processed.
412412
413413
Override this to inject custom events at the end of the request.
414414
"""
415-
return
415+
return # pragma: lax no cover
416416
yield # Make this an async generator
417417

418418
async def before_response(self) -> AsyncIterator[EventT]:
419419
"""Yield events before a model response is processed.
420420
421421
Override this to inject custom events at the start of the response.
422422
"""
423-
return
423+
return # pragma: no cover
424424
yield # Make this an async generator
425425

426426
async def after_response(self) -> AsyncIterator[EventT]:
427427
"""Yield events after a model response is processed.
428428
429429
Override this to inject custom events at the end of the response.
430430
"""
431-
return
431+
return # pragma: lax no cover
432432
yield # Make this an async generator
433433

434434
async def handle_text_start(self, part: TextPart, follows_text: bool = False) -> AsyncIterator[EventT]:

pydantic_ai_slim/pydantic_ai/ui/ag_ui/_event_stream.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,13 @@ async def before_stream(self) -> AsyncIterator[BaseEvent]:
9292
run_id=self.run_input.run_id,
9393
)
9494

95+
async def before_response(self) -> AsyncIterator[BaseEvent]:
96+
# Prevent parts from a subsequent response being tied to parts from an earlier response.
97+
# See https://github.com/pydantic/pydantic-ai/issues/3316
98+
self.new_message_id()
99+
return
100+
yield # Make this an async generator
101+
95102
async def after_stream(self) -> AsyncIterator[BaseEvent]:
96103
if not self._error:
97104
yield RunFinishedEvent(
@@ -167,9 +174,11 @@ async def _handle_tool_call_start(
167174
self, part: ToolCallPart | BuiltinToolCallPart, tool_call_id: str | None = None
168175
) -> AsyncIterator[BaseEvent]:
169176
tool_call_id = tool_call_id or part.tool_call_id
170-
message_id = self.message_id or self.new_message_id()
177+
parent_message_id = self.message_id
171178

172-
yield ToolCallStartEvent(tool_call_id=tool_call_id, tool_call_name=part.tool_name, parent_message_id=message_id)
179+
yield ToolCallStartEvent(
180+
tool_call_id=tool_call_id, tool_call_name=part.tool_name, parent_message_id=parent_message_id
181+
)
173182
if part.args:
174183
yield ToolCallArgsEvent(tool_call_id=tool_call_id, delta=part.args_as_json_str())
175184

tests/models/cassettes/test_gemini/test_gemini_drop_exclusive_maximum.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ interactions:
259259
name: get_chinese_zodiac
260260
response:
261261
call_error: |-
262-
1 validation errors: [
262+
1 validation error: [
263263
{
264264
"type": "greater_than",
265265
"loc": [

tests/test_ag_ui.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from pydantic_ai import (
2020
BuiltinToolCallPart,
2121
BuiltinToolReturnPart,
22+
FunctionToolCallEvent,
23+
FunctionToolResultEvent,
2224
ModelMessage,
2325
ModelRequest,
2426
ModelResponse,
@@ -29,6 +31,7 @@
2931
TextPart,
3032
TextPartDelta,
3133
ToolCallPart,
34+
ToolCallPartDelta,
3235
ToolReturn,
3336
ToolReturnPart,
3437
UserPromptPart,
@@ -1661,6 +1664,194 @@ async def event_generator():
16611664
)
16621665

16631666

1667+
async def test_event_stream_multiple_responses_with_tool_calls():
1668+
async def event_generator():
1669+
yield PartStartEvent(index=0, part=TextPart(content='Hello'))
1670+
yield PartDeltaEvent(index=0, delta=TextPartDelta(content_delta=' world'))
1671+
yield PartEndEvent(index=0, part=TextPart(content='Hello world'), next_part_kind='tool-call')
1672+
1673+
yield PartStartEvent(
1674+
index=1,
1675+
part=ToolCallPart(tool_name='tool_call_1', args='{}', tool_call_id='tool_call_1'),
1676+
previous_part_kind='text',
1677+
)
1678+
yield PartDeltaEvent(
1679+
index=1, delta=ToolCallPartDelta(args_delta='{"query": "Hello world"}', tool_call_id='tool_call_1')
1680+
)
1681+
yield PartEndEvent(
1682+
index=1,
1683+
part=ToolCallPart(tool_name='tool_call_1', args='{"query": "Hello world"}', tool_call_id='tool_call_1'),
1684+
next_part_kind='tool-call',
1685+
)
1686+
1687+
yield PartStartEvent(
1688+
index=2,
1689+
part=ToolCallPart(tool_name='tool_call_2', args='{}', tool_call_id='tool_call_2'),
1690+
previous_part_kind='tool-call',
1691+
)
1692+
yield PartDeltaEvent(
1693+
index=2, delta=ToolCallPartDelta(args_delta='{"query": "Goodbye world"}', tool_call_id='tool_call_2')
1694+
)
1695+
yield PartEndEvent(
1696+
index=2,
1697+
part=ToolCallPart(tool_name='tool_call_2', args='{"query": "Hello world"}', tool_call_id='tool_call_2'),
1698+
next_part_kind=None,
1699+
)
1700+
1701+
yield FunctionToolCallEvent(
1702+
part=ToolCallPart(tool_name='tool_call_1', args='{"query": "Hello world"}', tool_call_id='tool_call_1')
1703+
)
1704+
yield FunctionToolCallEvent(
1705+
part=ToolCallPart(tool_name='tool_call_2', args='{"query": "Goodbye world"}', tool_call_id='tool_call_2')
1706+
)
1707+
1708+
yield FunctionToolResultEvent(
1709+
result=ToolReturnPart(tool_name='tool_call_1', content='Hi!', tool_call_id='tool_call_1')
1710+
)
1711+
yield FunctionToolResultEvent(
1712+
result=ToolReturnPart(tool_name='tool_call_2', content='Bye!', tool_call_id='tool_call_2')
1713+
)
1714+
1715+
yield PartStartEvent(
1716+
index=0,
1717+
part=ToolCallPart(tool_name='tool_call_3', args='{}', tool_call_id='tool_call_3'),
1718+
previous_part_kind=None,
1719+
)
1720+
yield PartDeltaEvent(
1721+
index=0, delta=ToolCallPartDelta(args_delta='{"query": "Hello world"}', tool_call_id='tool_call_3')
1722+
)
1723+
yield PartEndEvent(
1724+
index=0,
1725+
part=ToolCallPart(tool_name='tool_call_3', args='{"query": "Hello world"}', tool_call_id='tool_call_3'),
1726+
next_part_kind='tool-call',
1727+
)
1728+
1729+
yield PartStartEvent(
1730+
index=1,
1731+
part=ToolCallPart(tool_name='tool_call_4', args='{}', tool_call_id='tool_call_4'),
1732+
previous_part_kind='tool-call',
1733+
)
1734+
yield PartDeltaEvent(
1735+
index=1, delta=ToolCallPartDelta(args_delta='{"query": "Goodbye world"}', tool_call_id='tool_call_4')
1736+
)
1737+
yield PartEndEvent(
1738+
index=1,
1739+
part=ToolCallPart(tool_name='tool_call_4', args='{"query": "Goodbye world"}', tool_call_id='tool_call_4'),
1740+
next_part_kind=None,
1741+
)
1742+
1743+
yield FunctionToolCallEvent(
1744+
part=ToolCallPart(tool_name='tool_call_3', args='{"query": "Hello world"}', tool_call_id='tool_call_3')
1745+
)
1746+
yield FunctionToolCallEvent(
1747+
part=ToolCallPart(tool_name='tool_call_4', args='{"query": "Goodbye world"}', tool_call_id='tool_call_4')
1748+
)
1749+
1750+
yield FunctionToolResultEvent(
1751+
result=ToolReturnPart(tool_name='tool_call_3', content='Hi!', tool_call_id='tool_call_3')
1752+
)
1753+
yield FunctionToolResultEvent(
1754+
result=ToolReturnPart(tool_name='tool_call_4', content='Bye!', tool_call_id='tool_call_4')
1755+
)
1756+
1757+
run_input = create_input(
1758+
UserMessage(
1759+
id='msg_1',
1760+
content='Tell me about Hello World',
1761+
),
1762+
)
1763+
event_stream = AGUIEventStream(run_input=run_input)
1764+
events = [
1765+
json.loads(event.removeprefix('data: '))
1766+
async for event in event_stream.encode_stream(event_stream.transform_stream(event_generator()))
1767+
]
1768+
1769+
assert events == snapshot(
1770+
[
1771+
{
1772+
'type': 'RUN_STARTED',
1773+
'threadId': (thread_id := IsSameStr()),
1774+
'runId': (run_id := IsSameStr()),
1775+
},
1776+
{'type': 'TEXT_MESSAGE_START', 'messageId': (message_id := IsSameStr()), 'role': 'assistant'},
1777+
{'type': 'TEXT_MESSAGE_CONTENT', 'messageId': message_id, 'delta': 'Hello'},
1778+
{'type': 'TEXT_MESSAGE_CONTENT', 'messageId': message_id, 'delta': ' world'},
1779+
{'type': 'TEXT_MESSAGE_END', 'messageId': message_id},
1780+
{
1781+
'type': 'TOOL_CALL_START',
1782+
'toolCallId': 'tool_call_1',
1783+
'toolCallName': 'tool_call_1',
1784+
'parentMessageId': message_id,
1785+
},
1786+
{'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_1', 'delta': '{}'},
1787+
{'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_1', 'delta': '{"query": "Hello world"}'},
1788+
{'type': 'TOOL_CALL_END', 'toolCallId': 'tool_call_1'},
1789+
{
1790+
'type': 'TOOL_CALL_START',
1791+
'toolCallId': 'tool_call_2',
1792+
'toolCallName': 'tool_call_2',
1793+
'parentMessageId': message_id,
1794+
},
1795+
{'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_2', 'delta': '{}'},
1796+
{'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_2', 'delta': '{"query": "Goodbye world"}'},
1797+
{'type': 'TOOL_CALL_END', 'toolCallId': 'tool_call_2'},
1798+
{
1799+
'type': 'TOOL_CALL_RESULT',
1800+
'messageId': IsStr(),
1801+
'toolCallId': 'tool_call_1',
1802+
'content': 'Hi!',
1803+
'role': 'tool',
1804+
},
1805+
{
1806+
'type': 'TOOL_CALL_RESULT',
1807+
'messageId': (result_message_id := IsSameStr()),
1808+
'toolCallId': 'tool_call_2',
1809+
'content': 'Bye!',
1810+
'role': 'tool',
1811+
},
1812+
{
1813+
'type': 'TOOL_CALL_START',
1814+
'toolCallId': 'tool_call_3',
1815+
'toolCallName': 'tool_call_3',
1816+
'parentMessageId': (new_message_id := IsSameStr()),
1817+
},
1818+
{'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_3', 'delta': '{}'},
1819+
{'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_3', 'delta': '{"query": "Hello world"}'},
1820+
{'type': 'TOOL_CALL_END', 'toolCallId': 'tool_call_3'},
1821+
{
1822+
'type': 'TOOL_CALL_START',
1823+
'toolCallId': 'tool_call_4',
1824+
'toolCallName': 'tool_call_4',
1825+
'parentMessageId': new_message_id,
1826+
},
1827+
{'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_4', 'delta': '{}'},
1828+
{'type': 'TOOL_CALL_ARGS', 'toolCallId': 'tool_call_4', 'delta': '{"query": "Goodbye world"}'},
1829+
{'type': 'TOOL_CALL_END', 'toolCallId': 'tool_call_4'},
1830+
{
1831+
'type': 'TOOL_CALL_RESULT',
1832+
'messageId': IsStr(),
1833+
'toolCallId': 'tool_call_3',
1834+
'content': 'Hi!',
1835+
'role': 'tool',
1836+
},
1837+
{
1838+
'type': 'TOOL_CALL_RESULT',
1839+
'messageId': IsStr(),
1840+
'toolCallId': 'tool_call_4',
1841+
'content': 'Bye!',
1842+
'role': 'tool',
1843+
},
1844+
{
1845+
'type': 'RUN_FINISHED',
1846+
'threadId': thread_id,
1847+
'runId': run_id,
1848+
},
1849+
]
1850+
)
1851+
1852+
assert result_message_id != new_message_id
1853+
1854+
16641855
async def test_handle_ag_ui_request():
16651856
agent = Agent(model=TestModel())
16661857
run_input = create_input(

tests/test_agent.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse
201201
),
202202
ModelResponse(
203203
parts=[ToolCallPart(tool_name='final_result', args='{"a": 42, "b": "foo"}', tool_call_id=IsStr())],
204-
usage=RequestUsage(input_tokens=87, output_tokens=14),
204+
usage=RequestUsage(input_tokens=89, output_tokens=14),
205205
model_name='function:return_model:',
206206
timestamp=IsNow(tz=timezone.utc),
207207
),
@@ -260,7 +260,9 @@ def check_b(cls, v: str) -> str:
260260
retry_prompt = user_retry.parts[0]
261261
assert isinstance(retry_prompt, RetryPromptPart)
262262
assert retry_prompt.model_response() == snapshot("""\
263-
1 validation errors: [
263+
1 validation error:
264+
```json
265+
[
264266
{
265267
"type": "value_error",
266268
"loc": [
@@ -270,6 +272,7 @@ def check_b(cls, v: str) -> str:
270272
"input": "foo"
271273
}
272274
]
275+
```
273276
274277
Fix the errors and try again.""")
275278

@@ -1875,7 +1878,7 @@ class CityLocation(BaseModel):
18751878
),
18761879
ModelResponse(
18771880
parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')],
1878-
usage=RequestUsage(input_tokens=85, output_tokens=12),
1881+
usage=RequestUsage(input_tokens=87, output_tokens=12),
18791882
model_name='function:return_city_location:',
18801883
timestamp=IsDatetime(),
18811884
),

tests/test_messages.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def test_binary_content_video(media_type: str, format: str):
147147
('application/pdf', 'pdf'),
148148
('text/plain', 'txt'),
149149
('text/csv', 'csv'),
150+
('application/msword', 'doc'),
150151
('application/vnd.openxmlformats-officedocument.wordprocessingml.document', 'docx'),
151152
('application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', 'xlsx'),
152153
('text/html', 'html'),
@@ -209,6 +210,7 @@ def test_image_url_invalid():
209210
pytest.param(DocumentUrl('foobar.pdf'), 'application/pdf', 'pdf', id='pdf'),
210211
pytest.param(DocumentUrl('foobar.txt'), 'text/plain', 'txt', id='txt'),
211212
pytest.param(DocumentUrl('foobar.csv'), 'text/csv', 'csv', id='csv'),
213+
pytest.param(DocumentUrl('foobar.doc'), 'application/msword', 'doc', id='doc'),
212214
pytest.param(
213215
DocumentUrl('foobar.docx'),
214216
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',

0 commit comments

Comments
 (0)