Skip to content

Commit 196d918

Browse files
authored
Use .iter() API to fully replace existing streaming implementation (#951)
1 parent bef69c8 commit 196d918

File tree

8 files changed

+561
-80
lines changed

8 files changed

+561
-80
lines changed

docs/agents.md

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,146 @@ Once the run finishes, `agent_run.final_result` becomes a [`AgentRunResult`][pyd
220220

221221
---
222222

223+
### Streaming
224+
225+
Here is an example of streaming an agent run in combination with `async for` iteration:
226+
227+
```python {title="streaming.py"}
228+
import asyncio
229+
from dataclasses import dataclass
230+
from datetime import date
231+
232+
from pydantic_ai import Agent
233+
from pydantic_ai.messages import (
234+
FinalResultEvent,
235+
FunctionToolCallEvent,
236+
FunctionToolResultEvent,
237+
PartDeltaEvent,
238+
PartStartEvent,
239+
TextPartDelta,
240+
ToolCallPartDelta,
241+
)
242+
from pydantic_ai.tools import RunContext
243+
244+
245+
@dataclass
246+
class WeatherService:
247+
async def get_forecast(self, location: str, forecast_date: date) -> str:
248+
# In real code: call weather API, DB queries, etc.
249+
return f'The forecast in {location} on {forecast_date} is 24°C and sunny.'
250+
251+
async def get_historic_weather(self, location: str, forecast_date: date) -> str:
252+
# In real code: call a historical weather API or DB
253+
return (
254+
f'The weather in {location} on {forecast_date} was 18°C and partly cloudy.'
255+
)
256+
257+
258+
weather_agent = Agent[WeatherService, str](
259+
'openai:gpt-4o',
260+
deps_type=WeatherService,
261+
result_type=str, # We'll produce a final answer as plain text
262+
system_prompt='Providing a weather forecast at the locations the user provides.',
263+
)
264+
265+
266+
@weather_agent.tool
267+
async def weather_forecast(
268+
ctx: RunContext[WeatherService],
269+
location: str,
270+
forecast_date: date,
271+
) -> str:
272+
if forecast_date >= date.today():
273+
return await ctx.deps.get_forecast(location, forecast_date)
274+
else:
275+
return await ctx.deps.get_historic_weather(location, forecast_date)
276+
277+
278+
output_messages: list[str] = []
279+
280+
281+
async def main():
282+
user_prompt = 'What will the weather be like in Paris on Tuesday?'
283+
284+
# Begin a node-by-node, streaming iteration
285+
with weather_agent.iter(user_prompt, deps=WeatherService()) as run:
286+
async for node in run:
287+
if Agent.is_user_prompt_node(node):
288+
# A user prompt node => The user has provided input
289+
output_messages.append(f'=== UserPromptNode: {node.user_prompt} ===')
290+
elif Agent.is_model_request_node(node):
291+
# A model request node => We can stream tokens from the model's request
292+
output_messages.append(
293+
'=== ModelRequestNode: streaming partial request tokens ==='
294+
)
295+
async with node.stream(run.ctx) as request_stream:
296+
async for event in request_stream:
297+
if isinstance(event, PartStartEvent):
298+
output_messages.append(
299+
f'[Request] Starting part {event.index}: {event.part!r}'
300+
)
301+
elif isinstance(event, PartDeltaEvent):
302+
if isinstance(event.delta, TextPartDelta):
303+
output_messages.append(
304+
f'[Request] Part {event.index} text delta: {event.delta.content_delta!r}'
305+
)
306+
elif isinstance(event.delta, ToolCallPartDelta):
307+
output_messages.append(
308+
f'[Request] Part {event.index} args_delta={event.delta.args_delta}'
309+
)
310+
elif isinstance(event, FinalResultEvent):
311+
output_messages.append(
312+
f'[Result] The model produced a final result (tool_name={event.tool_name})'
313+
)
314+
elif Agent.is_handle_response_node(node):
315+
# A handle-response node => The model returned some data, potentially calls a tool
316+
output_messages.append(
317+
'=== HandleResponseNode: streaming partial response & tool usage ==='
318+
)
319+
async with node.stream(run.ctx) as handle_stream:
320+
async for event in handle_stream:
321+
if isinstance(event, FunctionToolCallEvent):
322+
output_messages.append(
323+
f'[Tools] The LLM calls tool={event.part.tool_name!r} with args={event.part.args} (tool_call_id={event.part.tool_call_id!r})'
324+
)
325+
elif isinstance(event, FunctionToolResultEvent):
326+
output_messages.append(
327+
f'[Tools] Tool call {event.tool_call_id!r} returned => {event.result.content}'
328+
)
329+
elif Agent.is_end_node(node):
330+
assert run.result.data == node.data.data
331+
# Once an End node is reached, the agent run is complete
332+
output_messages.append(f'=== Final Agent Output: {run.result.data} ===')
333+
334+
335+
if __name__ == '__main__':
336+
asyncio.run(main())
337+
338+
print(output_messages)
339+
"""
340+
[
341+
'=== ModelRequestNode: streaming partial request tokens ===',
342+
'[Request] Starting part 0: ToolCallPart(tool_name=\'weather_forecast\', args=\'{"location":"Pa\', tool_call_id=\'0001\', part_kind=\'tool-call\')',
343+
'[Request] Part 0 args_delta=ris","forecast_',
344+
'[Request] Part 0 args_delta=date":"2030-01-',
345+
'[Request] Part 0 args_delta=01"}',
346+
'=== HandleResponseNode: streaming partial response & tool usage ===',
347+
'[Tools] The LLM calls tool=\'weather_forecast\' with args={"location":"Paris","forecast_date":"2030-01-01"} (tool_call_id=\'0001\')',
348+
"[Tools] Tool call '0001' returned => The forecast in Paris on 2030-01-01 is 24°C and sunny.",
349+
'=== ModelRequestNode: streaming partial request tokens ===',
350+
"[Request] Starting part 0: TextPart(content='It will be ', part_kind='text')",
351+
'[Result] The model produced a final result (tool_name=None)',
352+
"[Request] Part 0 text delta: 'warm and sunny '",
353+
"[Request] Part 0 text delta: 'in Paris on '",
354+
"[Request] Part 0 text delta: 'Tuesday.'",
355+
'=== HandleResponseNode: streaming partial response & tool usage ===',
356+
'=== Final Agent Output: It will be warm and sunny in Paris on Tuesday. ===',
357+
]
358+
"""
359+
```
360+
361+
---
362+
223363
### Additional Configuration
224364

225365
#### Usage Limits

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22

33
import asyncio
44
import dataclasses
5-
from abc import ABC
65
from collections.abc import AsyncIterator, Iterator, Sequence
76
from contextlib import asynccontextmanager, contextmanager
87
from contextvars import ContextVar
98
from dataclasses import field
109
from typing import Any, Generic, Literal, Union, cast
1110

1211
import logfire_api
13-
from typing_extensions import TypeVar, assert_never
12+
from typing_extensions import TypeGuard, TypeVar, assert_never
1413

1514
from pydantic_graph import BaseNode, Graph, GraphRunContext
1615
from pydantic_graph.nodes import End, NodeRunEndT
@@ -55,6 +54,7 @@
5554
logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
5655

5756
T = TypeVar('T')
57+
S = TypeVar('S')
5858
NoneType = type(None)
5959
EndStrategy = Literal['early', 'exhaustive']
6060
"""The strategy for handling multiple tool calls when a final result is found.
@@ -107,8 +107,31 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
107107
run_span: logfire_api.LogfireSpan
108108

109109

110+
class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
111+
"""The base class for all agent nodes.
112+
113+
Using subclass of `BaseNode` for all nodes reduces the amount of boilerplate of generics everywhere
114+
"""
115+
116+
117+
def is_agent_node(
118+
node: BaseNode[GraphAgentState, GraphAgentDeps[T, Any], result.FinalResult[S]] | End[result.FinalResult[S]],
119+
) -> TypeGuard[AgentNode[T, S]]:
120+
"""Check if the provided node is an instance of `AgentNode`.
121+
122+
Usage:
123+
124+
if is_agent_node(node):
125+
# `node` is an AgentNode
126+
...
127+
128+
This method preserves the generic parameters on the narrowed type, unlike `isinstance(node, AgentNode)`.
129+
"""
130+
return isinstance(node, AgentNode)
131+
132+
110133
@dataclasses.dataclass
111-
class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]], ABC):
134+
class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
112135
user_prompt: str | Sequence[_messages.UserContent]
113136

114137
system_prompts: tuple[str, ...]
@@ -215,7 +238,7 @@ async def add_tool(tool: Tool[DepsT]) -> None:
215238

216239

217240
@dataclasses.dataclass
218-
class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
241+
class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
219242
"""Make a request to the model using the last message in state.message_history."""
220243

221244
request: _messages.ModelRequest
@@ -236,12 +259,30 @@ async def run(
236259

237260
return await self._make_request(ctx)
238261

262+
@asynccontextmanager
263+
async def stream(
264+
self,
265+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
266+
) -> AsyncIterator[result.AgentStream[DepsT, T]]:
267+
async with self._stream(ctx) as streamed_response:
268+
agent_stream = result.AgentStream[DepsT, T](
269+
streamed_response,
270+
ctx.deps.result_schema,
271+
ctx.deps.result_validators,
272+
build_run_context(ctx),
273+
ctx.deps.usage_limits,
274+
)
275+
yield agent_stream
276+
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
277+
# otherwise usage won't be properly counted:
278+
async for _ in agent_stream:
279+
pass
280+
239281
@asynccontextmanager
240282
async def _stream(
241283
self,
242284
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
243285
) -> AsyncIterator[models.StreamedResponse]:
244-
# TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public
245286
assert not self._did_stream, 'stream() should only be called once per node'
246287

247288
model_settings, model_request_parameters = await self._prepare_request(ctx)
@@ -319,7 +360,7 @@ def _finish_handling(
319360

320361

321362
@dataclasses.dataclass
322-
class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
363+
class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]):
323364
"""Process a model response, and decide whether to end the run or make a new request."""
324365

325366
model_response: _messages.ModelResponse
@@ -575,7 +616,7 @@ async def process_function_tools(
575616
for task in done:
576617
index = tasks.index(task)
577618
result = task.result()
578-
yield _messages.FunctionToolResultEvent(result, call_id=call_index_to_event_id[index])
619+
yield _messages.FunctionToolResultEvent(result, tool_call_id=call_index_to_event_id[index])
579620
if isinstance(result, (_messages.ToolReturnPart, _messages.RetryPromptPart)):
580621
results_by_index[index] = result
581622
else:

0 commit comments

Comments
 (0)