@@ -570,7 +570,9 @@ def start(self) -> None:
570570 if self .state :
571571 values = {}
572572 for name , value in self .node .state_values .items ():
573- if name in self .node .state .names :
573+ if name in self .node .state .current_splitter_rpn :
574+ if name in ("*" , "." ):
575+ continue
574576 if isinstance (value , LazyField ):
575577 values [name ] = value ._get_value (
576578 workflow = self .workflow , graph = self .graph
@@ -692,23 +694,20 @@ def _split_definition(self) -> dict[int, "TaskDef[OutputType]"]:
692694 if not self .node .state :
693695 return {None : self .node ._definition }
694696 split_defs = []
695- for input_ind in self .node .state .inputs_ind :
697+ for index , vals in zip ( self .node .state .inputs_ind , self . node . state . states_val ) :
696698 resolved = {}
697699 for inpt_name in set (self .node .input_names ):
698700 value = getattr (self ._definition , inpt_name )
699701 state_key = f"{ self .node .name } .{ inpt_name } "
700- if isinstance (value , LazyField ):
701- resolved [inpt_name ] = value ._get_value (
702- workflow = self .workflow ,
703- graph = self .graph ,
704- state_index = input_ind .get (state_key ),
705- )
706- elif state_key in input_ind :
707- resolved [inpt_name ] = self .node .state ._get_element (
708- value = value ,
709- field_name = inpt_name ,
710- ind = input_ind [state_key ],
711- )
702+ try :
703+ resolved [inpt_name ] = vals [state_key ]
704+ except KeyError :
705+ if isinstance (value , LazyField ):
706+ resolved [inpt_name ] = value ._get_value (
707+ workflow = self .workflow ,
708+ graph = self .graph ,
709+ state_index = index .get (state_key ),
710+ )
712711 split_defs .append (attrs .evolve (self .node ._definition , ** resolved ))
713712 return split_defs
714713
@@ -736,21 +735,28 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]:
736735 for index , task in list (self .blocked .items ()):
737736 pred : NodeExecution
738737 is_runnable = True
739- states_ind = (
740- list (self .node .state .states_ind [index ].items ())
741- if self .node .state
742- else []
743- )
738+ # This is required for the commented-out code below
739+ # states_ind = (
740+ # list(self.node.state.states_ind[index].items())
741+ # if self.node.state
742+ # else []
743+ # )
744744 for pred in graph .predecessors [self .node .name ]:
745745 if pred .node .state :
746- pred_states_ind = {
747- (k , i ) for k , i in states_ind if k .startswith (pred .name + "." )
748- }
749- pred_inds = [
750- i
751- for i , ind in enumerate (pred .node .state .states_ind )
752- if set (ind .items ()).issuperset (pred_states_ind )
753- ]
746+ # FIXME: These should be the only predecessor jobs that are required to have
747+ # completed before the job can be run, however, due to how the state
748+ # is currently built, all predecessors are required to have completed.
749+ # If/when this is relaxed, then the following code should be used instead.
750+ #
751+ # pred_states_ind = {
752+ # (k, i) for k, i in states_ind if k.startswith(pred.name + ".")
753+ # }
754+ # pred_inds = [
755+ # i
756+ # for i, ind in enumerate(pred.node.state.states_ind)
757+ # if set(ind.items()).issuperset(pred_states_ind)
758+ # ]
759+ pred_inds = list (range (len (pred .node .state .states_ind )))
754760 else :
755761 pred_inds = [None ]
756762 if not all (i in pred .successful for i in pred_inds ):
0 commit comments