3232
3333if ty .TYPE_CHECKING :
3434 from .node import Node
35- from .specs import TaskDef , TaskOutputs , WorkflowDef , TaskHooks , Result
35+ from .specs import WorkflowDef , TaskDef , TaskOutputs , TaskHooks , Result
36+ from .core import Workflow
3637 from .environments import Environment
3738 from .state import State
3839
@@ -501,15 +502,15 @@ class NodeExecution(ty.Generic[DefType]):
501502
502503 _tasks : dict [StateIndex | None , "Task[DefType]" ] | None
503504
504- workflow_inputs : "WorkflowDef "
505+ workflow : "Workflow "
505506
506507 graph : DiGraph ["NodeExecution" ] | None
507508
508509 def __init__ (
509510 self ,
510511 node : "Node" ,
511512 submitter : Submitter ,
512- workflow_inputs : "WorkflowDef " ,
513+ workflow : "Workflow " ,
513514 ):
514515 self .name = node .name
515516 self .node = node
@@ -523,9 +524,17 @@ def __init__(
523524 self .running = {} # Not used in logic, but may be useful for progress tracking
524525 self .unrunnable = defaultdict (list )
525526 self .state_names = self .node .state .names if self .node .state else []
526- self .workflow_inputs = workflow_inputs
527+ self .workflow = workflow
527528 self .graph = None
528529
530+ def __repr__ (self ):
531+ return (
532+ f"NodeExecution(name={ self .name !r} , blocked={ list (self .blocked )} , "
533+ f"queued={ list (self .queued )} , running={ list (self .running )} , "
534+ f"successful={ list (self .successful )} , errored={ list (self .errored )} , "
535+ f"unrunnable={ list (self .unrunnable )} )"
536+ )
537+
529538 @property
530539 def inputs (self ) -> "Node.Inputs" :
531540 return self .node .inputs
@@ -547,12 +556,16 @@ def tasks(self) -> ty.Iterable["Task[DefType]"]:
547556 def task (self , index : StateIndex = StateIndex ()) -> "Task | list[Task[DefType]]" :
548557 """Get a task object for a given state index."""
549558 self .tasks # Ensure tasks are loaded
550- try :
551- return self ._tasks [index ]
552- except KeyError :
553- if not index :
554- return StateArray (self ._tasks .values ())
555- raise
559+ task_index = next (iter (self ._tasks ))
560+ if len (task_index ) > len (index ):
561+ tasks = []
562+ for ind , task in self ._tasks .items ():
563+ if ind .matches (index ):
564+ tasks .append (task )
565+ return StateArray (tasks )
566+ elif len (index ) > len (task_index ):
567+ index = index .subset (task_index )
568+ return self ._tasks [index ]
556569
557570 @property
558571 def started (self ) -> bool :
@@ -651,10 +664,12 @@ def _resolve_lazy_inputs(
651664 The task definition with all lazy fields resolved
652665 """
653666 resolved = {}
654- for name , value in attrs_values (self ).items ():
667+ for name , value in attrs_values (task_def ).items ():
655668 if isinstance (value , LazyField ):
656- resolved [name ] = value ._get_value (self , state_index )
657- return attrs .evolve (self , ** resolved )
669+ resolved [name ] = value ._get_value (
670+ workflow = self .workflow , graph = self .graph , state_index = state_index
671+ )
672+ return attrs .evolve (task_def , ** resolved )
658673
659674 def get_runnable_tasks (self , graph : DiGraph ) -> list ["Task[DefType]" ]:
660675 """For a given node, check to see which tasks have been successfully run, are ready
@@ -676,19 +691,35 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]:
676691 runnable : list ["Task[DefType]" ] = []
677692 self .tasks # Ensure tasks are loaded
678693 if not self .started :
694+ assert self ._tasks
679695 self .blocked = copy (self ._tasks )
680696 # Check to see if any blocked tasks are now runnable/unrunnable
681697 for index , task in list (self .blocked .items ()):
682698 pred : NodeExecution
683699 is_runnable = True
684700 for pred in graph .predecessors [self .node .name ]:
685- if index not in pred .successful :
701+ pred_jobs = pred .task (index )
702+ if isinstance (pred_jobs , StateArray ):
703+ pred_inds = [j .state_index for j in pred_jobs ]
704+ else :
705+ pred_inds = [pred_jobs .state_index ]
706+ if not all (i in pred .successful for i in pred_inds ):
686707 is_runnable = False
687- if index in pred .errored :
688- self .unrunnable [index ].append (self .blocked .pop (index ))
689- if index in pred .unrunnable :
690- self .unrunnable [index ].extend (pred .unrunnable [index ])
691- self .blocked .pop (index )
708+ blocked = True
709+ if pred_errored := [i for i in pred_inds if i in pred .errored ]:
710+ self .unrunnable [index ].extend (
711+ [pred .errored [i ] for i in pred_errored ]
712+ )
713+ blocked = False
714+ if pred_unrunnable := [
715+ i for i in pred_inds if i in pred .unrunnable
716+ ]:
717+ self .unrunnable [index ].extend (
718+ [pred .unrunnable [i ] for i in pred_unrunnable ]
719+ )
720+ blocked = False
721+ if not blocked :
722+ del self .blocked [index ]
692723 break
693724 if is_runnable :
694725 runnable .append (self .blocked .pop (index ))
0 commit comments