1010import inspect
1111import types
1212import uuid
13- from collections .abc import AsyncGenerator , AsyncIterator , Iterable , Sequence
13+ from collections .abc import AsyncGenerator , AsyncIterable , AsyncIterator , Iterable , Sequence
1414from contextlib import AbstractContextManager , ExitStack , asynccontextmanager
1515from dataclasses import dataclass , field
1616from typing import TYPE_CHECKING , Any , Generic , Literal , TypeGuard , cast , get_args , get_origin , overload
1717
18- import anyio
19- from anyio import CancelScope , WouldBlock , create_memory_object_stream , create_task_group
18+ from anyio import CancelScope , create_memory_object_stream , create_task_group
2019from anyio .abc import TaskGroup
2120from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
2221from typing_extensions import TypeVar , assert_never
@@ -377,16 +376,17 @@ def __init__(
377376 run_id = GraphRunID (str (uuid .uuid4 ()))
378377 initial_fork_stack : ForkStack = (ForkStackItem (StartNode .id , NodeRunID (run_id ), 0 ),)
379378 self ._first_task = GraphTask (node_id = StartNode .id , inputs = inputs , fork_stack = initial_fork_stack )
380- self ._iterator = _GraphIterator [StateT , DepsT , OutputT ](self .graph , self .state , self .deps ).iter_graph (
381- self ._first_task
382- )
379+ self ._iterator_instance = _GraphIterator [StateT , DepsT , OutputT ](self .graph , self .state , self .deps )
380+ self ._iterator = self ._iterator_instance .iter_graph (self ._first_task )
383381
384382 self .__traceparent = traceparent
385383
386384 async def __aenter__ (self ):
387385 return self
388386
389387 async def __aexit__ (self , exc_type : Any , exc_val : Any , exc_tb : Any ):
388+ self ._iterator_instance .iter_stream_sender .close ()
389+ self ._iterator_instance .iter_stream_receiver .close ()
390390 await self ._iterator .aclose ()
391391
392392 @overload
@@ -472,10 +472,17 @@ def output(self) -> OutputT | None:
472472 return None
473473
474474
475+ @dataclass
476+ class _GraphTaskAsyncIterable :
477+ iterable : AsyncIterable [Sequence [GraphTask ]]
478+ fork_stack : ForkStack
479+
480+
475481@dataclass
476482class _GraphTaskResult :
477483 source : GraphTask
478- result : EndMarker [Any ] | Sequence [GraphTask ]
484+ result : EndMarker [Any ] | Sequence [GraphTask ] | JoinItem
485+ source_is_finished : bool = True
479486
480487
481488@dataclass
@@ -486,6 +493,8 @@ class _GraphIterator(Generic[StateT, DepsT, OutputT]):
486493
487494 cancel_scopes : dict [TaskID , CancelScope ] = field (init = False )
488495 active_tasks : dict [TaskID , GraphTask ] = field (init = False )
496+ pending_task_results : set [TaskID ] = field (init = False )
497+ cancelled_tasks : set [TaskID ] = field (init = False )
489498 active_reducers : dict [tuple [JoinID , NodeRunID ], JoinState ] = field (init = False )
490499 iter_stream_sender : MemoryObjectSendStream [_GraphTaskResult ] = field (init = False )
491500 iter_stream_receiver : MemoryObjectReceiveStream [_GraphTaskResult ] = field (init = False )
@@ -497,6 +506,9 @@ def __post_init__(self):
497506 self .active_reducers = {}
498507 self .iter_stream_sender , self .iter_stream_receiver = create_memory_object_stream [_GraphTaskResult ]()
499508
509+ self .pending_task_results = set ()
510+ self .cancelled_tasks = set ()
511+
500512 @property
501513 def task_group (self ) -> TaskGroup :
502514 if self ._task_group is None :
@@ -516,30 +528,70 @@ async def iter_graph( # noqa C901
516528 # Handle task results
517529 async with self .iter_stream_receiver :
518530 while self .active_tasks or self .active_reducers :
519- while self .active_tasks :
520- try :
521- task_result = self .iter_stream_receiver .receive_nowait ()
522- except WouldBlock :
523- await anyio .sleep (0.0 )
531+ async for task_result in self .iter_stream_receiver :
532+ # If we encounter a mock task, add it to the active tasks to ensure we don't proceed until everything downstream is handled
533+ if (
534+ not task_result .source_is_finished
535+ and task_result .source .task_id not in self .active_tasks
536+ ):
537+ self .active_tasks [task_result .source .task_id ] = task_result .source
538+
539+ if task_result .source .task_id in self .cancelled_tasks :
540+ if task_result .source_is_finished :
541+ self .cancelled_tasks .remove (task_result .source .task_id )
524542 continue
525543
526- maybe_overridden_result = yield task_result .result
544+ if task_result .source_is_finished :
545+ self .pending_task_results .discard (task_result .source .task_id )
546+
547+ if isinstance (task_result .result , JoinItem ):
548+ maybe_overridden_result = task_result .result
549+ else :
550+ maybe_overridden_result = yield task_result .result
527551 if isinstance (maybe_overridden_result , EndMarker ):
528552 self .task_group .cancel_scope .cancel ()
529553 return
530- for new_task in maybe_overridden_result :
531- self .active_tasks [new_task .task_id ] = new_task
532- await self ._finish_task (task_result .source .task_id )
554+ elif isinstance (maybe_overridden_result , JoinItem ):
555+ result = maybe_overridden_result
556+ parent_fork_id = self .graph .get_parent_fork (result .join_id ).fork_id
557+ for i , x in enumerate (result .fork_stack [::- 1 ]):
558+ if x .fork_id == parent_fork_id :
559+ downstream_fork_stack = result .fork_stack [: len (result .fork_stack ) - i ]
560+ fork_run_id = x .node_run_id
561+ break
562+ else : # pragma: no cover
563+ raise RuntimeError ('Parent fork run not found' )
564+
565+ join_node = self .graph .nodes [result .join_id ]
566+ assert isinstance (join_node , Join ), f'Expected a `Join` but got { join_node } '
567+ join_state = self .active_reducers .get ((result .join_id , fork_run_id ))
568+ if join_state is None :
569+ current = join_node .initial_factory ()
570+ join_state = self .active_reducers [(result .join_id , fork_run_id )] = JoinState (
571+ current , downstream_fork_stack
572+ )
573+ context = ReducerContext (state = self .state , deps = self .deps , join_state = join_state )
574+ join_state .current = join_node .reduce (context , join_state .current , result .inputs )
575+ if join_state .cancelled_sibling_tasks :
576+ await self ._cancel_sibling_tasks (parent_fork_id , fork_run_id )
577+ if task_result .source_is_finished :
578+ await self ._finish_task (task_result .source .task_id )
579+ else :
580+ for new_task in maybe_overridden_result :
581+ self .active_tasks [new_task .task_id ] = new_task
582+ if task_result .source_is_finished :
583+ await self ._finish_task (task_result .source .task_id )
533584
534585 tasks_by_id_values = list (self .active_tasks .values ())
535586 join_tasks : list [GraphTask ] = []
587+
536588 for join_id , fork_run_id in self ._get_completed_fork_runs (
537589 task_result .source , tasks_by_id_values
538590 ):
539591 join_state = self .active_reducers .pop ((join_id , fork_run_id ))
540592 join_node = self .graph .nodes [join_id ]
541593 assert isinstance (join_node , Join ), f'Expected a `Join` but got { join_node } '
542- new_tasks = self ._handle_edges (
594+ new_tasks = self ._handle_non_fork_edges (
543595 join_node , join_state .current , join_state .downstream_fork_stack
544596 )
545597 join_tasks .extend (new_tasks )
@@ -548,14 +600,17 @@ async def iter_graph( # noqa C901
548600 self .active_tasks [new_task .task_id ] = new_task
549601 self ._handle_execution_request (join_tasks )
550602
551- if not isinstance (task_result .result , EndMarker ):
552- new_task_ids = {t .task_id for t in maybe_overridden_result }
553- for t in task_result .result :
554- if t .task_id not in new_task_ids :
555- await self ._finish_task (
556- t .task_id
557- ) # TODO: Rename to cancel_task or something, instead of "finish", these didn't really get started
558- self ._handle_execution_request (maybe_overridden_result )
603+ if isinstance (maybe_overridden_result , Sequence ):
604+ if isinstance (task_result .result , Sequence ):
605+ new_task_ids = {t .task_id for t in maybe_overridden_result }
606+ for t in task_result .result :
607+ if t .task_id not in new_task_ids :
608+ await self ._finish_task (t .task_id )
609+ self ._handle_execution_request (maybe_overridden_result )
610+
611+ if not self .active_tasks :
612+ # if there are no active tasks, we'll be waiting forever for the next result..
613+ break
559614
560615 if self .active_reducers : # pragma: no branch
561616 # In this case, there are no pending tasks. We can therefore finalize all active reducers whose
@@ -577,7 +632,7 @@ async def iter_graph( # noqa C901
577632 ) # we're handling it now, so we can pop it
578633 join_node = self .graph .nodes [join_id ]
579634 assert isinstance (join_node , Join ), f'Expected a `Join` but got { join_node } '
580- new_tasks = self ._handle_edges (
635+ new_tasks = self ._handle_non_fork_edges (
581636 join_node , join_state .current , join_state .downstream_fork_stack
582637 )
583638 maybe_overridden_result = yield new_tasks
@@ -620,45 +675,18 @@ async def _run_tracked_task(self, t_: GraphTask):
620675 with CancelScope () as scope :
621676 self .cancel_scopes [t_ .task_id ] = scope
622677 result = await self ._run_task (t_ )
623-
624- if isinstance (result , EndMarker ):
625- await self .iter_stream_sender .send (_GraphTaskResult (t_ , result ))
626- return
627-
628- new_tasks : list [GraphTask ] = []
629- if isinstance (result , JoinItem ):
630- parent_fork_id = self .graph .get_parent_fork (result .join_id ).fork_id
631- for i , x in enumerate (result .fork_stack [::- 1 ]):
632- if x .fork_id == parent_fork_id :
633- downstream_fork_stack = result .fork_stack [: len (result .fork_stack ) - i ]
634- fork_run_id = x .node_run_id
635- break
636- else : # pragma: no cover
637- raise RuntimeError ('Parent fork run not found' )
638-
639- join_node = self .graph .nodes [result .join_id ]
640- assert isinstance (join_node , Join ), f'Expected a `Join` but got { join_node } '
641- join_state = self .active_reducers .get ((result .join_id , fork_run_id ))
642- if join_state is None :
643- current = join_node .initial_factory ()
644- join_state = self .active_reducers [(result .join_id , fork_run_id )] = JoinState (
645- current , downstream_fork_stack
646- )
647-
648- context = ReducerContext (state = self .state , deps = self .deps , join_state = join_state )
649- join_state .current = join_node .reduce (context , join_state .current , result .inputs )
650- if join_state .cancelled_sibling_tasks :
651- await self ._cancel_sibling_tasks (parent_fork_id , fork_run_id )
678+ if isinstance (result , _GraphTaskAsyncIterable ):
679+ async for new_tasks in result .iterable :
680+ await self .iter_stream_sender .send (_GraphTaskResult (t_ , new_tasks , False ))
652681 await self .iter_stream_sender .send (_GraphTaskResult (t_ , []))
653682 else :
654- new_tasks .extend (result )
655-
656- await self .iter_stream_sender .send (_GraphTaskResult (t_ , new_tasks ))
683+ self .pending_task_results .add (t_ .task_id )
684+ await self .iter_stream_sender .send (_GraphTaskResult (t_ , result ))
657685
658686 async def _run_task (
659687 self ,
660688 task : GraphTask ,
661- ) -> EndMarker [OutputT ] | Sequence [GraphTask ] | JoinItem :
689+ ) -> EndMarker [OutputT ] | Sequence [GraphTask ] | _GraphTaskAsyncIterable | JoinItem :
662690 state = self .state
663691 deps = self .deps
664692
@@ -769,53 +797,63 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen
769797 else :
770798 assert_never (item )
771799
772- def _handle_edges (self , node : AnyNode , inputs : Any , fork_stack : ForkStack ) -> Sequence [GraphTask ]:
800+ def _handle_edges (
801+ self , node : AnyNode , inputs : Any , fork_stack : ForkStack
802+ ) -> Sequence [GraphTask ] | _GraphTaskAsyncIterable :
803+ if isinstance (node , Fork ):
804+ return self ._handle_fork_edges (node , inputs , fork_stack )
805+ else :
806+ return self ._handle_non_fork_edges (node , inputs , fork_stack )
807+
808+ def _handle_non_fork_edges (self , node : AnyNode , inputs : Any , fork_stack : ForkStack ) -> Sequence [GraphTask ]:
809+ # TODO: Replace `node` with a type that implies it is not a Fork
810+ edges = self .graph .edges_by_source .get (node .id , [])
811+ assert len (edges ) == 1 # this should have already been ensured during graph building
812+ return self ._handle_path (edges [0 ], inputs , fork_stack )
813+
814+ def _handle_fork_edges (
815+ self , node : Fork [Any , Any ], inputs : Any , fork_stack : ForkStack
816+ ) -> Sequence [GraphTask ] | _GraphTaskAsyncIterable :
773817 edges = self .graph .edges_by_source .get (node .id , [])
774818 assert len (edges ) == 1 or (isinstance (node , Fork ) and not node .is_map ), (
775819 edges ,
776820 node .id ,
777821 ) # this should have already been ensured during graph building
778822
779823 new_tasks : list [GraphTask ] = []
824+ node_run_id = NodeRunID (str (uuid .uuid4 ()))
825+ if node .is_map :
826+ # If the map specifies a downstream join id, eagerly create a join state for it
827+ if (join_id := node .downstream_join_id ) is not None :
828+ join_node = self .graph .nodes [join_id ]
829+ assert isinstance (join_node , Join )
830+ self .active_reducers [(join_id , node_run_id )] = JoinState (join_node .initial_factory (), fork_stack )
831+
832+ # Eagerly raise a clear error if the input value is not iterable as expected
833+ if _is_any_iterable (inputs ):
834+ for thread_index , input_item in enumerate (inputs ):
835+ item_tasks = self ._handle_path (
836+ edges [0 ], input_item , fork_stack + (ForkStackItem (node .id , node_run_id , thread_index ),)
837+ )
838+ new_tasks += item_tasks
839+ elif _is_any_async_iterable (inputs ):
780840
781- if isinstance (node , Fork ):
782- node_run_id = NodeRunID (str (uuid .uuid4 ()))
783- if node .is_map :
784- # If the map specifies a downstream join id, eagerly create a join state for it
785- if (join_id := node .downstream_join_id ) is not None :
786- join_node = self .graph .nodes [join_id ]
787- assert isinstance (join_node , Join )
788- self .active_reducers [(join_id , node_run_id )] = JoinState (join_node .initial_factory (), fork_stack )
789-
790- # Eagerly raise a clear error if the input value is not iterable as expected
791- if _is_any_iterable (inputs ):
792- for thread_index , input_item in enumerate (inputs ):
841+ async def handle_async_iterable () -> AsyncIterator [Sequence [GraphTask ]]:
842+ thread_index = 0
843+ async for input_item in inputs :
793844 item_tasks = self ._handle_path (
794845 edges [0 ], input_item , fork_stack + (ForkStackItem (node .id , node_run_id , thread_index ),)
795846 )
796- new_tasks += item_tasks
797- # elif isinstance(inputs, AsyncIterable):
798- #
799- # async def handle_async_iterable():
800- # thread_index = 0
801- # async for input_item in inputs:
802- # item_tasks = self._handle_path(
803- # edges[0], input_item, fork_stack + (ForkStackItem(node.id, node_run_id, thread_index),)
804- # )
805- # yield item_tasks
806- # thread_index += 1
807- #
808- # task = GraphTask(node.id, inputs, fork_stack)
809- #
810- # self.tg.start_soon(self._run_tracked_task, task)
811- else :
812- raise RuntimeError (f'Cannot map non-iterable value: { inputs !r} ' )
847+ yield item_tasks
848+ thread_index += 1
849+
850+ return _GraphTaskAsyncIterable (handle_async_iterable (), fork_stack )
851+
813852 else :
814- for i , path in enumerate (edges ):
815- new_tasks += self ._handle_path (path , inputs , fork_stack + (ForkStackItem (node .id , node_run_id , i ),))
853+ raise RuntimeError (f'Cannot map non-iterable value: { inputs !r} ' )
816854 else :
817- new_tasks += self . _handle_path ( edges [ 0 ], inputs , fork_stack )
818-
855+ for i , path in enumerate ( edges ):
856+ new_tasks += self . _handle_path ( path , inputs , fork_stack + ( ForkStackItem ( node . id , node_run_id , i ),))
819857 return new_tasks
820858
821859 def _is_fork_run_completed (self , tasks : Iterable [GraphTask ], join_id : JoinID , fork_run_id : NodeRunID ) -> bool :
@@ -840,8 +878,14 @@ async def _cancel_sibling_tasks(self, parent_fork_id: ForkID, node_run_id: NodeR
840878 else :
841879 pass
842880 for task_id in task_ids_to_cancel :
881+ if task_id in self .pending_task_results :
882+ self .cancelled_tasks .add (task_id )
843883 await self ._finish_task (task_id )
844884
845885
846886def _is_any_iterable (x : Any ) -> TypeGuard [Iterable [Any ]]:
847887 return isinstance (x , Iterable )
888+
889+
890+ def _is_any_async_iterable (x : Any ) -> TypeGuard [AsyncIterable [Any ]]:
891+ return isinstance (x , AsyncIterable )
0 commit comments