Skip to content

Commit 4d0f8ff

Browse files
authored
Make Graph.iter into an _async_ contextmanager (#958)
1 parent c9b0765 commit 4d0f8ff

File tree

5 files changed

+28
-28
lines changed

5 files changed

+28
-28
lines changed

docs/agents.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ agent = Agent('openai:gpt-4o')
114114
async def main():
115115
nodes = []
116116
# Begin an AgentRun, which is an async-iterable over the nodes of the agent's graph
117-
with agent.iter('What is the capital of France?') as agent_run:
117+
async with agent.iter('What is the capital of France?') as agent_run:
118118
async for node in agent_run:
119119
# Each node represents a step in the agent's execution
120120
nodes.append(node)
@@ -163,7 +163,7 @@ agent = Agent('openai:gpt-4o')
163163

164164

165165
async def main():
166-
with agent.iter('What is the capital of France?') as agent_run:
166+
async with agent.iter('What is the capital of France?') as agent_run:
167167
node = agent_run.next_node # (1)!
168168

169169
all_nodes = [node]
@@ -282,7 +282,7 @@ async def main():
282282
user_prompt = 'What will the weather be like in Paris on Tuesday?'
283283

284284
# Begin a node-by-node, streaming iteration
285-
with weather_agent.iter(user_prompt, deps=WeatherService()) as run:
285+
async with weather_agent.iter(user_prompt, deps=WeatherService()) as run:
286286
async for node in run:
287287
if Agent.is_user_prompt_node(node):
288288
# A user prompt node => The user has provided input

docs/graph.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ count_down_graph = Graph(nodes=[CountDown])
679679

680680
async def main():
681681
state = CountDownState(counter=3)
682-
with count_down_graph.iter(CountDown(), state=state) as run: # (1)!
682+
async with count_down_graph.iter(CountDown(), state=state) as run: # (1)!
683683
async for node in run: # (2)!
684684
print('Node:', node)
685685
#> Node: CountDown()
@@ -712,7 +712,7 @@ from count_down import CountDown, CountDownState, count_down_graph
712712

713713
async def main():
714714
state = CountDownState(counter=5)
715-
with count_down_graph.iter(CountDown(), state=state) as run:
715+
async with count_down_graph.iter(CountDown(), state=state) as run:
716716
node = run.next_node # (1)!
717717
while not isinstance(node, End): # (2)!
718718
print('Node:', node)

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ async def main():
294294
"""
295295
if infer_name and self.name is None:
296296
self._infer_name(inspect.currentframe())
297-
with self.iter(
297+
async with self.iter(
298298
user_prompt=user_prompt,
299299
result_type=result_type,
300300
message_history=message_history,
@@ -310,8 +310,8 @@ async def main():
310310
assert (final_result := agent_run.result) is not None, 'The graph run did not finish properly'
311311
return final_result
312312

313-
@contextmanager
314-
def iter(
313+
@asynccontextmanager
314+
async def iter(
315315
self,
316316
user_prompt: str | Sequence[_messages.UserContent],
317317
*,
@@ -323,7 +323,7 @@ def iter(
323323
usage_limits: _usage.UsageLimits | None = None,
324324
usage: _usage.Usage | None = None,
325325
infer_name: bool = True,
326-
) -> Iterator[AgentRun[AgentDepsT, Any]]:
326+
) -> AsyncIterator[AgentRun[AgentDepsT, Any]]:
327327
"""A contextmanager which can be used to iterate over the agent graph's nodes as they are executed.
328328
329329
This method builds an internal agent graph (using system prompts, tools and result schemas) and then returns an
@@ -344,7 +344,7 @@ def iter(
344344
345345
async def main():
346346
nodes = []
347-
with agent.iter('What is the capital of France?') as agent_run:
347+
async with agent.iter('What is the capital of France?') as agent_run:
348348
async for node in agent_run:
349349
nodes.append(node)
350350
print(nodes)
@@ -454,7 +454,7 @@ async def main():
454454
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
455455
)
456456

457-
with graph.iter(
457+
async with graph.iter(
458458
start_node,
459459
state=state,
460460
deps=graph_deps,
@@ -633,7 +633,7 @@ async def main():
633633
self._infer_name(frame.f_back)
634634

635635
yielded = False
636-
with self.iter(
636+
async with self.iter(
637637
user_prompt,
638638
result_type=result_type,
639639
message_history=message_history,
@@ -1217,7 +1217,7 @@ def is_end_node(
12171217
class AgentRun(Generic[AgentDepsT, ResultDataT]):
12181218
"""A stateful, async-iterable run of an [`Agent`][pydantic_ai.agent.Agent].
12191219
1220-
You generally obtain an `AgentRun` instance by calling `with my_agent.iter(...) as agent_run:`.
1220+
You generally obtain an `AgentRun` instance by calling `async with my_agent.iter(...) as agent_run:`.
12211221
12221222
Once you have an instance, you can use it to iterate through the run's nodes as they execute. When an
12231223
[`End`][pydantic_graph.nodes.End] is reached, the run finishes and [`result`][pydantic_ai.agent.AgentRun.result]
@@ -1232,7 +1232,7 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
12321232
async def main():
12331233
nodes = []
12341234
# Iterate through the run, recording each node along the way:
1235-
with agent.iter('What is the capital of France?') as agent_run:
1235+
async with agent.iter('What is the capital of France?') as agent_run:
12361236
async for node in agent_run:
12371237
nodes.append(node)
12381238
print(nodes)
@@ -1346,7 +1346,7 @@ async def next(
13461346
agent = Agent('openai:gpt-4o')
13471347
13481348
async def main():
1349-
with agent.iter('What is the capital of France?') as agent_run:
1349+
async with agent.iter('What is the capital of France?') as agent_run:
13501350
next_node = agent_run.next_node # start with the first node
13511351
nodes = [next_node]
13521352
while not isinstance(next_node, End):

pydantic_graph/pydantic_graph/graph.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import inspect
44
import types
5-
from collections.abc import AsyncIterator, Iterator, Sequence
6-
from contextlib import ExitStack, contextmanager
5+
from collections.abc import AsyncIterator, Sequence
6+
from contextlib import ExitStack, asynccontextmanager
77
from dataclasses import dataclass, field
88
from functools import cached_property
99
from time import perf_counter
@@ -172,24 +172,24 @@ async def main():
172172
if infer_name and self.name is None:
173173
self._infer_name(inspect.currentframe())
174174

175-
with self.iter(start_node, state=state, deps=deps, infer_name=infer_name, span=span) as graph_run:
175+
async with self.iter(start_node, state=state, deps=deps, infer_name=infer_name, span=span) as graph_run:
176176
async for _node in graph_run:
177177
pass
178178

179179
final_result = graph_run.result
180180
assert final_result is not None, 'GraphRun should have a final result'
181181
return final_result
182182

183-
@contextmanager
184-
def iter(
183+
@asynccontextmanager
184+
async def iter(
185185
self: Graph[StateT, DepsT, T],
186186
start_node: BaseNode[StateT, DepsT, T],
187187
*,
188188
state: StateT = None,
189189
deps: DepsT = None,
190190
infer_name: bool = True,
191191
span: LogfireSpan | None = None,
192-
) -> Iterator[GraphRun[StateT, DepsT, T]]:
192+
) -> AsyncIterator[GraphRun[StateT, DepsT, T]]:
193193
"""A contextmanager which can be used to iterate over the graph's nodes as they are executed.
194194
195195
This method returns a `GraphRun` object which can be used to async-iterate over the nodes of this `Graph` as
@@ -569,7 +569,7 @@ class GraphRun(Generic[StateT, DepsT, RunEndT]):
569569
"""A stateful, async-iterable run of a [`Graph`][pydantic_graph.graph.Graph].
570570
571571
You typically get a `GraphRun` instance from calling
572-
`with [my_graph.iter(...)][pydantic_graph.graph.Graph.iter] as graph_run:`. That gives you the ability to iterate
572+
`async with [my_graph.iter(...)][pydantic_graph.graph.Graph.iter] as graph_run:`. That gives you the ability to iterate
573573
through nodes as they run, either by `async for` iteration or by repeatedly calling `.next(...)`.
574574
575575
Here's an example of iterating over the graph from [above][pydantic_graph.graph.Graph]:
@@ -579,7 +579,7 @@ class GraphRun(Generic[StateT, DepsT, RunEndT]):
579579
580580
async def main():
581581
state = MyState(1)
582-
with never_42_graph.iter(Increment(), state=state) as graph_run:
582+
async with never_42_graph.iter(Increment(), state=state) as graph_run:
583583
node_states = [(graph_run.next_node, deepcopy(graph_run.state))]
584584
async for node in graph_run:
585585
node_states.append((node, deepcopy(graph_run.state)))
@@ -593,7 +593,7 @@ async def main():
593593
'''
594594
595595
state = MyState(41)
596-
with never_42_graph.iter(Increment(), state=state) as graph_run:
596+
async with never_42_graph.iter(Increment(), state=state) as graph_run:
597597
node_states = [(graph_run.next_node, deepcopy(graph_run.state))]
598598
async for node in graph_run:
599599
node_states.append((node, deepcopy(graph_run.state)))
@@ -684,7 +684,7 @@ async def next(
684684
685685
async def main():
686686
state = MyState(48)
687-
with never_42_graph.iter(Increment(), state=state) as graph_run:
687+
async with never_42_graph.iter(Increment(), state=state) as graph_run:
688688
next_node = graph_run.next_node # start with the first node
689689
node_states = [(next_node, deepcopy(graph_run.state))]
690690

tests/test_streaming.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,7 @@ def result_validator_simple(data: str) -> str:
761761
messages: list[str] = []
762762

763763
stream_usage: Usage | None = None
764-
with agent.iter('Hello') as run:
764+
async with agent.iter('Hello') as run:
765765
async for node in run:
766766
if agent.is_model_request_node(node):
767767
async with node.stream(run.ctx) as stream:
@@ -800,7 +800,7 @@ def result_validator_simple(data: str) -> str:
800800
run: AgentRun
801801
stream: AgentStream
802802
messages: list[ModelResponse] = []
803-
with agent.iter('Hello') as run:
803+
async with agent.iter('Hello') as run:
804804
async for node in run:
805805
if agent.is_model_request_node(node):
806806
async with node.stream(run.ctx) as stream:
@@ -843,7 +843,7 @@ def result_validator(data: ResultType | NotResultType) -> ResultType | NotResult
843843
return ResultType(value=data.value + ' (validated)')
844844

845845
outputs: list[ResultType] = []
846-
with agent.iter('test') as run:
846+
async with agent.iter('test') as run:
847847
async for node in run:
848848
if agent.is_model_request_node(node):
849849
async with node.stream(run.ctx) as stream:

0 commit comments

Comments
 (0)