Skip to content

Commit ce860d2

Browse files
committed
reimplementing lazy-field split
1 parent 4e331e5 commit ce860d2

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

pydra/engine/lazy.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ def _get_value(
6868
"""
6969
raise NotImplementedError("LazyField is an abstract class")
7070

71+
def split(self) -> "LazyField":
72+
""" "Splits" the lazy field over an array of nodes by replacing the sequence type
73+
of the lazy field with StateArray to signify that it will be "split" across
74+
"""
75+
raise NotImplementedError("LazyField is an abstract class")
76+
7177

7278
@attrs.define(kw_only=True)
7379
class LazyInField(LazyField[T]):
@@ -116,6 +122,25 @@ def _get_value(
116122
value = self._apply_cast(value)
117123
return value
118124

125+
def split(self) -> "LazyField":
126+
""" "Splits" the lazy field over an array of nodes by replacing the sequence type
127+
of the lazy field with StateArray to signify that it will be "split" across
128+
"""
129+
from ..utils.typing import TypeParser # pylint: disable=import-outside-toplevel
130+
131+
assert not isinstance(self, LazyInField)
132+
133+
if not TypeParser.matches_type(self.type, list):
134+
raise TypeError(
135+
f"Cannot split non-sequence field {self} of type {self.type}"
136+
)
137+
138+
return type(self)(
139+
name=self.name,
140+
field=self.field,
141+
type=StateArray[TypeParser.get_item_type(self.type)],
142+
)
143+
119144

120145
@attrs.define(kw_only=True)
121146
class LazyOutField(LazyField[T]):
@@ -206,6 +231,34 @@ def retrieve_from_job(job: "Task[DefType]") -> ty.Any:
206231
# else:
207232
return [retrieve_from_job(j) for j in jobs]
208233

234+
def split(self) -> "LazyField":
235+
""" "Splits" the lazy field over an array of nodes by replacing the sequence type
236+
of the lazy field with StateArray to signify that it will be "split" across
237+
"""
238+
from ..utils.typing import TypeParser # pylint: disable=import-outside-toplevel
239+
240+
# Modify the type of the lazy field to include the split across a state-array
241+
inner_type, prev_split_depth = TypeParser.strip_splits(self.type)
242+
assert prev_split_depth <= 1
243+
if inner_type is ty.Any:
244+
type_ = StateArray[ty.Any]
245+
elif TypeParser.matches_type(inner_type, list):
246+
item_type = TypeParser.get_item_type(inner_type)
247+
type_ = StateArray[item_type]
248+
else:
249+
raise TypeError(
250+
f"Cannot split non-sequence field {self} of type {inner_type}"
251+
)
252+
if prev_split_depth:
253+
type_ = StateArray[
254+
type_
255+
] # FIXME: This nesting of StateArray is probably unnecessary
256+
return type(self)[type_](
257+
name=self.name,
258+
field=self.field,
259+
type=type_,
260+
)
261+
209262
@property
210263
def _source(self):
211264
return self._node

0 commit comments

Comments
 (0)