Skip to content

Commit 125a0a5

Browse files
committed
fixed up retrospective setting of node inputs via the inputs attribute of the TaskOutputs object
1 parent 546f2c9 commit 125a0a5

File tree

3 files changed

+93
-23
lines changed

3 files changed

+93
-23
lines changed

pydra/compose/base/task.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,16 @@ class Outputs:
3939
RESERVED_FIELD_NAMES = ("inputs",)
4040

4141
_cache_dir: Path = attrs.field(default=None, init=False, repr=False)
42+
_node = attrs.field(default=None, init=False, repr=False)
4243

4344
@property
4445
def inputs(self):
4546
"""The inputs object associated with a lazy-outputs object"""
46-
return self._get_node().inputs
47+
if self._node is None:
48+
raise AttributeError(
49+
f"{self} outputs object is not a lazy output of a workflow node"
50+
)
51+
return self._node.inputs
4752

4853
@classmethod
4954
def _from_task(cls, job: "Job[TaskType]") -> Self:
@@ -81,14 +86,6 @@ def _results(self) -> "Result[Self]":
8186
with open(results_path, "rb") as f:
8287
return cp.load(f)
8388

84-
def _get_node(self):
85-
try:
86-
return self._node
87-
except AttributeError:
88-
raise AttributeError(
89-
f"{self} outputs object is not a lazy output of a workflow node"
90-
) from None
91-
9289
def __iter__(self) -> ty.Generator[str, None, None]:
9390
"""The names of the fields in the output object"""
9491
return iter(sorted(f.name for f in attrs_fields(self)))

pydra/compose/tests/test_workflow_fields.py

Lines changed: 82 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from operator import attrgetter
2+
from pathlib import Path
23
from copy import copy
34
from unittest.mock import Mock
45
import pytest
@@ -15,17 +16,17 @@
1516

1617

1718
@python.define
18-
def Add(a, b):
19+
def Add(a: int | float, b: int | float) -> int | float:
1920
return a + b
2021

2122

2223
@python.define
23-
def Mul(a, b):
24+
def Mul(a: int | float, b: int | float) -> int | float:
2425
return a * b
2526

2627

2728
@python.define(outputs=["divided"])
28-
def Divide(x, y):
29+
def Divide(x: int | float, y: int | float) -> float:
2930
return x / y
3031

3132

@@ -68,7 +69,9 @@ def MyTestWorkflow(a, b):
6869
wf = Workflow.construct(workflow_spec)
6970
assert wf.inputs.a == 1
7071
assert wf.inputs.b == 2.0
71-
assert wf.outputs.out == LazyOutField(node=wf["Mul"], field="out", type=ty.Any)
72+
assert wf.outputs.out == LazyOutField(
73+
node=wf["Mul"], field="out", type=int | float, type_checked=True
74+
)
7275

7376
# Nodes are named after the specs by default
7477
assert list(wf.node_names) == ["Add", "Mul"]
@@ -185,7 +188,9 @@ class Outputs(workflow.Outputs):
185188
wf = Workflow.construct(workflow_spec)
186189
assert wf.inputs.a == 1
187190
assert wf.inputs.b == 2.0
188-
assert wf.outputs.out == LazyOutField(node=wf["Mul"], field="out", type=ty.Any)
191+
assert wf.outputs.out == LazyOutField(
192+
node=wf["Mul"], field="out", type=int | float, type_checked=True
193+
)
189194

190195
# Nodes are named after the specs by default
191196
assert list(wf.node_names) == ["Add", "Mul"]
@@ -323,7 +328,7 @@ def MyTestWorkflow(a: int, b: float) -> tuple[float, float]:
323328
node=wf["Mul"], field="out", type=float, type_checked=True
324329
)
325330
assert wf.outputs.out2 == LazyOutField(
326-
node=wf["division"], field="divided", type=ty.Any
331+
node=wf["division"], field="divided", type=float, type_checked=True
327332
)
328333
assert list(wf.node_names) == ["addition", "Mul", "division"]
329334

@@ -362,8 +367,12 @@ def MyTestWorkflow(a: int, b: float):
362367
wf = Workflow.construct(workflow_spec)
363368
assert wf.inputs.a == 1
364369
assert wf.inputs.b == 2.0
365-
assert wf.outputs.out1 == LazyOutField(node=wf["Mul"], field="out", type=ty.Any)
366-
assert wf.outputs.out2 == LazyOutField(node=wf["Add"], field="out", type=ty.Any)
370+
assert wf.outputs.out1 == LazyOutField(
371+
node=wf["Mul"], field="out", type=int | float, type_checked=True
372+
)
373+
assert wf.outputs.out2 == LazyOutField(
374+
node=wf["Add"], field="out", type=int | float, type_checked=True
375+
)
367376
assert list(wf.node_names) == ["Add", "Mul"]
368377

369378

@@ -500,3 +509,68 @@ def RecursiveNestedWorkflow(a: float, depth: int) -> float:
500509
type=float,
501510
type_checked=True,
502511
)
512+
513+
514+
def test_workflow_lzout_inputs1(tmp_path: Path):
515+
516+
@workflow.define
517+
def InputAccessWorkflow(a, b, c):
518+
add = workflow.add(Add(a=a, b=b))
519+
add.inputs.a = c
520+
mul = workflow.add(Mul(a=add.out, b=b))
521+
return mul.out
522+
523+
input_access_workflow = InputAccessWorkflow(a=1, b=2.0, c=3.0)
524+
outputs = input_access_workflow(cache_root=tmp_path)
525+
assert outputs.out == 10.0
526+
527+
528+
def test_workflow_lzout_inputs2(tmp_path: Path):
529+
530+
@workflow.define
531+
def InputAccessWorkflow(a, b, c):
532+
add = workflow.add(Add(a=a, b=b))
533+
add.inputs.a = c
534+
mul = workflow.add(Mul(a=add.out, b=b))
535+
return mul.out
536+
537+
input_access_workflow = InputAccessWorkflow(a=1, b=2.0, c=3.0)
538+
outputs = input_access_workflow(cache_root=tmp_path)
539+
assert outputs.out == 10.0
540+
541+
542+
def test_workflow_lzout_inputs2(tmp_path: Path):
543+
"""Set the inputs of the 'add' node after its outputs have been accessed
544+
but the state has not been altered"""
545+
546+
@workflow.define
547+
def InputAccessWorkflow(a, b, c):
548+
add = workflow.add(Add(a=a, b=b))
549+
mul = workflow.add(Mul(a=add.out, b=b))
550+
add.inputs.a = c
551+
return mul.out
552+
553+
input_access_workflow = InputAccessWorkflow(a=1, b=2.0, c=3.0)
554+
outputs = input_access_workflow(cache_root=tmp_path)
555+
assert outputs.out == 10.0
556+
557+
558+
def test_workflow_lzout_inputs_state_change_fail(tmp_path: Path):
559+
"""Set the inputs of the 'mul' node after its outputs have been accessed
560+
with an upstream lazy field that has a different state than the original.
561+
This changes the type of the input and is therefore not permitted"""
562+
563+
@workflow.define
564+
def InputAccessWorkflow(a, b, c):
565+
add1 = workflow.add(Add(a=a, b=b), name="add1")
566+
add2 = workflow.add(Add(a=a).split(b=c), name="add2")
567+
mul1 = workflow.add(Mul(a=add1.out, b=b), name="mul1")
568+
mul2 = workflow.add(Mul(a=mul1.out, b=b), name="mul2")
569+
mul1.inputs.a = add2.out
570+
return mul2.out
571+
572+
input_access_workflow = InputAccessWorkflow(a=1, b=2.0, c=[3.0, 4.0])
573+
with pytest.raises(
574+
RuntimeError, match="have already been accessed and therefore cannot set"
575+
):
576+
input_access_workflow.construct()

pydra/engine/node.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from enum import Enum
44
import attrs
55
from pydra.engine import lazy
6-
from pydra.utils.general import attrs_values
6+
from pydra.utils.general import attrs_values, task_dict
77
from pydra.utils.typing import is_lazy
88
from pydra.engine.state import State, add_name_splitter, add_name_combiner
99

@@ -79,7 +79,7 @@ def __setattr__(self, name: str, value: ty.Any) -> None:
7979
f"cannot set {name!r} input to {value} because it changes the "
8080
f"state"
8181
)
82-
self._set_state()
82+
self._node._set_state()
8383

8484
@property
8585
def inputs(self) -> Inputs:
@@ -144,6 +144,7 @@ def lzout(self) -> OutputType:
144144
# output of an upstream node with additional state variables.
145145
outpt._type_checked = False
146146
self._lzout = outputs
147+
outputs._node = self
147148
return outputs
148149

149150
@property
@@ -161,10 +162,8 @@ def combiner(self):
161162
def _check_if_outputs_have_been_used(self, msg):
162163
used = []
163164
if self._lzout:
164-
for outpt_name, outpt_val in attrs.asdict(
165-
self._lzout, recurse=False
166-
).items():
167-
if outpt_val.type_checked:
165+
for outpt_name, outpt_val in task_dict(self._lzout).items():
166+
if outpt_val._type_checked:
168167
used.append(outpt_name)
169168
if used:
170169
raise RuntimeError(

0 commit comments

Comments
 (0)