@@ -538,7 +538,7 @@ def __init__(
538538 self .unrunnable = defaultdict (list )
539539 # Prepare the state to be run
540540 if node .state :
541- self .state = deepcopy ( node .state )
541+ self .state = node .state
542542 self .state .prepare_states (self .node .state_values )
543543 self .state .prepare_inputs ()
544544 else :
@@ -564,34 +564,43 @@ def _definition(self) -> "Node":
564564 return self .node ._definition
565565
566566 @property
567- def tasks (self ) -> ty .Iterable ["Task[DefType]" ]:
567+ def tasks (self ) -> ty .Generator ["Task[DefType]" , None , None ]:
568568 if self ._tasks is None :
569569 self ._tasks = {t .state_index : t for t in self ._generate_tasks ()}
570570 return self ._tasks .values ()
571571
572- def matching_jobs (self , index : int | None = None ) -> "StateArray[Task]" :
572+ def get_jobs (
573+ self , index : int | None = None , as_array : bool = False
574+ ) -> "Task | StateArray[Task]" :
573575 """Get the jobs that match a given state index.
574576
575577 Parameters
576578 ----------
577579 index : int, optional
578580 The index of the state of the task to get, by default None
581+ as_array : bool, optional
582+ Whether to return the tasks in a state-array object, by default if the index
583+ matches
579584
580585 Returns
581586 -------
582- matching : StateArray[Task]
583- The tasks that match the given index
587+ matching : Task | StateArray[Task]
588+ The task or tasks that match the given index
584589 """
585590 matching = StateArray ()
586591 if self .tasks :
587592 try :
588- matching . append ( self ._tasks [index ])
593+ task = self ._tasks [index ]
589594 except KeyError :
595+ if index is None :
596+ return StateArray (self ._tasks .values ())
590597 # Select matching tasks and return them in nested state-array objects
591598 for ind , task in self ._tasks .items ():
592- if ind .matches (index ):
593- matching .append (task )
594-
599+ matching .append (task )
600+ else :
601+ if not as_array :
602+ return task
603+ matching .append (task )
595604 return matching
596605
597606 @property
@@ -657,7 +666,7 @@ def _generate_tasks(self) -> ty.Iterable["Task[DefType]"]:
657666 name = self .node .name ,
658667 )
659668 else :
660- for index , split_defn in self ._split_definition (). items ( ):
669+ for index , split_defn in enumerate ( self ._split_definition ()):
661670 yield Task (
662671 definition = split_defn ,
663672 submitter = self .submitter ,
@@ -707,7 +716,7 @@ def _split_definition(self) -> dict[int, "TaskDef[OutputType]"]:
707716 if not self .node .state :
708717 return {None : self .node ._definition }
709718 split_defs = []
710- for i , input_ind in enumerate ( self .node .state .inputs_ind ) :
719+ for input_ind in self .node .state .inputs_ind :
711720 resolved = {}
712721 for inpt_name in set (self .node .input_names ):
713722 value = getattr (self ._definition , inpt_name )
@@ -716,13 +725,13 @@ def _split_definition(self) -> dict[int, "TaskDef[OutputType]"]:
716725 resolved [inpt_name ] = value ._get_value (
717726 workflow = self .workflow ,
718727 graph = self .graph ,
719- state_index = i ,
728+ state_index = input_ind [ state_key ] ,
720729 )
721730 elif state_key in input_ind :
722731 resolved [inpt_name ] = self .node .state ._get_element (
723732 value = value ,
724733 field_name = inpt_name ,
725- ind = i ,
734+ ind = input_ind [ state_key ] ,
726735 )
727736 split_defs .append (attrs .evolve (self .node ._definition , ** resolved ))
728737 return split_defs
@@ -754,7 +763,8 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]:
754763 pred : NodeExecution
755764 is_runnable = True
756765 for pred in graph .predecessors [self .node .name ]:
757- pred_inds = [j .state_index for j in pred .matching_jobs (index )]
766+ pred_jobs : StateArray [Task ] = pred .get_jobs (index , as_array = True )
767+ pred_inds = [j .state_index for j in pred_jobs ]
758768 if not all (i in pred .successful for i in pred_inds ):
759769 is_runnable = False
760770 blocked = True
0 commit comments