@@ -503,17 +503,17 @@ class NodeExecution(ty.Generic[DefType]):
503503 submitter : Submitter
504504
505505 # List of tasks that were completed successfully
506- successful : dict [StateIndex | None , list ["Task[DefType]" ]]
506+ successful : dict [StateIndex , list ["Task[DefType]" ]]
507507 # List of tasks that failed
508- errored : dict [StateIndex | None , "Task[DefType]" ]
508+ errored : dict [StateIndex , "Task[DefType]" ]
509509 # List of tasks that couldn't be run due to upstream errors
510- unrunnable : dict [StateIndex | None , list ["Task[DefType]" ]]
510+ unrunnable : dict [StateIndex , list ["Task[DefType]" ]]
511511 # List of tasks that are queued
512- queued : dict [StateIndex | None , "Task[DefType]" ]
512+ queued : dict [StateIndex , "Task[DefType]" ]
513513 # List of tasks that are queued
514- running : dict [StateIndex | None , tuple ["Task[DefType]" , datetime ]]
514+ running : dict [StateIndex , tuple ["Task[DefType]" , datetime ]]
515515 # List of tasks that are blocked on other tasks to complete before they can be run
516- blocked : dict [StateIndex | None , "Task[DefType]" ]
516+ blocked : dict [StateIndex , "Task[DefType]" ] | None
517517
518518 _tasks : dict [StateIndex | None , "Task[DefType]" ] | None
519519
@@ -532,7 +532,7 @@ def __init__(
532532 self .submitter = submitter
533533 # Initialize the state dictionaries
534534 self ._tasks = None
535- self .blocked = {}
535+ self .blocked = None
536536 self .successful = {}
537537 self .errored = {}
538538 self .queued = {}
@@ -568,10 +568,13 @@ def tasks(self) -> ty.Iterable["Task[DefType]"]:
568568 self ._tasks = {t .state_index : t for t in self ._generate_tasks ()}
569569 return self ._tasks .values ()
570570
571- def task (self , index : StateIndex = StateIndex ()) -> "Task | list[Task[DefType]]" :
571+ def task (
572+ self , index : StateIndex = StateIndex ()
573+ ) -> "Task | StateArray[Task[DefType]]" :
572574 """Get a task object for a given state index."""
573- self .tasks # Ensure tasks are loaded
574- task_index = next (iter (self ._tasks ))
575+ if not self .tasks :
576+ return StateArray ([])
577+ task_index = next (iter (self ._tasks )) if self ._tasks else StateIndex ()
575578 if len (task_index ) > len (index ):
576579 tasks = []
577580 for ind , task in self ._tasks .items ():
@@ -589,7 +592,7 @@ def started(self) -> bool:
589592 or self .errored
590593 or self .unrunnable
591594 or self .queued
592- or self .blocked
595+ or self .blocked is not None
593596 )
594597
595598 @property
@@ -730,7 +733,7 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]:
730733 runnable : list ["Task[DefType]" ] = []
731734 self .tasks # Ensure tasks are loaded
732735 if not self .started :
733- assert self ._tasks
736+ assert self ._tasks is not None
734737 self .blocked = copy (self ._tasks )
735738 # Check to see if any blocked tasks are now runnable/unrunnable
736739 for index , task in list (self .blocked .items ()):
0 commit comments