Skip to content

Commit b5a0455

Browse files
authored
Merge pull request #770 from nipype/test_specs
Develop: Test specs
2 parents 4b97ed6 + db8f799 commit b5a0455

File tree

9 files changed

+315
-312
lines changed

9 files changed

+315
-312
lines changed

pydra/engine/core.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -817,9 +817,7 @@ def node_names(self) -> list[str]:
817817
def execution_graph(self, submitter: "Submitter") -> DiGraph:
818818
from pydra.engine.submitter import NodeExecution
819819

820-
exec_nodes = [
821-
NodeExecution(n, submitter, workflow_inputs=self.inputs) for n in self.nodes
822-
]
820+
exec_nodes = [NodeExecution(n, submitter, workflow=self) for n in self.nodes]
823821
graph = self._create_graph(exec_nodes)
824822
# Set the graph attribute of the nodes so lazy fields can be resolved as tasks
825823
# are created

pydra/engine/lazy.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
from . import node
77

88
if ty.TYPE_CHECKING:
9-
from .graph import DiGraph
10-
from .submitter import NodeExecution
9+
from .submitter import DiGraph, NodeExecution
1110
from .core import Task, Workflow
12-
from .specs import TaskDef, WorkflowDef
11+
from .specs import TaskDef
1312
from .state import StateIndex
1413

1514

@@ -46,6 +45,30 @@ def _apply_cast(self, value):
4645
value = self._type(value)
4746
return value
4847

48+
def _get_value(
49+
self,
50+
workflow: "Workflow",
51+
graph: "DiGraph[NodeExecution]",
52+
state_index: "StateIndex | None" = None,
53+
) -> ty.Any:
54+
"""Return the value of a lazy field.
55+
56+
Parameters
57+
----------
58+
workflow: Workflow
59+
the workflow object
60+
graph: DiGraph[NodeExecution]
61+
the graph representing the execution state of the workflow
62+
state_index : StateIndex, optional
63+
the state index of the field to access
64+
65+
Returns
66+
-------
67+
value : Any
68+
the resolved value of the lazy-field
69+
"""
70+
raise NotImplementedError("LazyField is an abstract class")
71+
4972

5073
@attrs.define(kw_only=True)
5174
class LazyInField(LazyField[T]):
@@ -70,23 +93,27 @@ def _source(self):
7093

7194
def _get_value(
7295
self,
73-
workflow_def: "WorkflowDef",
96+
workflow: "Workflow",
97+
graph: "DiGraph[NodeExecution]",
98+
state_index: "StateIndex | None" = None,
7499
) -> ty.Any:
75100
"""Return the value of a lazy field.
76101
77102
Parameters
78103
----------
79-
wf : Workflow
80-
the workflow the lazy field references
81-
state_index : int, optional
104+
workflow: Workflow
105+
the workflow object
106+
graph: DiGraph[NodeExecution]
107+
the graph representing the execution state of the workflow
108+
state_index : StateIndex, optional
82109
the state index of the field to access
83110
84111
Returns
85112
-------
86113
value : Any
87114
the resolved value of the lazy-field
88115
"""
89-
value = workflow_def[self._field]
116+
value = workflow.inputs[self._field]
90117
value = self._apply_cast(value)
91118
return value
92119

@@ -105,16 +132,19 @@ def __repr__(self):
105132

106133
def _get_value(
107134
self,
135+
workflow: "Workflow",
108136
graph: "DiGraph[NodeExecution]",
109137
state_index: "StateIndex | None" = None,
110138
) -> ty.Any:
111139
"""Return the value of a lazy field.
112140
113141
Parameters
114142
----------
115-
wf : Workflow
116-
the workflow the lazy field references
117-
state_index : int, optional
143+
workflow: Workflow
144+
the workflow object
145+
graph: DiGraph[NodeExecution]
146+
the graph representing the execution state of the workflow
147+
state_index : StateIndex, optional
118148
the state index of the field to access
119149
120150
Returns

pydra/engine/node.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,28 @@ def lzout(self) -> OutputType:
121121
type=field.type,
122122
)
123123
outputs = self.inputs.Outputs(**lazy_fields)
124-
# Flag the output lazy fields as being not typed checked (i.e. assigned to another
125-
# node's inputs) yet
124+
126125
outpt: lazy.LazyOutField
127126
for outpt in attrs_values(outputs).values():
128-
outpt._type_checked = False
127+
# Assign the current node to the lazy fields so they can access the state
129128
outpt._node = self
129+
# If the node has a non-empty state, wrap the type of the lazy field in
130+
# a combination of an optional list and a number of nested StateArrays
131+
# types based on the number of states the node is split over and whether
132+
# it has a combiner
133+
if self._state:
134+
type_, _ = TypeParser.strip_splits(outpt._type)
135+
if self._state.combiner:
136+
type_ = list[type_]
137+
for _ in range(self._state.depth - int(bool(self._state.combiner))):
138+
type_ = StateArray[type_]
139+
outpt._type = type_
140+
# Flag the output lazy fields as being not typed checked (i.e. assigned to
141+
# another node's inputs) yet. This is used to prevent the user from changing
142+
# the type of the output after it has been accessed by connecting it to an
143+
# output of an upstream node with additional state variables.
144+
outpt._type_checked = False
130145
self._lzout = outputs
131-
self._wrap_lzout_types_in_state_arrays()
132146
return outputs
133147

134148
@property
@@ -217,20 +231,6 @@ def _check_if_outputs_have_been_used(self, msg):
217231
+ msg
218232
)
219233

220-
def _wrap_lzout_types_in_state_arrays(self) -> None:
221-
"""Wraps a types of the lazy out fields in a number of nested StateArray types
222-
based on the number of states the node is split over"""
223-
# Unwrap StateArray types from the output types
224-
if not self.state:
225-
return
226-
outpt_lf: lazy.LazyOutField
227-
for outpt_lf in attrs_values(self.lzout).values():
228-
assert not outpt_lf._type_checked
229-
type_, _ = TypeParser.strip_splits(outpt_lf._type)
230-
for _ in range(self._state.depth):
231-
type_ = StateArray[type_]
232-
outpt_lf._type = type_
233-
234234
def _set_state(self) -> None:
235235
# Add node name to state's splitter, combiner and cont_dim loaded from the def
236236
splitter = self._definition._splitter
@@ -269,7 +269,11 @@ def _get_upstream_states(self) -> dict[str, tuple["State", list[str]]]:
269269
"""Get the states of the upstream nodes that are connected to this node"""
270270
upstream_states = {}
271271
for inpt_name, val in self.input_values:
272-
if isinstance(val, lazy.LazyOutField) and val._node.state:
272+
if (
273+
isinstance(val, lazy.LazyOutField)
274+
and val._node.state
275+
and val._node.state.depth
276+
):
273277
node: Node = val._node
274278
# variables that are part of inner splitters should be treated as a containers
275279
if node.state and f"{node.name}.{inpt_name}" in node.state.splitter:

pydra/engine/specs.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,12 @@
3535
from pydra.utils.typing import StateArray, MultiInputObj
3636
from pydra.design.base import Field, Arg, Out, RequirementSet, NO_DEFAULT
3737
from pydra.design import shell
38-
from pydra.engine.lazy import LazyInField, LazyOutField
3938

4039
if ty.TYPE_CHECKING:
4140
from pydra.engine.core import Task
4241
from pydra.engine.graph import DiGraph
4342
from pydra.engine.submitter import NodeExecution
4443
from pydra.engine.core import Workflow
45-
from pydra.engine.state import StateIndex
4644
from pydra.engine.environments import Environment
4745
from pydra.engine.workers import Worker
4846

@@ -476,41 +474,6 @@ def _compute_hashes(self) -> ty.Tuple[bytes, ty.Dict[str, bytes]]:
476474
}
477475
return hash_function(sorted(field_hashes.items())), field_hashes
478476

479-
def _resolve_lazy_inputs(
480-
self,
481-
workflow_inputs: "WorkflowDef",
482-
graph: "DiGraph[NodeExecution]",
483-
state_index: "StateIndex | None" = None,
484-
) -> Self:
485-
"""Resolves lazy fields in the task definition by replacing them with their
486-
actual values.
487-
488-
Parameters
489-
----------
490-
workflow : Workflow
491-
The workflow the task is part of
492-
graph : DiGraph[NodeExecution]
493-
The execution graph of the workflow
494-
state_index : StateIndex, optional
495-
The state index for the workflow, by default None
496-
497-
Returns
498-
-------
499-
Self
500-
The task definition with all lazy fields resolved
501-
"""
502-
from pydra.engine.state import StateIndex
503-
504-
if state_index is None:
505-
state_index = StateIndex()
506-
resolved = {}
507-
for name, value in attrs_values(self).items():
508-
if isinstance(value, LazyInField):
509-
resolved[name] = value._get_value(workflow_inputs)
510-
elif isinstance(value, LazyOutField):
511-
resolved[name] = value._get_value(graph, state_index)
512-
return attrs.evolve(self, **resolved)
513-
514477
def _check_rules(self):
515478
"""Check if all rules are satisfied."""
516479

@@ -773,7 +736,7 @@ def _from_task(cls, task: "Task[WorkflowDef]") -> Self:
773736
nodes_dict = {n.name: n for n in exec_graph.nodes}
774737
for name, lazy_field in attrs_values(workflow.outputs).items():
775738
try:
776-
val_out = lazy_field._get_value(exec_graph)
739+
val_out = lazy_field._get_value(workflow=workflow, graph=exec_graph)
777740
output_wf[name] = val_out
778741
except (ValueError, AttributeError):
779742
output_wf[name] = None

pydra/engine/state.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,23 +41,63 @@ def __init__(self, indices: dict[str, int] | None = None):
4141
else:
4242
self.indices = OrderedDict(sorted(indices.items()))
4343

44-
def __repr__(self):
44+
def __len__(self) -> int:
45+
return len(self.indices)
46+
47+
def __iter__(self) -> ty.Generator[str, None, None]:
48+
return iter(self.indices)
49+
50+
def __repr__(self) -> str:
4551
return (
4652
"StateIndex(" + ", ".join(f"{n}={v}" for n, v in self.indices.items()) + ")"
4753
)
4854

4955
def __hash__(self):
5056
return hash(tuple(self.indices.items()))
5157

52-
def __eq__(self, other):
58+
def __eq__(self, other) -> bool:
5359
return self.indices == other.indices
5460

55-
def __str__(self):
61+
def __str__(self) -> str:
5662
return "__".join(f"{n}-{i}" for n, i in self.indices.items())
5763

58-
def __bool__(self):
64+
def __bool__(self) -> bool:
5965
return bool(self.indices)
6066

67+
def subset(self, state_names: ty.Iterable[str]) -> ty.Self:
68+
"""Create a new StateIndex with only the specified fields
69+
70+
Parameters
71+
----------
72+
fields : list[str]
73+
the fields to keep in the new StateIndex
74+
75+
Returns
76+
-------
77+
StateIndex
78+
a new StateIndex with only the specified fields
79+
"""
80+
return type(self)({k: v for k, v in self.indices.items() if k in state_names})
81+
82+
def matches(self, other: "StateIndex") -> bool:
83+
"""Check if the indices that are present in the other StateIndex match
84+
85+
Parameters
86+
----------
87+
other : StateIndex
88+
the other StateIndex to compare against
89+
90+
Returns
91+
-------
92+
bool
93+
True if all the indices in the other StateIndex match
94+
"""
95+
if not set(self.indices).issuperset(other.indices):
96+
raise ValueError(
97+
f"StateIndex {self} does not contain all the indices in {other}"
98+
)
99+
return all(self.indices[k] == v for k, v in other.indices.items())
100+
61101

62102
class State:
63103
"""
@@ -172,6 +212,9 @@ def __str__(self):
172212
def names(self):
173213
"""Return the names of the states."""
174214
# analysing states from connected tasks if inner_inputs
215+
if not hasattr(self, "keys_final"):
216+
self.prepare_states()
217+
self.prepare_inputs()
175218
previous_states_keys = {
176219
f"_{v.name}": v.keys_final for v in self.inner_inputs.values()
177220
}
@@ -190,13 +233,13 @@ def names(self):
190233

191234
@property
192235
def depth(self) -> int:
193-
"""Return the number of uncombined splits of the state, i.e. the number nested
236+
"""Return the number of splits of the state, i.e. the number nested
194237
state arrays to wrap around the type of lazy out fields
195238
196239
Returns
197240
-------
198241
int
199-
number of uncombined splits
242+
number of uncombined independent splits (i.e. linked splits only add 1)
200243
"""
201244
depth = 0
202245
stack = []
@@ -210,7 +253,8 @@ def depth(self) -> int:
210253
stack = []
211254
else:
212255
stack.append(spl)
213-
return depth + len(stack)
256+
remaining_stack = [s for s in stack if s not in self.combiner]
257+
return depth + len(remaining_stack)
214258

215259
@property
216260
def splitter(self):

0 commit comments

Comments
 (0)