2
2
3
3
import inspect
4
4
import types
5
- from collections .abc import AsyncIterator , Iterator , Sequence
6
- from contextlib import ExitStack , contextmanager
5
+ from collections .abc import AsyncIterator , Sequence
6
+ from contextlib import ExitStack , asynccontextmanager
7
7
from dataclasses import dataclass , field
8
8
from functools import cached_property
9
9
from time import perf_counter
@@ -172,24 +172,24 @@ async def main():
172
172
if infer_name and self .name is None :
173
173
self ._infer_name (inspect .currentframe ())
174
174
175
- with self .iter (start_node , state = state , deps = deps , infer_name = infer_name , span = span ) as graph_run :
175
+ async with self .iter (start_node , state = state , deps = deps , infer_name = infer_name , span = span ) as graph_run :
176
176
async for _node in graph_run :
177
177
pass
178
178
179
179
final_result = graph_run .result
180
180
assert final_result is not None , 'GraphRun should have a final result'
181
181
return final_result
182
182
183
- @contextmanager
184
- def iter (
183
+ @asynccontextmanager
184
+ async def iter (
185
185
self : Graph [StateT , DepsT , T ],
186
186
start_node : BaseNode [StateT , DepsT , T ],
187
187
* ,
188
188
state : StateT = None ,
189
189
deps : DepsT = None ,
190
190
infer_name : bool = True ,
191
191
span : LogfireSpan | None = None ,
192
- ) -> Iterator [GraphRun [StateT , DepsT , T ]]:
192
+ ) -> AsyncIterator [GraphRun [StateT , DepsT , T ]]:
193
193
"""A contextmanager which can be used to iterate over the graph's nodes as they are executed.
194
194
195
195
This method returns a `GraphRun` object which can be used to async-iterate over the nodes of this `Graph` as
@@ -569,7 +569,7 @@ class GraphRun(Generic[StateT, DepsT, RunEndT]):
569
569
"""A stateful, async-iterable run of a [`Graph`][pydantic_graph.graph.Graph].
570
570
571
571
You typically get a `GraphRun` instance from calling
572
- `with [my_graph.iter(...)][pydantic_graph.graph.Graph.iter] as graph_run:`. That gives you the ability to iterate
572
+ `async with [my_graph.iter(...)][pydantic_graph.graph.Graph.iter] as graph_run:`. That gives you the ability to iterate
573
573
through nodes as they run, either by `async for` iteration or by repeatedly calling `.next(...)`.
574
574
575
575
Here's an example of iterating over the graph from [above][pydantic_graph.graph.Graph]:
@@ -579,7 +579,7 @@ class GraphRun(Generic[StateT, DepsT, RunEndT]):
579
579
580
580
async def main():
581
581
state = MyState(1)
582
- with never_42_graph.iter(Increment(), state=state) as graph_run:
582
+ async with never_42_graph.iter(Increment(), state=state) as graph_run:
583
583
node_states = [(graph_run.next_node, deepcopy(graph_run.state))]
584
584
async for node in graph_run:
585
585
node_states.append((node, deepcopy(graph_run.state)))
@@ -593,7 +593,7 @@ async def main():
593
593
'''
594
594
595
595
state = MyState(41)
596
- with never_42_graph.iter(Increment(), state=state) as graph_run:
596
+ async with never_42_graph.iter(Increment(), state=state) as graph_run:
597
597
node_states = [(graph_run.next_node, deepcopy(graph_run.state))]
598
598
async for node in graph_run:
599
599
node_states.append((node, deepcopy(graph_run.state)))
@@ -684,7 +684,7 @@ async def next(
684
684
685
685
async def main():
686
686
state = MyState(48)
687
- with never_42_graph.iter(Increment(), state=state) as graph_run:
687
+ async with never_42_graph.iter(Increment(), state=state) as graph_run:
688
688
next_node = graph_run.next_node # start with the first node
689
689
node_states = [(next_node, deepcopy(graph_run.state))]
690
690
0 commit comments