@@ -346,7 +346,7 @@ def __init__(
346
346
self .inputs = inputs
347
347
"""The initial input data."""
348
348
349
- self ._active_reducers : dict [tuple [JoinId , NodeRunId ], Reducer [Any , Any , Any , Any ]] = {}
349
+ self ._active_reducers : dict [tuple [JoinId , NodeRunId ], tuple [ Reducer [Any , Any , Any , Any ], ForkStack ]] = {}
350
350
"""Active reducers for join operations."""
351
351
352
352
self ._next : EndMarker [OutputT ] | JoinItem | Sequence [GraphTask ] | None = None
@@ -469,39 +469,82 @@ def _handle_result(result: EndMarker[OutputT] | JoinItem | Sequence[GraphTask])
469
469
470
470
if isinstance (result , JoinItem ):
471
471
parent_fork_id = self .graph .get_parent_fork (result .join_id ).fork_id
472
- fork_run_id = [x .node_run_id for x in result .fork_stack [::- 1 ] if x .fork_id == parent_fork_id ][0 ]
473
- reducer = self ._active_reducers .get ((result .join_id , fork_run_id ))
474
- if reducer is None :
472
+ for i , x in enumerate (result .fork_stack [::- 1 ]):
473
+ if x .fork_id == parent_fork_id :
474
+ downstream_fork_stack = result .fork_stack [: len (result .fork_stack ) - i ]
475
+ fork_run_id = x .node_run_id
476
+ break
477
+ else :
478
+ raise RuntimeError ('Parent fork run not found' )
479
+
480
+ reducer_and_fork_stack = self ._active_reducers .get ((result .join_id , fork_run_id ))
481
+ if reducer_and_fork_stack is None :
475
482
join_node = self .graph .nodes [result .join_id ]
476
483
assert isinstance (join_node , Join )
477
- reducer = join_node .create_reducer (StepContext ( self . state , self . deps , result . inputs ) )
478
- self ._active_reducers [(result .join_id , fork_run_id )] = reducer
484
+ reducer = join_node .create_reducer ()
485
+ self ._active_reducers [(result .join_id , fork_run_id )] = reducer , downstream_fork_stack
479
486
else :
487
+ reducer , _ = reducer_and_fork_stack
488
+
489
+ try :
480
490
reducer .reduce (StepContext (self .state , self .deps , result .inputs ))
491
+ except StopIteration :
492
+ # cancel all concurrently running tasks with the same fork_run_id of the parent fork
493
+ task_ids_to_cancel = set [TaskId ]()
494
+ for task_id , t in tasks_by_id .items ():
495
+ for item in t .fork_stack :
496
+ if item .fork_id == parent_fork_id and item .node_run_id == fork_run_id :
497
+ task_ids_to_cancel .add (task_id )
498
+ break
499
+ for task in list (pending ):
500
+ if task .get_name () in task_ids_to_cancel :
501
+ task .cancel ()
502
+ pending .remove (task )
481
503
else :
482
504
for new_task in result :
483
505
_start_task (new_task )
484
506
return False
485
507
486
- while pending :
487
- done , pending = await asyncio .wait (pending , return_when = asyncio .FIRST_COMPLETED )
488
- for task in done :
489
- task_result = task .result ()
490
- source_task = tasks_by_id .pop (TaskId (task .get_name ()))
491
- maybe_overridden_result = yield task_result
492
- if _handle_result (maybe_overridden_result ):
493
- return
494
-
495
- for join_id , fork_run_id , fork_stack in self ._get_completed_fork_runs (
496
- source_task , tasks_by_id .values ()
497
- ):
498
- reducer = self ._active_reducers .pop ((join_id , fork_run_id ))
508
+ while pending or self ._active_reducers :
509
+ while pending :
510
+ done , pending = await asyncio .wait (pending , return_when = asyncio .FIRST_COMPLETED )
511
+ for task in done :
512
+ task_result = task .result ()
513
+ source_task = tasks_by_id .pop (TaskId (task .get_name ()))
514
+ maybe_overridden_result = yield task_result
515
+ if _handle_result (maybe_overridden_result ):
516
+ return
499
517
518
+ for join_id , fork_run_id in self ._get_completed_fork_runs (source_task , tasks_by_id .values ()):
519
+ reducer , fork_stack = self ._active_reducers .pop ((join_id , fork_run_id ))
520
+ output = reducer .finalize (StepContext (self .state , self .deps , None ))
521
+ join_node = self .graph .nodes [join_id ]
522
+ assert isinstance (
523
+ join_node , Join
524
+ ) # We could drop this but if it fails it means there is a bug.
525
+ new_tasks = self ._handle_edges (join_node , output , fork_stack )
526
+ maybe_overridden_result = yield new_tasks # give an opportunity to override these
527
+ if _handle_result (maybe_overridden_result ):
528
+ return
529
+
530
+ if self ._active_reducers :
531
+ # In this case, there are no pending tasks. We can therefore finalize all active reducers whose
532
+ # downstream fork stacks are not a strict "prefix" of another active reducer. (If it was, finalizing the
533
+ # deeper reducer could produce new tasks in the "prefix" reducer.)
534
+ active_fork_stacks = [fork_stack for _ , fork_stack in self ._active_reducers .values ()]
535
+ for (join_id , fork_run_id ), (reducer , fork_stack ) in list (self ._active_reducers .items ()):
536
+ if any (
537
+ len (afs ) > len (fork_stack ) and fork_stack == afs [: len (fork_stack )]
538
+ for afs in active_fork_stacks
539
+ ):
540
+ continue # this reducer is a strict prefix for one of the other active reducers
541
+
542
+ self ._active_reducers .pop ((join_id , fork_run_id )) # we're finalizing it now
500
543
output = reducer .finalize (StepContext (self .state , self .deps , None ))
501
544
join_node = self .graph .nodes [join_id ]
502
545
assert isinstance (join_node , Join ) # We could drop this but if it fails it means there is a bug.
503
546
new_tasks = self ._handle_edges (join_node , output , fork_stack )
504
- maybe_overridden_result = yield new_tasks # Need to give an opportunity to override these
547
+ maybe_overridden_result = yield new_tasks # give an opportunity to override these
505
548
if _handle_result (maybe_overridden_result ):
506
549
return
507
550
@@ -588,19 +631,18 @@ def _get_completed_fork_runs(
588
631
self ,
589
632
t : GraphTask ,
590
633
active_tasks : Iterable [GraphTask ],
591
- ) -> list [tuple [JoinId , NodeRunId , ForkStack ]]:
592
- completed_fork_runs : list [tuple [JoinId , NodeRunId , ForkStack ]] = []
634
+ ) -> list [tuple [JoinId , NodeRunId ]]:
635
+ completed_fork_runs : list [tuple [JoinId , NodeRunId ]] = []
593
636
594
637
fork_run_indices = {fsi .node_run_id : i for i , fsi in enumerate (t .fork_stack )}
595
638
for join_id , fork_run_id in self ._active_reducers .keys ():
596
639
fork_run_index = fork_run_indices .get (fork_run_id )
597
640
if fork_run_index is None :
598
641
continue # The fork_run_id is not in the current task's fork stack, so this task didn't complete it.
599
642
600
- new_fork_stack = t .fork_stack [:fork_run_index ]
601
643
# This reducer _may_ now be ready to finalize:
602
644
if self ._is_fork_run_completed (active_tasks , join_id , fork_run_id ):
603
- completed_fork_runs .append ((join_id , fork_run_id , new_fork_stack ))
645
+ completed_fork_runs .append ((join_id , fork_run_id ))
604
646
605
647
return completed_fork_runs
606
648
@@ -612,13 +654,27 @@ def _handle_path(self, path: Path, inputs: Any, fork_stack: ForkStack) -> Sequen
612
654
if isinstance (item , DestinationMarker ):
613
655
return [GraphTask (item .destination_id , inputs , fork_stack )]
614
656
elif isinstance (item , SpreadMarker ):
657
+ # Eagerly raise a clear error if the input value is not iterable as expected
658
+ try :
659
+ iter (inputs )
660
+ except TypeError :
661
+ raise RuntimeError (f'Cannot spread non-iterable value: { inputs !r} ' )
662
+
615
663
node_run_id = NodeRunId (str (uuid .uuid4 ()))
616
- return [
617
- GraphTask (
618
- item .fork_id , input_item , fork_stack + (ForkStackItem (item .fork_id , node_run_id , thread_index ),)
664
+
665
+ # If the spread specifies a downstream join id, eagerly create a reducer for it
666
+ if item .downstream_join_id is not None :
667
+ join_node = self .graph .nodes [item .downstream_join_id ]
668
+ assert isinstance (join_node , Join )
669
+ self ._active_reducers [(item .downstream_join_id , node_run_id )] = join_node .create_reducer (), fork_stack
670
+
671
+ spread_tasks : list [GraphTask ] = []
672
+ for thread_index , input_item in enumerate (inputs ):
673
+ item_tasks = self ._handle_path (
674
+ path .next_path , input_item , fork_stack + (ForkStackItem (item .fork_id , node_run_id , thread_index ),)
619
675
)
620
- for thread_index , input_item in enumerate ( inputs )
621
- ]
676
+ spread_tasks += item_tasks
677
+ return spread_tasks
622
678
elif isinstance (item , BroadcastMarker ):
623
679
return [GraphTask (item .fork_id , inputs , fork_stack )]
624
680
elif isinstance (item , TransformMarker ):
@@ -644,6 +700,6 @@ def _is_fork_run_completed(self, tasks: Iterable[GraphTask], join_id: JoinId, fo
644
700
parent_fork = self .graph .get_parent_fork (join_id )
645
701
for t in tasks :
646
702
if fork_run_id in {x .node_run_id for x in t .fork_stack }:
647
- if t .node_id in parent_fork .intermediate_nodes :
703
+ if t .node_id in parent_fork .intermediate_nodes or t . node_id == join_id :
648
704
return False
649
705
return True
0 commit comments