Skip to content

Commit 384e57d

Browse files
committed
fixed up lazy out splitting
1 parent 03e6951 commit 384e57d

File tree

2 files changed

+26
-27
lines changed

2 files changed

+26
-27
lines changed

pydra/engine/workflow/lazy.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
T = ty.TypeVar("T")
1212

1313
TypeOrAny = ty.Union[type, ty.Any]
14-
Splitter = ty.Union[str, ty.Tuple[str, ...]]
1514

1615

1716
@attrs.define(auto_attribs=True, kw_only=True)
@@ -100,29 +99,6 @@ def cast(self, new_type: TypeOrAny) -> Self:
10099

101100
# # def combine(self, combiner: str | list[str]) -> Self:
102101

103-
# @classmethod
104-
# def normalize_splitter(
105-
# cls, splitter: Splitter, strip_previous: bool = True
106-
# ) -> ty.Tuple[ty.Tuple[str, ...], ...]:
107-
# """Converts the splitter spec into a consistent tuple[tuple[str, ...], ...] form
108-
# used in LazyFields"""
109-
# if isinstance(splitter, str):
110-
# splitter = (splitter,)
111-
# if isinstance(splitter, tuple):
112-
# splitter = (splitter,) # type: ignore
113-
# else:
114-
# assert isinstance(splitter, list)
115-
# # convert to frozenset to differentiate from tuple, yet still be hashable
116-
# # (NB: order of fields in list splitters aren't relevant)
117-
# splitter = tuple((s,) if isinstance(s, str) else s for s in splitter)
118-
# # Strip out fields starting with "_" designating splits in upstream nodes
119-
# if strip_previous:
120-
# stripped = tuple(
121-
# tuple(f for f in i if not f.startswith("_")) for i in splitter
122-
# )
123-
# splitter = tuple(s for s in stripped if s) # type: ignore
124-
# return splitter # type: ignore
125-
126102
def _apply_cast(self, value):
127103
"""\"Casts\" the value from the retrieved type if a cast has been applied to
128104
the lazy-field"""

pydra/engine/workflow/node.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414

1515
OutputType = ty.TypeVar("OutputType", bound=OutputsSpec)
16+
Splitter = ty.Union[str, ty.Tuple[str, ...]]
1617

1718

1819
@attrs.define
@@ -186,15 +187,14 @@ def split(
186187
if not self._state or splitter != self._state.splitter:
187188
self._set_state(splitter)
188189
# Wrap types of lazy outputs in StateArray types
189-
split_depth = len(lazy.LazyField.normalize_splitter(splitter))
190+
split_depth = len(self._normalize_splitter(splitter))
190191
outpt_lf: lazy.LazyOutField
191192
for outpt_lf in attrs.asdict(self.lzout, recurse=False).values():
192193
assert not outpt_lf.type_checked
193194
outpt_type = outpt_lf.type
194195
for d in range(split_depth):
195196
outpt_type = StateArray[outpt_type]
196197
outpt_lf.type = outpt_type
197-
outpt_lf.splits = frozenset(iter(self._state.splitter))
198198
return self
199199

200200
def combine(
@@ -250,7 +250,7 @@ def combine(
250250
else: # self.state and not self.state.combiner
251251
self._set_state(splitter=self._state.splitter, combiner=combiner)
252252
# Wrap types of lazy outputs in StateArray types
253-
norm_splitter = lazy.LazyField.normalize_splitter(self._state.splitter)
253+
norm_splitter = self._normalize_splitter(self._state.splitter)
254254
remaining_splits = [
255255
s for s in norm_splitter if not any(c in s for c in combiner)
256256
]
@@ -332,6 +332,29 @@ def _check_if_outputs_have_been_used(self, msg):
332332
+ msg
333333
)
334334

335+
@classmethod
336+
def _normalize_splitter(
337+
cls, splitter: Splitter, strip_previous: bool = True
338+
) -> ty.Tuple[ty.Tuple[str, ...], ...]:
339+
"""Converts the splitter spec into a consistent tuple[tuple[str, ...], ...] form
340+
used in LazyFields"""
341+
if isinstance(splitter, str):
342+
splitter = (splitter,)
343+
if isinstance(splitter, tuple):
344+
splitter = (splitter,) # type: ignore
345+
else:
346+
assert isinstance(splitter, list)
347+
# convert to frozenset to differentiate from tuple, yet still be hashable
348+
# (NB: order of fields in list splitters aren't relevant)
349+
splitter = tuple((s,) if isinstance(s, str) else s for s in splitter)
350+
# Strip out fields starting with "_" designating splits in upstream nodes
351+
if strip_previous:
352+
stripped = tuple(
353+
tuple(f for f in i if not f.startswith("_")) for i in splitter
354+
)
355+
splitter = tuple(s for s in stripped if s) # type: ignore
356+
return splitter # type: ignore
357+
335358

336359
@attrs.define(auto_attribs=False)
337360
class Workflow(ty.Generic[OutputType]):

0 commit comments

Comments
 (0)