Skip to content

Commit db8f799

Browse files
committed
finished debugging test_specs
1 parent b2034d5 commit db8f799

File tree

7 files changed

+220
-196
lines changed

7 files changed

+220
-196
lines changed

pydra/engine/core.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -817,9 +817,7 @@ def node_names(self) -> list[str]:
817817
def execution_graph(self, submitter: "Submitter") -> DiGraph:
818818
from pydra.engine.submitter import NodeExecution
819819

820-
exec_nodes = [
821-
NodeExecution(n, submitter, workflow_inputs=self.inputs) for n in self.nodes
822-
]
820+
exec_nodes = [NodeExecution(n, submitter, workflow=self) for n in self.nodes]
823821
graph = self._create_graph(exec_nodes)
824822
# Set the graph attribute of the nodes so lazy fields can be resolved as tasks
825823
# are created

pydra/engine/lazy.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from . import node
77

88
if ty.TYPE_CHECKING:
9-
from .submitter import NodeExecution
9+
from .submitter import DiGraph, NodeExecution
1010
from .core import Task, Workflow
1111
from .specs import TaskDef
1212
from .state import StateIndex
@@ -47,15 +47,18 @@ def _apply_cast(self, value):
4747

4848
def _get_value(
4949
self,
50-
node_exec: "NodeExecution",
50+
workflow: "Workflow",
51+
graph: "DiGraph[NodeExecution]",
5152
state_index: "StateIndex | None" = None,
5253
) -> ty.Any:
5354
"""Return the value of a lazy field.
5455
5556
Parameters
5657
----------
57-
node_exec: NodeExecution
58-
the object representing the execution state of the current node
58+
workflow: Workflow
59+
the workflow object
60+
graph: DiGraph[NodeExecution]
61+
the graph representing the execution state of the workflow
5962
state_index : StateIndex, optional
6063
the state index of the field to access
6164
@@ -90,25 +93,27 @@ def _source(self):
9093

9194
def _get_value(
9295
self,
93-
node_exec: "NodeExecution",
96+
workflow: "Workflow",
97+
graph: "DiGraph[NodeExecution]",
9498
state_index: "StateIndex | None" = None,
9599
) -> ty.Any:
96100
"""Return the value of a lazy field.
97101
98102
Parameters
99103
----------
100-
node_exec: NodeExecution
101-
the object representing the execution state of the current node
104+
workflow: Workflow
105+
the workflow object
106+
graph: DiGraph[NodeExecution]
107+
the graph representing the execution state of the workflow
102108
state_index : StateIndex, optional
103-
the state index of the field to access (ignored, used for duck-typing with
104-
LazyOutField)
109+
the state index of the field to access
105110
106111
Returns
107112
-------
108113
value : Any
109114
the resolved value of the lazy-field
110115
"""
111-
value = node_exec.workflow_inputs[self._field]
116+
value = workflow.inputs[self._field]
112117
value = self._apply_cast(value)
113118
return value
114119

@@ -127,15 +132,18 @@ def __repr__(self):
127132

128133
def _get_value(
129134
self,
130-
node_exec: "NodeExecution",
135+
workflow: "Workflow",
136+
graph: "DiGraph[NodeExecution]",
131137
state_index: "StateIndex | None" = None,
132138
) -> ty.Any:
133139
"""Return the value of a lazy field.
134140
135141
Parameters
136142
----------
137-
node_exec: NodeExecution
138-
the object representing the execution state of the current node
143+
workflow: Workflow
144+
the workflow object
145+
graph: DiGraph[NodeExecution]
146+
the graph representing the execution state of the workflow
139147
state_index : StateIndex, optional
140148
the state index of the field to access
141149
@@ -152,7 +160,7 @@ def _get_value(
152160
if state_index is None:
153161
state_index = StateIndex()
154162

155-
task = node_exec.graph.node(self._node.name).task(state_index)
163+
task = graph.node(self._node.name).task(state_index)
156164
_, split_depth = TypeParser.strip_splits(self._type)
157165

158166
def get_nested(task: "Task[DefType]", depth: int):

pydra/engine/node.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,11 @@ def _get_upstream_states(self) -> dict[str, tuple["State", list[str]]]:
269269
"""Get the states of the upstream nodes that are connected to this node"""
270270
upstream_states = {}
271271
for inpt_name, val in self.input_values:
272-
if isinstance(val, lazy.LazyOutField) and val._node.state:
272+
if (
273+
isinstance(val, lazy.LazyOutField)
274+
and val._node.state
275+
and val._node.state.depth
276+
):
273277
node: Node = val._node
274278
# variables that are part of inner splitters should be treated as a containers
275279
if node.state and f"{node.name}.{inpt_name}" in node.state.splitter:

pydra/engine/specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,7 @@ def _from_task(cls, task: "Task[WorkflowDef]") -> Self:
736736
nodes_dict = {n.name: n for n in exec_graph.nodes}
737737
for name, lazy_field in attrs_values(workflow.outputs).items():
738738
try:
739-
val_out = lazy_field._get_value(exec_graph)
739+
val_out = lazy_field._get_value(workflow=workflow, graph=exec_graph)
740740
output_wf[name] = val_out
741741
except (ValueError, AttributeError):
742742
output_wf[name] = None

pydra/engine/state.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,23 +41,63 @@ def __init__(self, indices: dict[str, int] | None = None):
4141
else:
4242
self.indices = OrderedDict(sorted(indices.items()))
4343

44-
def __repr__(self):
44+
def __len__(self) -> int:
45+
return len(self.indices)
46+
47+
def __iter__(self) -> ty.Generator[str, None, None]:
48+
return iter(self.indices)
49+
50+
def __repr__(self) -> str:
4551
return (
4652
"StateIndex(" + ", ".join(f"{n}={v}" for n, v in self.indices.items()) + ")"
4753
)
4854

4955
def __hash__(self):
5056
return hash(tuple(self.indices.items()))
5157

52-
def __eq__(self, other):
58+
def __eq__(self, other) -> bool:
5359
return self.indices == other.indices
5460

55-
def __str__(self):
61+
def __str__(self) -> str:
5662
return "__".join(f"{n}-{i}" for n, i in self.indices.items())
5763

58-
def __bool__(self):
64+
def __bool__(self) -> bool:
5965
return bool(self.indices)
6066

67+
def subset(self, state_names: ty.Iterable[str]) -> ty.Self:
68+
"""Create a new StateIndex with only the specified fields
69+
70+
Parameters
71+
----------
72+
fields : list[str]
73+
the fields to keep in the new StateIndex
74+
75+
Returns
76+
-------
77+
StateIndex
78+
a new StateIndex with only the specified fields
79+
"""
80+
return type(self)({k: v for k, v in self.indices.items() if k in state_names})
81+
82+
def matches(self, other: "StateIndex") -> bool:
83+
"""Check if the indices that are present in the other StateIndex match
84+
85+
Parameters
86+
----------
87+
other : StateIndex
88+
the other StateIndex to compare against
89+
90+
Returns
91+
-------
92+
bool
93+
True if all the indices in the other StateIndex match
94+
"""
95+
if not set(self.indices).issuperset(other.indices):
96+
raise ValueError(
97+
f"StateIndex {self} does not contain all the indices in {other}"
98+
)
99+
return all(self.indices[k] == v for k, v in other.indices.items())
100+
61101

62102
class State:
63103
"""
@@ -172,6 +212,9 @@ def __str__(self):
172212
def names(self):
173213
"""Return the names of the states."""
174214
# analysing states from connected tasks if inner_inputs
215+
if not hasattr(self, "keys_final"):
216+
self.prepare_states()
217+
self.prepare_inputs()
175218
previous_states_keys = {
176219
f"_{v.name}": v.keys_final for v in self.inner_inputs.values()
177220
}

pydra/engine/submitter.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232

3333
if ty.TYPE_CHECKING:
3434
from .node import Node
35-
from .specs import TaskDef, TaskOutputs, WorkflowDef, TaskHooks, Result
35+
from .specs import WorkflowDef, TaskDef, TaskOutputs, TaskHooks, Result
36+
from .core import Workflow
3637
from .environments import Environment
3738
from .state import State
3839

@@ -501,15 +502,15 @@ class NodeExecution(ty.Generic[DefType]):
501502

502503
_tasks: dict[StateIndex | None, "Task[DefType]"] | None
503504

504-
workflow_inputs: "WorkflowDef"
505+
workflow: "Workflow"
505506

506507
graph: DiGraph["NodeExecution"] | None
507508

508509
def __init__(
509510
self,
510511
node: "Node",
511512
submitter: Submitter,
512-
workflow_inputs: "WorkflowDef",
513+
workflow: "Workflow",
513514
):
514515
self.name = node.name
515516
self.node = node
@@ -523,9 +524,17 @@ def __init__(
523524
self.running = {} # Not used in logic, but may be useful for progress tracking
524525
self.unrunnable = defaultdict(list)
525526
self.state_names = self.node.state.names if self.node.state else []
526-
self.workflow_inputs = workflow_inputs
527+
self.workflow = workflow
527528
self.graph = None
528529

530+
def __repr__(self):
531+
return (
532+
f"NodeExecution(name={self.name!r}, blocked={list(self.blocked)}, "
533+
f"queued={list(self.queued)}, running={list(self.running)}, "
534+
f"successful={list(self.successful)}, errored={list(self.errored)}, "
535+
f"unrunnable={list(self.unrunnable)})"
536+
)
537+
529538
@property
530539
def inputs(self) -> "Node.Inputs":
531540
return self.node.inputs
@@ -547,12 +556,16 @@ def tasks(self) -> ty.Iterable["Task[DefType]"]:
547556
def task(self, index: StateIndex = StateIndex()) -> "Task | list[Task[DefType]]":
548557
"""Get a task object for a given state index."""
549558
self.tasks # Ensure tasks are loaded
550-
try:
551-
return self._tasks[index]
552-
except KeyError:
553-
if not index:
554-
return StateArray(self._tasks.values())
555-
raise
559+
task_index = next(iter(self._tasks))
560+
if len(task_index) > len(index):
561+
tasks = []
562+
for ind, task in self._tasks.items():
563+
if ind.matches(index):
564+
tasks.append(task)
565+
return StateArray(tasks)
566+
elif len(index) > len(task_index):
567+
index = index.subset(task_index)
568+
return self._tasks[index]
556569

557570
@property
558571
def started(self) -> bool:
@@ -651,10 +664,12 @@ def _resolve_lazy_inputs(
651664
The task definition with all lazy fields resolved
652665
"""
653666
resolved = {}
654-
for name, value in attrs_values(self).items():
667+
for name, value in attrs_values(task_def).items():
655668
if isinstance(value, LazyField):
656-
resolved[name] = value._get_value(self, state_index)
657-
return attrs.evolve(self, **resolved)
669+
resolved[name] = value._get_value(
670+
workflow=self.workflow, graph=self.graph, state_index=state_index
671+
)
672+
return attrs.evolve(task_def, **resolved)
658673

659674
def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]:
660675
"""For a given node, check to see which tasks have been successfully run, are ready
@@ -676,19 +691,35 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]:
676691
runnable: list["Task[DefType]"] = []
677692
self.tasks # Ensure tasks are loaded
678693
if not self.started:
694+
assert self._tasks
679695
self.blocked = copy(self._tasks)
680696
# Check to see if any blocked tasks are now runnable/unrunnable
681697
for index, task in list(self.blocked.items()):
682698
pred: NodeExecution
683699
is_runnable = True
684700
for pred in graph.predecessors[self.node.name]:
685-
if index not in pred.successful:
701+
pred_jobs = pred.task(index)
702+
if isinstance(pred_jobs, StateArray):
703+
pred_inds = [j.state_index for j in pred_jobs]
704+
else:
705+
pred_inds = [pred_jobs.state_index]
706+
if not all(i in pred.successful for i in pred_inds):
686707
is_runnable = False
687-
if index in pred.errored:
688-
self.unrunnable[index].append(self.blocked.pop(index))
689-
if index in pred.unrunnable:
690-
self.unrunnable[index].extend(pred.unrunnable[index])
691-
self.blocked.pop(index)
708+
blocked = True
709+
if pred_errored := [i for i in pred_inds if i in pred.errored]:
710+
self.unrunnable[index].extend(
711+
[pred.errored[i] for i in pred_errored]
712+
)
713+
blocked = False
714+
if pred_unrunnable := [
715+
i for i in pred_inds if i in pred.unrunnable
716+
]:
717+
self.unrunnable[index].extend(
718+
[pred.unrunnable[i] for i in pred_unrunnable]
719+
)
720+
blocked = False
721+
if not blocked:
722+
del self.blocked[index]
692723
break
693724
if is_runnable:
694725
runnable.append(self.blocked.pop(index))

0 commit comments

Comments
 (0)