@@ -83,6 +83,7 @@ class Submitter:
8383 messenger_args : dict [str , ty .Any ]
8484 clean_stale_locks : bool
8585 run_start_time : datetime | None
86+ propagate_rerun : bool
8687
8788 def __init__ (
8889 self ,
@@ -94,6 +95,7 @@ def __init__(
9495 audit_flags : AuditFlag = AuditFlag .NONE ,
9596 messengers : ty .Iterable [Messenger ] | None = None ,
9697 messenger_args : dict [str , ty .Any ] | None = None ,
98+ propagate_rerun : bool = True ,
9799 clean_stale_locks : bool | None = None ,
98100 ** kwargs ,
99101 ):
@@ -121,6 +123,7 @@ def __init__(
121123
122124 self .cache_dir = cache_dir
123125 self .cache_locations = cache_locations
126+ self .propagate_rerun = propagate_rerun
124127 self .environment = environment if environment is not None else Native ()
125128 self .loop = get_open_loop ()
126129 self ._own_loop = not self .loop .is_running ()
@@ -188,6 +191,8 @@ def __call__(
188191 rerun : bool, optional
189192 Whether to force the re-computation of the task results even if existing
190193 results are found, by default False
194+ propagate_rerun : bool, optional
195+ Whether to propagate the rerun flag to all tasks in the workflow, by default True
191196
192197 Returns
193198 -------
@@ -312,12 +317,12 @@ def expand_workflow(self, workflow_task: "Task[WorkflowDef]", rerun: bool) -> No
312317 wf = workflow_task .definition .construct ()
313318 # Generate the execution graph
314319 exec_graph = wf .execution_graph (submitter = self )
320+ workflow_task .return_values = {"workflow" : wf , "exec_graph" : exec_graph }
315321 tasks = self .get_runnable_tasks (exec_graph )
316322 while tasks or any (not n .done for n in exec_graph .nodes ):
317323 for task in tasks :
318324 self .worker .run (task , rerun = rerun )
319325 tasks = self .get_runnable_tasks (exec_graph )
320- workflow_task .return_values = {"workflow" : wf , "exec_graph" : exec_graph }
321326
322327 async def expand_workflow_async (
323328 self , workflow_task : "Task[WorkflowDef]" , rerun : bool
@@ -333,6 +338,7 @@ async def expand_workflow_async(
333338 wf = workflow_task .definition .construct ()
334339 # Generate the execution graph
335340 exec_graph = wf .execution_graph (submitter = self )
341+ workflow_task .return_values = {"workflow" : wf , "exec_graph" : exec_graph }
336342 # keep track of pending futures
337343 task_futures = set ()
338344 tasks = self .get_runnable_tasks (exec_graph )
@@ -417,7 +423,6 @@ async def expand_workflow_async(
417423 task_futures .add (self .worker .run (task , rerun = rerun ))
418424 task_futures = await self .worker .fetch_finished (task_futures )
419425 tasks = self .get_runnable_tasks (exec_graph )
420- workflow_task .return_values = {"workflow" : wf , "exec_graph" : exec_graph }
421426
422427 def __enter__ (self ):
423428 return self
@@ -467,7 +472,8 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]:
467472 continue
468473 # since the list is sorted (breadth-first) we can stop
469474 # when we find a task that depends on any task that is already in tasks
470- if set (graph .predecessors [node .name ]).intersection (not_started ):
475+ preds = set (graph .predecessors [node .name ])
476+ if preds .intersection (not_started ):
471477 break
472478 # Record if the node has not been started
473479 if not node .started :
@@ -619,6 +625,11 @@ def done(self) -> bool:
619625 # Check to see if any previously queued tasks have completed
620626 return not (self .queued or self .blocked or self .running )
621627
628+ @property
629+ def has_errored (self ) -> bool :
630+ self .update_status ()
631+ return bool (self .errored )
632+
622633 def update_status (self ) -> None :
623634 """Updates the status of the tasks in the node."""
624635 if not self .started :
@@ -729,56 +740,80 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]:
729740 List of tasks that are ready to run
730741 """
731742 runnable : list ["Task[DefType]" ] = []
732- if not self .started :
733- self .start ()
734- # Check to see if any blocked tasks are now runnable/unrunnable
735- for index , task in list (self .blocked .items ()):
736- pred : NodeExecution
737- is_runnable = True
738- # This is required for the commented-out code below
739- # states_ind = (
740- # list(self.node.state.states_ind[index].items())
741- # if self.node.state
742- # else []
743- # )
744- for pred in graph .predecessors [self .node .name ]:
745- if pred .node .state :
746- # FIXME: These should be the only predecessor jobs that are required to have
747- # completed before the job can be run, however, due to how the state
748- # is currently built, all predecessors are required to have completed.
749- # If/when this is relaxed, then the following code should be used instead.
750- #
751- # pred_states_ind = {
752- # (k, i) for k, i in states_ind if k.startswith(pred.name + ".")
753- # }
754- # pred_inds = [
755- # i
756- # for i, ind in enumerate(pred.node.state.states_ind)
757- # if set(ind.items()).issuperset(pred_states_ind)
758- # ]
759- pred_inds = list (range (len (pred .node .state .states_ind )))
760- else :
761- pred_inds = [None ]
762- if not all (i in pred .successful for i in pred_inds ):
763- is_runnable = False
764- blocked = True
765- if pred_errored := [i for i in pred_inds if i in pred .errored ]:
766- self .unrunnable [index ].extend (
767- [pred .errored [i ] for i in pred_errored ]
768- )
769- blocked = False
770- if pred_unrunnable := [
771- i for i in pred_inds if i in pred .unrunnable
772- ]:
773- self .unrunnable [index ].extend (
774- [pred .unrunnable [i ] for i in pred_unrunnable ]
775- )
776- blocked = False
777- if not blocked :
778- del self .blocked [index ]
779- break
780- if is_runnable :
781- runnable .append (self .blocked .pop (index ))
743+ predecessors : list ["Task[DefType]" ] = graph .predecessors [self .node .name ]
744+
745+ # If there is a split, we need to wait for all predecessor nodes to finish
746+ # In theory, if the current splitter splits an already split state we should
747+ # only need to wait for the direct predecessor jobs to finish, however, this
748+ # would require a deep refactor of the State class as we need the whole state
749+ # in order to assign consistent state indices across the new split
750+
751+ # FIXME: The branch for handling partially completed/errored/unrunnable
752+ # predecessor nodes can't be used until the State class can be partially
753+ # initialised with lazy-fields.
754+ if True : # self.node.splitter:
755+ if unrunnable := [p for p in predecessors if p .errored or p .unrunnable ]:
756+ self .unrunnable = {None : unrunnable }
757+ self .blocked = {}
758+ assert self .done
759+ else :
760+ if all (p .done for p in predecessors ):
761+ if not self .started :
762+ self .start ()
763+ if self .node .state is None :
764+ inds = [None ]
765+ else :
766+ inds = list (range (len (self .node .state .states_ind )))
767+ if self .blocked :
768+ for i in inds :
769+ runnable .append (self .blocked .pop (i ))
770+ else :
771+ if not self .started :
772+ self .start ()
773+
774+ # Check to see if any blocked tasks are now runnable/unrunnable
775+ for index , task in list (self .blocked .items ()):
776+ pred : NodeExecution
777+ is_runnable = True
778+ states_ind = (
779+ list (self .node .state .states_ind [index ].items ())
780+ if self .node .state
781+ else []
782+ )
783+ for pred in predecessors :
784+ if pred .node .state :
785+ pred_states_ind = {
786+ (k , i )
787+ for k , i in states_ind
788+ if k .startswith (pred .name + "." )
789+ }
790+ pred_inds = [
791+ i
792+ for i , ind in enumerate (pred .node .state .states_ind )
793+ if set (ind .items ()).issuperset (pred_states_ind )
794+ ]
795+ else :
796+ pred_inds = [None ]
797+ if not all (i in pred .successful for i in pred_inds ):
798+ is_runnable = False
799+ blocked = True
800+ if pred_errored := [
801+ pred .errored [i ] for i in pred_inds if i in pred .errored
802+ ]:
803+ self .unrunnable [index ].extend (pred_errored )
804+ blocked = False
805+ if pred_unrunnable := [
806+ pred .unrunnable [i ]
807+ for i in pred_inds
808+ if i in pred .unrunnable
809+ ]:
810+ self .unrunnable [index ].extend (pred_unrunnable )
811+ blocked = False
812+ if not blocked :
813+ del self .blocked [index ]
814+ break
815+ if is_runnable :
816+ runnable .append (self .blocked .pop (index ))
782817 self .queued .update ({t .state_index : t for t in runnable })
783818 return list (self .queued .values ())
784819
0 commit comments