Skip to content

Commit 553bb2f

Browse files
committed
fixing up state maintenance in workflow nodes
1 parent 384e57d commit 553bb2f

File tree

5 files changed

+208
-360
lines changed

5 files changed

+208
-360
lines changed

pydra/design/tests/test_workflow.py

Lines changed: 34 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,36 @@
99
from pydra.engine.specs import TaskSpec
1010
from fileformats import video, image
1111

12+
# NB: We use PascalCase for interfaces and workflow functions as it is translated into a class
1213

13-
def test_workflow():
1414

15-
# NB: We use PascalCase (i.e. class names) as it is translated into a class
15+
@python.define
16+
def Add(a, b):
17+
return a + b
1618

17-
@python.define
18-
def Add(a, b):
19-
return a + b
2019

21-
@python.define
22-
def Mul(a, b):
23-
return a * b
20+
@python.define
21+
def Mul(a, b):
22+
return a * b
23+
24+
25+
@python.define(outputs=["divided"])
26+
def Divide(x, y):
27+
return x / y
28+
29+
30+
@python.define
31+
def Sum(x: list[float]) -> float:
32+
return sum(x)
33+
34+
35+
def a_converter(value):
36+
if value is attrs.NOTHING:
37+
return value
38+
return float(value)
39+
40+
41+
def test_workflow():
2442

2543
@workflow.define
2644
def MyTestWorkflow(a, b):
@@ -109,7 +127,7 @@ def MyTestShellWorkflow(
109127
assert wf.inputs.input_video == input_video
110128
assert wf.inputs.watermark == watermark
111129
assert wf.outputs.output_video == LazyOutField(
112-
node=wf["resize"], field="out_video", type=video.Mp4
130+
node=wf["resize"], field="out_video", type=video.Mp4, type_checked=True
113131
)
114132
assert list(wf.node_names) == ["add_watermark", "resize"]
115133

@@ -119,19 +137,6 @@ def test_workflow_canonical():
119137

120138
# NB: We use PascalCase (i.e. class names) as it is translated into a class
121139

122-
@python.define
123-
def Add(a, b):
124-
return a + b
125-
126-
@python.define
127-
def Mul(a, b):
128-
return a * b
129-
130-
def a_converter(value):
131-
if value is attrs.NOTHING:
132-
return value
133-
return float(value)
134-
135140
@workflow.define
136141
class MyTestWorkflow(TaskSpec["MyTestWorkflow.Outputs"]):
137142

@@ -220,10 +225,10 @@ def MyTestShellWorkflow(
220225
)
221226
wf = Workflow.construct(workflow_spec)
222227
assert wf["add_watermark"].inputs.in_video == LazyInField(
223-
node=wf, field="input_video", type=video.Mp4
228+
workflow=wf, field="input_video", type=video.Mp4, type_checked=True
224229
)
225230
assert wf["add_watermark"].inputs.watermark == LazyInField(
226-
node=wf, field="watermark", type=image.Png
231+
workflow=wf, field="watermark", type=image.Png, type_checked=True
227232
)
228233

229234

@@ -236,10 +241,6 @@ def Add(x, y):
236241
def Mul(x, y):
237242
return x * y
238243

239-
@python.define(outputs=["divided"])
240-
def Divide(x, y):
241-
return x / y
242-
243244
@workflow.define(outputs=["out1", "out2"])
244245
def MyTestWorkflow(a: int, b: float) -> tuple[float, float]:
245246
"""A test workflow demonstration a few alternative ways to set and connect nodes
@@ -279,7 +280,9 @@ def MyTestWorkflow(a: int, b: float) -> tuple[float, float]:
279280
wf = Workflow.construct(workflow_spec)
280281
assert wf.inputs.a == 1
281282
assert wf.inputs.b == 2.0
282-
assert wf.outputs.out1 == LazyOutField(node=wf["Mul"], field="out", type=float)
283+
assert wf.outputs.out1 == LazyOutField(
284+
node=wf["Mul"], field="out", type=float, type_checked=True
285+
)
283286
assert wf.outputs.out2 == LazyOutField(
284287
node=wf["division"], field="divided", type=ty.Any
285288
)
@@ -288,14 +291,6 @@ def MyTestWorkflow(a: int, b: float) -> tuple[float, float]:
288291

289292
def test_workflow_set_outputs_directly():
290293

291-
@python.define
292-
def Add(a, b):
293-
return a + b
294-
295-
@python.define
296-
def Mul(a, b):
297-
return a * b
298-
299294
@workflow.define(outputs={"out1": float, "out2": float})
300295
def MyTestWorkflow(a: int, b: float):
301296

@@ -362,10 +357,6 @@ def Mul(x: float, y: float) -> float:
362357
def Add(x: float, y: float) -> float:
363358
return x + y
364359

365-
@python.define
366-
def Sum(x: list[float]) -> float:
367-
return sum(x)
368-
369360
@workflow.define
370361
def MyTestWorkflow(a: list[int], b: list[float], c: float) -> list[float]:
371362
mul = workflow.add(Mul()).split(x=a, y=b)
@@ -387,11 +378,11 @@ def test_workflow_split_after_access_fail():
387378
"""
388379

389380
@python.define
390-
def Add(x, y):
381+
def Add(x: float, y: float) -> float:
391382
return x + y
392383

393384
@python.define
394-
def Mul(x, y):
385+
def Mul(x: float, y: float) -> float:
395386
return x * y
396387

397388
@workflow.define

pydra/engine/workflow/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def construct(
7777
lazy_spec,
7878
lzy_inpt.name,
7979
LazyInField(
80-
node=wf,
80+
workflow=wf,
8181
field=lzy_inpt.name,
8282
type=lzy_inpt.type,
8383
),

pydra/engine/workflow/lazy.py

Lines changed: 71 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import typing as ty
2+
import abc
23
from typing_extensions import Self
34
import attrs
45
from pydra.utils.typing import StateArray
6+
from pydra.utils.hash import hash_single
57
from . import node
68

79
if ty.TYPE_CHECKING:
@@ -13,91 +15,22 @@
1315
TypeOrAny = ty.Union[type, ty.Any]
1416

1517

16-
@attrs.define(auto_attribs=True, kw_only=True)
17-
class LazyField(ty.Generic[T]):
18+
@attrs.define(kw_only=True)
19+
class LazyField(ty.Generic[T], metaclass=abc.ABCMeta):
1820
"""Lazy fields implement promises."""
1921

20-
node: node.Node
2122
field: str
2223
type: TypeOrAny
23-
# Set of splitters that have been applied to the lazy field. Note that the splitter
24-
# specifications are transformed to a tuple[tuple[str, ...], ...] form where the
25-
# outer tuple is the outer product, the inner tuple are inner products (where either
26-
# product can be of length==1)
27-
splits: ty.FrozenSet[ty.Tuple[ty.Tuple[str, ...], ...]] = attrs.field(
28-
factory=frozenset, converter=frozenset
29-
)
3024
cast_from: ty.Optional[ty.Type[ty.Any]] = None
31-
# type_checked will be set to False after it is created but defaults to True here for
32-
# ease of testing
33-
type_checked: bool = True
25+
type_checked: bool = False
3426

3527
def __bytes_repr__(self, cache):
36-
yield type(self).__name__.encode()
37-
yield self.name.encode()
38-
yield self.field.encode()
39-
40-
def cast(self, new_type: TypeOrAny) -> Self:
41-
""" "casts" the lazy field to a new type
42-
43-
Parameters
44-
----------
45-
new_type : type
46-
the type to cast the lazy-field to
47-
48-
Returns
49-
-------
50-
cast_field : LazyField
51-
a copy of the lazy field with the new type
52-
"""
53-
return type(self)[new_type](
54-
name=self.name,
55-
field=self.field,
56-
type=new_type,
57-
splits=self.splits,
58-
cast_from=self.cast_from if self.cast_from else self.type,
59-
)
60-
61-
# def split(self, splitter: Splitter) -> Self:
62-
# """ "Splits" the lazy field over an array of nodes by replacing the sequence type
63-
# of the lazy field with StateArray to signify that it will be "split" across
64-
65-
# Parameters
66-
# ----------
67-
# splitter : str or ty.Tuple[str, ...] or ty.List[str]
68-
# the splitter to append to the list of splitters
69-
# """
70-
# from pydra.utils.typing import (
71-
# TypeParser,
72-
# ) # pylint: disable=import-outside-toplevel
73-
74-
# splits = self.splits | set([LazyField.normalize_splitter(splitter)])
75-
# # Check to see whether the field has already been split over the given splitter
76-
# if splits == self.splits:
77-
# return self
78-
79-
# # Modify the type of the lazy field to include the split across a state-array
80-
# inner_type, prev_split_depth = TypeParser.strip_splits(self.type)
81-
# assert prev_split_depth <= 1
82-
# if inner_type is ty.Any:
83-
# type_ = StateArray[ty.Any]
84-
# elif TypeParser.matches_type(inner_type, list):
85-
# item_type = TypeParser.get_item_type(inner_type)
86-
# type_ = StateArray[item_type]
87-
# else:
88-
# raise TypeError(
89-
# f"Cannot split non-sequence field {self} of type {inner_type}"
90-
# )
91-
# if prev_split_depth:
92-
# type_ = StateArray[type_]
93-
# return type(self)[type_](
94-
# name=self.name,
95-
# field=self.field,
96-
# type=type_,
97-
# splits=splits,
98-
# )
99-
100-
# # def combine(self, combiner: str | list[str]) -> Self:
28+
yield type(self).__name__.encode() + b"("
29+
yield from bytes(hash_single(self.source, cache))
30+
yield b"field=" + self.field.encode()
31+
yield b"type=" + bytes(hash_single(self.type, cache))
32+
yield b"cast_from=" + bytes(hash_single(self.cast_from, cache))
33+
yield b")"
10134

10235
def _apply_cast(self, value):
10336
"""\"Casts\" the value from the retrieved type if a cast has been applied to
@@ -110,19 +43,24 @@ def _apply_cast(self, value):
11043
return value
11144

11245

113-
@attrs.define(auto_attribs=True, kw_only=True)
46+
@attrs.define(kw_only=True)
11447
class LazyInField(LazyField[T]):
11548

49+
workflow: "Workflow" = attrs.field()
50+
11651
attr_type = "input"
11752

11853
def __eq__(self, other):
11954
return (
12055
isinstance(other, LazyInField)
12156
and self.field == other.field
12257
and self.type == other.type
123-
and self.splits == other.splits
12458
)
12559

60+
@property
61+
def source(self):
62+
return self.workflow
63+
12664
def get_value(self, wf: "Workflow", state_index: ty.Optional[int] = None) -> ty.Any:
12765
"""Return the value of a lazy field.
12866
@@ -155,8 +93,31 @@ def apply_splits(obj, depth):
15593
value = self._apply_cast(value)
15694
return value
15795

96+
def cast(self, new_type: TypeOrAny) -> Self:
97+
""" "casts" the lazy field to a new type
15898
99+
Parameters
100+
----------
101+
new_type : type
102+
the type to cast the lazy-field to
103+
104+
Returns
105+
-------
106+
cast_field : LazyInField
107+
a copy of the lazy field with the new type
108+
"""
109+
return type(self)[new_type](
110+
workflow=self.workflow,
111+
field=self.field,
112+
type=new_type,
113+
cast_from=self.cast_from if self.cast_from else self.type,
114+
)
115+
116+
117+
@attrs.define(kw_only=True)
159118
class LazyOutField(LazyField[T]):
119+
120+
node: node.Node
160121
attr_type = "output"
161122

162123
def get_value(self, wf: "Workflow", state_index: ty.Optional[int] = None) -> ty.Any:
@@ -178,16 +139,15 @@ def get_value(self, wf: "Workflow", state_index: ty.Optional[int] = None) -> ty.
178139
TypeParser,
179140
) # pylint: disable=import-outside-toplevel
180141

181-
node = getattr(wf, self.name)
182-
result = node.result(state_index=state_index)
142+
result = self.node.result(state_index=state_index)
183143
if result is None:
184144
raise RuntimeError(
185-
f"Could not find results of '{node.name}' node in a sub-directory "
186-
f"named '{node.checksum}' in any of the cache locations.\n"
187-
+ "\n".join(str(p) for p in set(node.cache_locations))
145+
f"Could not find results of '{self.node.name}' node in a sub-directory "
146+
f"named '{self.node.checksum}' in any of the cache locations.\n"
147+
+ "\n".join(str(p) for p in set(self.node.cache_locations))
188148
+ f"\n\nThis is likely due to hash changes in '{self.name}' node inputs. "
189-
f"Current values and hashes: {node.inputs}, "
190-
f"{node.inputs._hashes}\n\n"
149+
f"Current values and hashes: {self.node.inputs}, "
150+
f"{self.node.inputs._hashes}\n\n"
191151
"Set loglevel to 'debug' in order to track hash changes "
192152
"throughout the execution of the workflow.\n\n "
193153
"These issues may have been caused by `bytes_repr()` methods "
@@ -224,3 +184,27 @@ def get_nested_results(res, depth: int):
224184
value = get_nested_results(result, depth=split_depth)
225185
value = self._apply_cast(value)
226186
return value
187+
188+
@property
189+
def source(self):
190+
return self.node
191+
192+
def cast(self, new_type: TypeOrAny) -> Self:
193+
""" "casts" the lazy field to a new type
194+
195+
Parameters
196+
----------
197+
new_type : type
198+
the type to cast the lazy-field to
199+
200+
Returns
201+
-------
202+
cast_field : LazyOutField
203+
a copy of the lazy field with the new type
204+
"""
205+
return type(self)[new_type](
206+
node=self.node,
207+
field=self.field,
208+
type=new_type,
209+
cast_from=self.cast_from if self.cast_from else self.type,
210+
)

0 commit comments

Comments
 (0)