@@ -569,39 +569,39 @@ def tasks(self) -> ty.Generator["Task[DefType]", None, None]:
569569 self ._tasks = {t .state_index : t for t in self ._generate_tasks ()}
570570 return self ._tasks .values ()
571571
572- def get_jobs (
573- self , index : int | None = None , as_array : bool = False
574- ) -> "Task | StateArray[Task]" :
572+ def get_jobs (self , final_index : int | None = None ) -> "Task | StateArray[Task]" :
575573 """Get the jobs that match a given state index.
576574
577575 Parameters
578576 ----------
579- index : int, optional
580- The index of the state of the task to get, by default None
581- as_array : bool, optional
582- Whether to return the tasks in a state-array object, by default if the index
583- matches
577+ final_index : int, optional
578+ The index of the output state array (i.e. after any combinations) of the
579+ job to get, by default None
584580
585581 Returns
586582 -------
587583 matching : Task | StateArray[Task]
588584 The task or tasks that match the given index
589585 """
590- matching = StateArray ()
591- if self .tasks :
592- try :
593- task = self ._tasks [index ]
594- except KeyError :
595- if index is None :
596- return StateArray (self ._tasks .values ())
597- # Select matching tasks and return them in nested state-array objects
598- for ind , task in self ._tasks .items ():
599- matching .append (task )
600- else :
601- if not as_array :
602- return task
603- matching .append (task )
604- return matching
586+ if not self .tasks : # No jobs, return empty state array
587+ return StateArray ()
588+ if not self .node .state : # Return the singular job
589+ assert final_index is None
590+ task = self ._tasks [None ]
591+ return task
592+ if final_index is None : # return all jobs in a state array
593+ return StateArray (self ._tasks .values ())
594+ if not self .node .state .combiner : # Select the job that matches the index
595+ task = self ._tasks [final_index ]
596+ return task
597+ # Get a slice of the tasks that match the given index of the state array of the
598+ # combined values
599+ final_index = set (self .node .state .states_ind_final [final_index ].items ())
600+ return StateArray (
601+ self ._tasks [i ]
602+ for i , ind in enumerate (self .node .state .states_ind )
603+ if set (ind .items ()).issuperset (final_index )
604+ )
605605
606606 @property
607607 def started (self ) -> bool :
@@ -762,9 +762,23 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]:
762762 for index , task in list (self .blocked .items ()):
763763 pred : NodeExecution
764764 is_runnable = True
765+ states_ind = (
766+ list (self .node .state .states_ind [index ].items ())
767+ if self .node .state
768+ else []
769+ )
765770 for pred in graph .predecessors [self .node .name ]:
766- pred_jobs : StateArray [Task ] = pred .get_jobs (index , as_array = True )
767- pred_inds = [j .state_index for j in pred_jobs ]
771+ if pred .node .state :
772+ pred_states_ind = {
773+ (k , i ) for k , i in states_ind if k .startswith (pred .name + "." )
774+ }
775+ pred_inds = [
776+ i
777+ for i , ind in enumerate (pred .node .state .states_ind )
778+ if set (ind .items ()).issuperset (pred_states_ind )
779+ ]
780+ else :
781+ pred_inds = [None ]
768782 if not all (i in pred .successful for i in pred_inds ):
769783 is_runnable = False
770784 blocked = True
0 commit comments