@@ -536,17 +536,14 @@ def __init__(
536536 self .queued = {}
537537 self .running = {} # Not used in logic, but may be useful for progress tracking
538538 self .unrunnable = defaultdict (list )
539- # Prepare the state to be run
540- if node .state :
541- self .state = node .state
542- self .state .prepare_states (self .node .state_values )
543- self .state .prepare_inputs ()
544- else :
545- self .state = None
546539 self .state_names = self .node .state .names if self .node .state else []
547540 self .workflow = workflow
548541 self .graph = None
549542
543+ @property
544+ def state (self ):
545+ return self .node .state
546+
550547 def __repr__ (self ):
551548 return (
552549 f"NodeExecution(name={ self .name !r} , blocked={ list (self .blocked )} , "
@@ -566,7 +563,7 @@ def _definition(self) -> "Node":
566563 @property
567564 def tasks (self ) -> ty .Generator ["Task[DefType]" , None , None ]:
568565 if self ._tasks is None :
569- self . _tasks = { t . state_index : t for t in self . _generate_tasks ()}
566+ raise RuntimeError ( "Tasks have not been generated" )
570567 return self ._tasks .values ()
571568
572569 def get_jobs (self , final_index : int | None = None ) -> "Task | StateArray[Task]" :
@@ -586,14 +583,11 @@ def get_jobs(self, final_index: int | None = None) -> "Task | StateArray[Task]":
586583 if not self .tasks : # No jobs, return empty state array
587584 return StateArray ()
588585 if not self .node .state : # Return the singular job
589- assert final_index is None
590- task = self ._tasks [None ]
591- return task
586+ return self ._tasks [None ]
592587 if final_index is None : # return all jobs in a state array
593588 return StateArray (self ._tasks .values ())
594589 if not self .node .state .combiner : # Select the job that matches the index
595- task = self ._tasks [final_index ]
596- return task
590+ return self ._tasks [final_index ]
597591 # Get a slice of the tasks that match the given index of the state array of the
598592 # combined values
599593 final_index = set (self .node .state .states_ind_final [final_index ].items ())
@@ -603,6 +597,38 @@ def get_jobs(self, final_index: int | None = None) -> "Task | StateArray[Task]":
603597 if set (ind .items ()).issuperset (final_index )
604598 )
605599
600+ def start (self ) -> None :
601+ """Prepare the execution node so that it can be processed"""
602+ self ._tasks = {}
603+ if self .state :
604+ values = {}
605+ for name , value in self .node .state_values .items ():
606+ if name in self .node .state .names and isinstance (value , LazyField ):
607+ values [name ] = value ._get_value (
608+ workflow = self .workflow , graph = self .graph
609+ )
610+ self .state .prepare_states (values )
611+ self .state .prepare_inputs ()
612+ # Generate the tasks
613+ for index , split_defn in enumerate (self ._split_definition ()):
614+ self ._tasks [index ] = Task (
615+ definition = split_defn ,
616+ submitter = self .submitter ,
617+ environment = self .node ._environment ,
618+ name = self .node .name ,
619+ hooks = self .node ._hooks ,
620+ state_index = index ,
621+ )
622+ else :
623+ self ._tasks [None ] = Task (
624+ definition = self ._resolve_lazy_inputs (task_def = self .node ._definition ),
625+ submitter = self .submitter ,
626+ environment = self .node ._environment ,
627+ hooks = self .node ._hooks ,
628+ name = self .node .name ,
629+ )
630+ self .blocked = copy (self ._tasks )
631+
606632 @property
607633 def started (self ) -> bool :
608634 return (
@@ -656,26 +682,6 @@ def all_failed(self) -> bool:
656682 self .successful or self .blocked or self .queued
657683 )
658684
659- def _generate_tasks (self ) -> ty .Iterable ["Task[DefType]" ]:
660- if not self .node .state :
661- yield Task (
662- definition = self ._resolve_lazy_inputs (task_def = self .node ._definition ),
663- submitter = self .submitter ,
664- environment = self .node ._environment ,
665- hooks = self .node ._hooks ,
666- name = self .node .name ,
667- )
668- else :
669- for index , split_defn in enumerate (self ._split_definition ()):
670- yield Task (
671- definition = split_defn ,
672- submitter = self .submitter ,
673- environment = self .node ._environment ,
674- name = self .node .name ,
675- hooks = self .node ._hooks ,
676- state_index = index ,
677- )
678-
679685 def _resolve_lazy_inputs (
680686 self ,
681687 task_def : "TaskDef" ,
@@ -754,10 +760,8 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]:
754760 List of tasks that are ready to run
755761 """
756762 runnable : list ["Task[DefType]" ] = []
757- self .tasks # Ensure tasks are loaded
758763 if not self .started :
759- assert self ._tasks is not None
760- self .blocked = copy (self ._tasks )
764+ self .start ()
761765 # Check to see if any blocked tasks are now runnable/unrunnable
762766 for index , task in list (self .blocked .items ()):
763767 pred : NodeExecution
0 commit comments