Skip to content

Commit 5027f3d

Browse files
authored
Streamed response messages (#274)
1 parent 0475da8 commit 5027f3d

File tree

8 files changed

+448
-162
lines changed

8 files changed

+448
-162
lines changed

pydantic_ai_slim/pydantic_ai/_result.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import inspect
44
import sys
55
import types
6-
from collections.abc import Awaitable
6+
from collections.abc import Awaitable, Iterable
77
from dataclasses import dataclass, field
88
from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin
99

@@ -113,14 +113,24 @@ def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultDat
113113

114114
return cls(tools=tools, allow_text_result=allow_text_result)
115115

116+
def find_named_tool(
117+
self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
118+
) -> tuple[_messages.ToolCallPart, ResultTool[ResultData]] | None:
119+
"""Find a tool that matches one of the calls, with a specific name."""
120+
for part in parts:
121+
if isinstance(part, _messages.ToolCallPart):
122+
if part.tool_name == tool_name:
123+
return part, self.tools[tool_name]
124+
116125
def find_tool(
117-
self, message: _messages.ModelResponse
126+
self,
127+
parts: Iterable[_messages.ModelResponsePart],
118128
) -> tuple[_messages.ToolCallPart, ResultTool[ResultData]] | None:
119129
"""Find a tool that matches one of the calls."""
120-
for item in message.parts:
121-
if isinstance(item, _messages.ToolCallPart):
122-
if result := self.tools.get(item.tool_name):
123-
return item, result
130+
for part in parts:
131+
if isinstance(part, _messages.ToolCallPart):
132+
if result := self.tools.get(part.tool_name):
133+
return part, result
124134

125135
def tool_names(self) -> list[str]:
126136
"""Return the names of the tools."""

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 114 additions & 105 deletions
Large diffs are not rendered by default.

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations as _annotations
22

33
from abc import ABC, abstractmethod
4-
from collections.abc import AsyncIterator, Callable
4+
from collections.abc import AsyncIterator, Awaitable, Callable
55
from dataclasses import dataclass, field
66
from datetime import datetime
77
from typing import Generic, TypeVar, cast
@@ -122,7 +122,8 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
122122
_result_schema: _result.ResultSchema[ResultData] | None
123123
_deps: AgentDeps
124124
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]
125-
_on_complete: Callable[[list[_messages.ModelMessage]], None]
125+
_result_tool_name: str | None
126+
_on_complete: Callable[[], Awaitable[None]]
126127
is_complete: bool = field(default=False, init=False)
127128
"""Whether the stream has all been received.
128129
@@ -205,7 +206,7 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None =
205206
combined = await self._validate_text_result(''.join(chunks))
206207
yield combined
207208
lf_span.set_attribute('combined_text', combined)
208-
self._marked_completed(_messages.ModelResponse.from_text(combined))
209+
await self._marked_completed(_messages.ModelResponse.from_text(combined))
209210

210211
async def stream_structured(
211212
self, *, debounce_by: float | None = 0.1
@@ -244,7 +245,7 @@ async def stream_structured(
244245
msg = self._stream_response.get(final=True)
245246
yield msg, True
246247
lf_span.set_attribute('structured_response', msg)
247-
self._marked_completed(msg)
248+
await self._marked_completed(msg)
248249

249250
async def get_data(self) -> ResultData:
250251
"""Stream the whole response, validate and return it."""
@@ -253,11 +254,11 @@ async def get_data(self) -> ResultData:
253254
if isinstance(self._stream_response, models.StreamTextResponse):
254255
text = ''.join(self._stream_response.get(final=True))
255256
text = await self._validate_text_result(text)
256-
self._marked_completed(_messages.ModelResponse.from_text(text))
257+
await self._marked_completed(_messages.ModelResponse.from_text(text))
257258
return cast(ResultData, text)
258259
else:
259260
message = self._stream_response.get(final=True)
260-
self._marked_completed(message)
261+
await self._marked_completed(message)
261262
return await self.validate_structured_result(message)
262263

263264
@property
@@ -282,7 +283,8 @@ async def validate_structured_result(
282283
) -> ResultData:
283284
"""Validate a structured result message."""
284285
assert self._result_schema is not None, 'Expected _result_schema to not be None'
285-
match = self._result_schema.find_tool(message)
286+
assert self._result_tool_name is not None, 'Expected _result_tool_name to not be None'
287+
match = self._result_schema.find_named_tool(message.parts, self._result_tool_name)
286288
if match is None:
287289
raise exceptions.UnexpectedModelBehavior(
288290
f'Invalid message, unable to find tool: {self._result_schema.tool_names()}'
@@ -306,7 +308,7 @@ async def _validate_text_result(self, text: str) -> str:
306308
)
307309
return text
308310

309-
def _marked_completed(self, message: _messages.ModelResponse) -> None:
311+
async def _marked_completed(self, message: _messages.ModelResponse) -> None:
310312
self.is_complete = True
311313
self._all_messages.append(message)
312-
self._on_complete(self._all_messages)
314+
await self._on_complete()

tests/models/test_gemini.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -669,13 +669,6 @@ async def bar(y: str) -> str:
669669
ToolReturnPart(tool_name='bar', content='b', timestamp=IsNow(tz=timezone.utc)),
670670
]
671671
),
672-
ModelRequest(
673-
parts=[
674-
ToolReturnPart(
675-
tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
676-
)
677-
]
678-
),
679672
ModelResponse(
680673
parts=[
681674
ToolCallPart(
@@ -685,6 +678,13 @@ async def bar(y: str) -> str:
685678
],
686679
timestamp=IsNow(tz=timezone.utc),
687680
),
681+
ModelRequest(
682+
parts=[
683+
ToolReturnPart(
684+
tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
685+
)
686+
]
687+
),
688688
]
689689
)
690690
assert tool_calls == snapshot(["foo(x='a')", "bar(y='b')"])

tests/models/test_groq.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -415,19 +415,21 @@ async def test_stream_structured(allow_model_requests: None):
415415
assert result.all_messages() == snapshot(
416416
[
417417
ModelRequest(parts=[UserPromptPart(content='', timestamp=IsNow(tz=timezone.utc))]),
418-
ModelRequest(
419-
parts=[
420-
ToolReturnPart(
421-
tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
422-
)
423-
]
424-
),
425418
ModelResponse(
426419
parts=[
427420
ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"first": "One", "second": "Two"}'))
428421
],
429422
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
430423
),
424+
ModelRequest(
425+
parts=[
426+
ToolReturnPart(
427+
tool_name='final_result',
428+
content='Final result processed.',
429+
timestamp=IsNow(tz=timezone.utc),
430+
)
431+
]
432+
),
431433
]
432434
)
433435

tests/models/test_mistral.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,6 +1327,12 @@ async def get_location(loc_name: str) -> str:
13271327
)
13281328
]
13291329
),
1330+
ModelResponse(
1331+
parts=[
1332+
ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"won": true}'), tool_call_id='1')
1333+
],
1334+
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
1335+
),
13301336
ModelRequest(
13311337
parts=[
13321338
ToolReturnPart(
@@ -1337,12 +1343,6 @@ async def get_location(loc_name: str) -> str:
13371343
)
13381344
]
13391345
),
1340-
ModelResponse(
1341-
parts=[
1342-
ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"won": true}'), tool_call_id='1')
1343-
],
1344-
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
1345-
),
13461346
]
13471347
)
13481348

tests/test_agent.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,13 +1060,13 @@ def another_tool(y: int) -> int: # pragma: no cover
10601060
ModelRequest(
10611061
parts=[
10621062
ToolReturnPart(
1063-
tool_name='final_result',
1064-
content='Final result processed.',
1063+
tool_name='regular_tool',
1064+
content='Tool not executed - a final result was already processed.',
10651065
timestamp=IsNow(tz=timezone.utc),
10661066
),
10671067
ToolReturnPart(
1068-
tool_name='regular_tool',
1069-
content='Tool not executed - a final result was already processed.',
1068+
tool_name='final_result',
1069+
content='Final result processed.',
10701070
timestamp=IsNow(tz=timezone.utc),
10711071
),
10721072
ToolReturnPart(

0 commit comments

Comments
 (0)