@@ -607,6 +607,7 @@ async def main():
607
607
print(node_states)
608
608
'''
609
609
[
610
+ (Increment(), MyState(number=1)),
610
611
(Increment(), MyState(number=1)),
611
612
(Check42(), MyState(number=2)),
612
613
(End(data=2), MyState(number=2)),
@@ -621,6 +622,7 @@ async def main():
621
622
print(node_states)
622
623
'''
623
624
[
625
+ (Increment(), MyState(number=41)),
624
626
(Increment(), MyState(number=41)),
625
627
(Check42(), MyState(number=42)),
626
628
(Increment(), MyState(number=42)),
@@ -665,6 +667,7 @@ def __init__(
665
667
self .deps = deps
666
668
667
669
self ._next_node : BaseNode [StateT , DepsT , RunEndT ] | End [RunEndT ] = start_node
670
+ self ._is_started : bool = False
668
671
669
672
@property
670
673
def next_node (self ) -> BaseNode [StateT , DepsT , RunEndT ] | End [RunEndT ]:
@@ -777,8 +780,13 @@ def __aiter__(self) -> AsyncIterator[BaseNode[StateT, DepsT, RunEndT] | End[RunE
777
780
778
781
async def __anext__ (self ) -> BaseNode [StateT , DepsT , RunEndT ] | End [RunEndT ]:
779
782
"""Use the last returned node as the input to `Graph.next`."""
783
+ if not self ._is_started :
784
+ self ._is_started = True
785
+ return self ._next_node
786
+
780
787
if isinstance (self ._next_node , End ):
781
788
raise StopAsyncIteration
789
+
782
790
return await self .next (self ._next_node )
783
791
784
792
def __repr__ (self ) -> str :
0 commit comments