@@ -346,7 +346,7 @@ def __init__(
346346 self .inputs = inputs
347347 """The initial input data."""
348348
349- self ._active_reducers : dict [tuple [JoinId , NodeRunId ], Reducer [Any , Any , Any , Any ]] = {}
349+ self ._active_reducers : dict [tuple [JoinId , NodeRunId ], tuple [ Reducer [Any , Any , Any , Any ], ForkStack ]] = {}
350350 """Active reducers for join operations."""
351351
352352 self ._next : EndMarker [OutputT ] | JoinItem | Sequence [GraphTask ] | None = None
@@ -469,39 +469,82 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask])
469469
470470 if isinstance (result , JoinItem ):
471471 parent_fork_id = self .graph .get_parent_fork (result .join_id ).fork_id
472- fork_run_id = [x .node_run_id for x in result .fork_stack [::- 1 ] if x .fork_id == parent_fork_id ][0 ]
473- reducer = self ._active_reducers .get ((result .join_id , fork_run_id ))
474- if reducer is None :
472+ for i , x in enumerate (result .fork_stack [::- 1 ]):
473+ if x .fork_id == parent_fork_id :
474+ downstream_fork_stack = result .fork_stack [: len (result .fork_stack ) - i ]
475+ fork_run_id = x .node_run_id
476+ break
477+ else :
478+ raise RuntimeError ('Parent fork run not found' )
479+
480+ reducer_and_fork_stack = self ._active_reducers .get ((result .join_id , fork_run_id ))
481+ if reducer_and_fork_stack is None :
475482 join_node = self .graph .nodes [result .join_id ]
476483 assert isinstance (join_node , Join )
477- reducer = join_node .create_reducer (StepContext ( self . state , self . deps , result . inputs ) )
478- self ._active_reducers [(result .join_id , fork_run_id )] = reducer
484+ reducer = join_node .create_reducer ()
485+ self ._active_reducers [(result .join_id , fork_run_id )] = reducer , downstream_fork_stack
479486 else :
487+ reducer , _ = reducer_and_fork_stack
488+
489+ try :
480490 reducer .reduce (StepContext (self .state , self .deps , result .inputs ))
491+ except StopIteration :
492+ # cancel all concurrently running tasks with the same fork_run_id of the parent fork
493+ task_ids_to_cancel = set [TaskId ]()
494+ for task_id , t in tasks_by_id .items ():
495+ for item in t .fork_stack :
496+ if item .fork_id == parent_fork_id and item .node_run_id == fork_run_id :
497+ task_ids_to_cancel .add (task_id )
498+ break
499+ for task in list (pending ):
500+ if task .get_name () in task_ids_to_cancel :
501+ task .cancel ()
502+ pending .remove (task )
481503 else :
482504 for new_task in result :
483505 _start_task (new_task )
484506 return False
485507
486- while pending :
487- done , pending = await asyncio .wait (pending , return_when = asyncio .FIRST_COMPLETED )
488- for task in done :
489- task_result = task .result ()
490- source_task = tasks_by_id .pop (TaskId (task .get_name ()))
491- maybe_overridden_result = yield task_result
492- if _handle_result (maybe_overridden_result ):
493- return
494-
495- for join_id , fork_run_id , fork_stack in self ._get_completed_fork_runs (
496- source_task , tasks_by_id .values ()
497- ):
498- reducer = self ._active_reducers .pop ((join_id , fork_run_id ))
508+ while pending or self ._active_reducers :
509+ while pending :
510+ done , pending = await asyncio .wait (pending , return_when = asyncio .FIRST_COMPLETED )
511+ for task in done :
512+ task_result = task .result ()
513+ source_task = tasks_by_id .pop (TaskId (task .get_name ()))
514+ maybe_overridden_result = yield task_result
515+ if _handle_result (maybe_overridden_result ):
516+ return
499517
518+ for join_id , fork_run_id in self ._get_completed_fork_runs (source_task , tasks_by_id .values ()):
519+ reducer , fork_stack = self ._active_reducers .pop ((join_id , fork_run_id ))
520+ output = reducer .finalize (StepContext (self .state , self .deps , None ))
521+ join_node = self .graph .nodes [join_id ]
522+ assert isinstance (
523+ join_node , Join
524+ ) # We could drop this but if it fails it means there is a bug.
525+ new_tasks = self ._handle_edges (join_node , output , fork_stack )
526+ maybe_overridden_result = yield new_tasks # give an opportunity to override these
527+ if _handle_result (maybe_overridden_result ):
528+ return
529+
530+ if self ._active_reducers :
531+ # In this case, there are no pending tasks. We can therefore finalize all active reducers whose
532+ # downstream fork stacks are not a strict "prefix" of another active reducer. (If it was, finalizing the
533+ # deeper reducer could produce new tasks in the "prefix" reducer.)
534+ active_fork_stacks = [fork_stack for _ , fork_stack in self ._active_reducers .values ()]
535+ for (join_id , fork_run_id ), (reducer , fork_stack ) in list (self ._active_reducers .items ()):
536+ if any (
537+ len (afs ) > len (fork_stack ) and fork_stack == afs [: len (fork_stack )]
538+ for afs in active_fork_stacks
539+ ):
540+ continue # this reducer is a strict prefix for one of the other active reducers
541+
542+ self ._active_reducers .pop ((join_id , fork_run_id )) # we're finalizing it now
500543 output = reducer .finalize (StepContext (self .state , self .deps , None ))
501544 join_node = self .graph .nodes [join_id ]
502545 assert isinstance (join_node , Join ) # We could drop this but if it fails it means there is a bug.
503546 new_tasks = self ._handle_edges (join_node , output , fork_stack )
504- maybe_overridden_result = yield new_tasks # Need to give an opportunity to override these
547+ maybe_overridden_result = yield new_tasks # give an opportunity to override these
505548 if _handle_result (maybe_overridden_result ):
506549 return
507550
@@ -588,19 +631,18 @@ def _get_completed_fork_runs(
588631 self ,
589632 t : GraphTask ,
590633 active_tasks : Iterable [GraphTask ],
591- ) -> list [tuple [JoinId , NodeRunId , ForkStack ]]:
592- completed_fork_runs : list [tuple [JoinId , NodeRunId , ForkStack ]] = []
634+ ) -> list [tuple [JoinId , NodeRunId ]]:
635+ completed_fork_runs : list [tuple [JoinId , NodeRunId ]] = []
593636
594637 fork_run_indices = {fsi .node_run_id : i for i , fsi in enumerate (t .fork_stack )}
595638 for join_id , fork_run_id in self ._active_reducers .keys ():
596639 fork_run_index = fork_run_indices .get (fork_run_id )
597640 if fork_run_index is None :
598641 continue # The fork_run_id is not in the current task's fork stack, so this task didn't complete it.
599642
600- new_fork_stack = t .fork_stack [:fork_run_index ]
601643 # This reducer _may_ now be ready to finalize:
602644 if self ._is_fork_run_completed (active_tasks , join_id , fork_run_id ):
603- completed_fork_runs .append ((join_id , fork_run_id , new_fork_stack ))
645+ completed_fork_runs .append ((join_id , fork_run_id ))
604646
605647 return completed_fork_runs
606648
@@ -612,13 +654,27 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen
612654 if isinstance (item , DestinationMarker ):
613655 return [GraphTask (item .destination_id , inputs , fork_stack )]
614656 elif isinstance (item , SpreadMarker ):
657+ # Eagerly raise a clear error if the input value is not iterable as expected
658+ try :
659+ iter (inputs )
660+ except TypeError :
661+ raise RuntimeError (f'Cannot spread non-iterable value: { inputs !r} ' )
662+
615663 node_run_id = NodeRunId (str (uuid .uuid4 ()))
616- return [
617- GraphTask (
618- item .fork_id , input_item , fork_stack + (ForkStackItem (item .fork_id , node_run_id , thread_index ),)
664+
665+ # If the spread specifies a downstream join id, eagerly create a reducer for it
666+ if item .downstream_join_id is not None :
667+ join_node = self .graph .nodes [item .downstream_join_id ]
668+ assert isinstance (join_node , Join )
669+ self ._active_reducers [(item .downstream_join_id , node_run_id )] = join_node .create_reducer (), fork_stack
670+
671+ spread_tasks : list [GraphTask ] = []
672+ for thread_index , input_item in enumerate (inputs ):
673+ item_tasks = self ._handle_path (
674+ path .next_path , input_item , fork_stack + (ForkStackItem (item .fork_id , node_run_id , thread_index ),)
619675 )
620- for thread_index , input_item in enumerate ( inputs )
621- ]
676+ spread_tasks += item_tasks
677+ return spread_tasks
622678 elif isinstance (item , BroadcastMarker ):
623679 return [GraphTask (item .fork_id , inputs , fork_stack )]
624680 elif isinstance (item , TransformMarker ):
@@ -644,6 +700,6 @@ def _is_fork_run_completed(self, tasks: Iterable[GraphTask], join_id: JoinId, fo
644700 parent_fork = self .graph .get_parent_fork (join_id )
645701 for t in tasks :
646702 if fork_run_id in {x .node_run_id for x in t .fork_stack }:
647- if t .node_id in parent_fork .intermediate_nodes :
703+ if t .node_id in parent_fork .intermediate_nodes or t . node_id == join_id :
648704 return False
649705 return True
0 commit comments