| 
13 | 13 | 
 
  | 
14 | 14 | 
 
  | 
15 | 15 | OutputType = ty.TypeVar("OutputType", bound=OutputsSpec)  | 
 | 16 | +Splitter = ty.Union[str, ty.Tuple[str, ...]]  | 
16 | 17 | 
 
  | 
17 | 18 | 
 
  | 
18 | 19 | @attrs.define  | 
@@ -186,15 +187,14 @@ def split(  | 
186 | 187 |         if not self._state or splitter != self._state.splitter:  | 
187 | 188 |             self._set_state(splitter)  | 
188 | 189 |         # Wrap types of lazy outputs in StateArray types  | 
189 |  | -        split_depth = len(lazy.LazyField.normalize_splitter(splitter))  | 
 | 190 | +        split_depth = len(self._normalize_splitter(splitter))  | 
190 | 191 |         outpt_lf: lazy.LazyOutField  | 
191 | 192 |         for outpt_lf in attrs.asdict(self.lzout, recurse=False).values():  | 
192 | 193 |             assert not outpt_lf.type_checked  | 
193 | 194 |             outpt_type = outpt_lf.type  | 
194 | 195 |             for d in range(split_depth):  | 
195 | 196 |                 outpt_type = StateArray[outpt_type]  | 
196 | 197 |             outpt_lf.type = outpt_type  | 
197 |  | -            outpt_lf.splits = frozenset(iter(self._state.splitter))  | 
198 | 198 |         return self  | 
199 | 199 | 
 
  | 
200 | 200 |     def combine(  | 
@@ -250,7 +250,7 @@ def combine(  | 
250 | 250 |         else:  # self.state and not self.state.combiner  | 
251 | 251 |             self._set_state(splitter=self._state.splitter, combiner=combiner)  | 
252 | 252 |         # Wrap types of lazy outputs in StateArray types  | 
253 |  | -        norm_splitter = lazy.LazyField.normalize_splitter(self._state.splitter)  | 
 | 253 | +        norm_splitter = self._normalize_splitter(self._state.splitter)  | 
254 | 254 |         remaining_splits = [  | 
255 | 255 |             s for s in norm_splitter if not any(c in s for c in combiner)  | 
256 | 256 |         ]  | 
@@ -332,6 +332,29 @@ def _check_if_outputs_have_been_used(self, msg):  | 
332 | 332 |                 + msg  | 
333 | 333 |             )  | 
334 | 334 | 
 
  | 
 | 335 | +    @classmethod  | 
 | 336 | +    def _normalize_splitter(  | 
 | 337 | +        cls, splitter: Splitter, strip_previous: bool = True  | 
 | 338 | +    ) -> ty.Tuple[ty.Tuple[str, ...], ...]:  | 
 | 339 | +        """Converts the splitter spec into a consistent tuple[tuple[str, ...], ...] form  | 
 | 340 | +        used in LazyFields"""  | 
 | 341 | +        if isinstance(splitter, str):  | 
 | 342 | +            splitter = (splitter,)  | 
 | 343 | +        if isinstance(splitter, tuple):  | 
 | 344 | +            splitter = (splitter,)  # type: ignore  | 
 | 345 | +        else:  | 
 | 346 | +            assert isinstance(splitter, list)  | 
 | 347 | +            # convert to frozenset to differentiate from tuple, yet still be hashable  | 
 | 348 | +            # (NB: order of fields in list splitters aren't relevant)  | 
 | 349 | +            splitter = tuple((s,) if isinstance(s, str) else s for s in splitter)  | 
 | 350 | +        # Strip out fields starting with "_" designating splits in upstream nodes  | 
 | 351 | +        if strip_previous:  | 
 | 352 | +            stripped = tuple(  | 
 | 353 | +                tuple(f for f in i if not f.startswith("_")) for i in splitter  | 
 | 354 | +            )  | 
 | 355 | +            splitter = tuple(s for s in stripped if s)  # type: ignore  | 
 | 356 | +        return splitter  # type: ignore  | 
 | 357 | + | 
335 | 358 | 
 
  | 
336 | 359 | @attrs.define(auto_attribs=False)  | 
337 | 360 | class Workflow(ty.Generic[OutputType]):  | 
 | 
0 commit comments