Skip to content

Commit 2664509

Browse files
committed
handle empty state arrays, i.e. nodes that don't run any jobs
1 parent 7f1b259 commit 2664509

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

pydra/engine/submitter.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -503,17 +503,17 @@ class NodeExecution(ty.Generic[DefType]):
503503
submitter: Submitter
504504

505505
# List of tasks that were completed successfully
506-
successful: dict[StateIndex | None, list["Task[DefType]"]]
506+
successful: dict[StateIndex, list["Task[DefType]"]]
507507
# List of tasks that failed
508-
errored: dict[StateIndex | None, "Task[DefType]"]
508+
errored: dict[StateIndex, "Task[DefType]"]
509509
# List of tasks that couldn't be run due to upstream errors
510-
unrunnable: dict[StateIndex | None, list["Task[DefType]"]]
510+
unrunnable: dict[StateIndex, list["Task[DefType]"]]
511511
# List of tasks that are queued
512-
queued: dict[StateIndex | None, "Task[DefType]"]
512+
queued: dict[StateIndex, "Task[DefType]"]
513513
# List of tasks that are queued
514-
running: dict[StateIndex | None, tuple["Task[DefType]", datetime]]
514+
running: dict[StateIndex, tuple["Task[DefType]", datetime]]
515515
# List of tasks that are blocked on other tasks to complete before they can be run
516-
blocked: dict[StateIndex | None, "Task[DefType]"]
516+
blocked: dict[StateIndex, "Task[DefType]"] | None
517517

518518
_tasks: dict[StateIndex | None, "Task[DefType]"] | None
519519

@@ -532,7 +532,7 @@ def __init__(
532532
self.submitter = submitter
533533
# Initialize the state dictionaries
534534
self._tasks = None
535-
self.blocked = {}
535+
self.blocked = None
536536
self.successful = {}
537537
self.errored = {}
538538
self.queued = {}
@@ -568,10 +568,13 @@ def tasks(self) -> ty.Iterable["Task[DefType]"]:
568568
self._tasks = {t.state_index: t for t in self._generate_tasks()}
569569
return self._tasks.values()
570570

571-
def task(self, index: StateIndex = StateIndex()) -> "Task | list[Task[DefType]]":
571+
def task(
572+
self, index: StateIndex = StateIndex()
573+
) -> "Task | StateArray[Task[DefType]]":
572574
"""Get a task object for a given state index."""
573-
self.tasks # Ensure tasks are loaded
574-
task_index = next(iter(self._tasks))
575+
if not self.tasks:
576+
return StateArray([])
577+
task_index = next(iter(self._tasks)) if self._tasks else StateIndex()
575578
if len(task_index) > len(index):
576579
tasks = []
577580
for ind, task in self._tasks.items():
@@ -589,7 +592,7 @@ def started(self) -> bool:
589592
or self.errored
590593
or self.unrunnable
591594
or self.queued
592-
or self.blocked
595+
or self.blocked is not None
593596
)
594597

595598
@property
@@ -730,7 +733,7 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]:
730733
runnable: list["Task[DefType]"] = []
731734
self.tasks # Ensure tasks are loaded
732735
if not self.started:
733-
assert self._tasks
736+
assert self._tasks is not None
734737
self.blocked = copy(self._tasks)
735738
# Check to see if any blocked tasks are now runnable/unrunnable
736739
for index, task in list(self.blocked.items()):

pydra/engine/tests/test_node_task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,7 @@ def test_task_state_3(plugin, tmp_path):
887887
assert state.splitter_rpn == ["NA.a"]
888888
assert nn.a == []
889889

890-
with Submitter(worker=plugin, cache_dir=tmp_path) as sub:
890+
with Submitter(worker="debug", cache_dir=tmp_path) as sub:
891891
results = sub(nn)
892892
assert not results.errored, "\n".join(results.errors["error message"])
893893

0 commit comments

Comments
 (0)