33from copy import deepcopy
44import itertools
55from collections import OrderedDict
6+ from operator import itemgetter
67from functools import reduce
78import typing as ty
89from . import helpers_state as hlpst
910from .helpers import ensure_list , attrs_values
11+ from pydra .utils .typing import StateArray , TypeParser
1012
1113# from .specs import BaseDef
1214if ty .TYPE_CHECKING :
@@ -47,6 +49,18 @@ def __len__(self) -> int:
4749 def __iter__ (self ) -> ty .Generator [str , None , None ]:
4850 return iter (self .indices )
4951
52+ def __getitem__ (self , key : str ) -> int :
53+ return self .indices [key ]
54+
55+ def __lt__ (self , other : "StateIndex" ) -> bool :
56+ if set (self .indices ) != set (other .indices ):
57+ raise ValueError (
58+ f"StateIndex { self } does not contain the same indices as { other } "
59+ )
60+ return sorted (self .indices .items (), key = itemgetter (0 )) < sorted (
61+ other .indices .items (), key = itemgetter (0 )
62+ )
63+
5064 def __repr__ (self ) -> str :
5165 return (
5266 "StateIndex(" + ", " .join (f"{ n } ={ v } " for n , v in self .indices .items ()) + ")"
@@ -79,6 +93,21 @@ def subset(self, state_names: ty.Iterable[str]) -> ty.Self:
7993 """
8094 return type (self )({k : v for k , v in self .indices .items () if k in state_names })
8195
96+ def missing (self , state_names : ty .Iterable [str ]) -> ty .List [str ]:
97+ """Return the fields that are missing from the StateIndex
98+
99+ Parameters
100+ ----------
101+ fields : list[str]
102+ the fields to check for
103+
104+ Returns
105+ -------
106+ list[str]
107+ the fields that are missing from the StateIndex
108+ """
109+ return [f for f in state_names if f not in self .indices ]
110+
82111 def matches (self , other : "StateIndex" ) -> bool :
83112 """Check if the indices that are present in the other StateIndex match
84113
@@ -92,6 +121,8 @@ def matches(self, other: "StateIndex") -> bool:
92121 bool
93122 True if all the indices in the other StateIndex match
94123 """
124+ if isinstance (other , dict ):
125+ other = StateIndex (other )
95126 if not set (self .indices ).issuperset (other .indices ):
96127 raise ValueError (
97128 f"StateIndex { self } does not contain all the indices in { other } "
@@ -211,10 +242,6 @@ def __str__(self):
211242 @property
212243 def names (self ):
213244 """Return the names of the states."""
214- # analysing states from connected tasks if inner_inputs
215- if not hasattr (self , "keys_final" ):
216- self .prepare_states ()
217- self .prepare_inputs ()
218245 previous_states_keys = {
219246 f"_{ v .name } " : v .keys_final for v in self .inner_inputs .values ()
220247 }
@@ -265,6 +292,41 @@ def included(s):
265292 remaining_stack = [s for s in stack if included (s )]
266293 return depth + len (remaining_stack )
267294
295+ def nest_output_type (self , type_ : type ) -> type :
296+ """Nests a type of an output field in a combination of lists and state-arrays
297+ based on the state's splitter and combiner
298+
299+ Parameters
300+ ----------
301+ type_ : type
302+ the type of the output field
303+
304+ Returns
305+ -------
306+ type
307+ the nested type of the output field
308+ """
309+
310+ state_array_depth = self .depth ()
311+
312+ # If there is a combination, it will get flattened into a single list
313+ if self .depth (after_combine = False ) > state_array_depth :
314+ type_ = list [type_ ]
315+
316+ # Nest the uncombined state arrays around the type
317+ for _ in range (state_array_depth ):
318+ type_ = StateArray [type_ ]
319+ return type_
320+
321+ @classmethod
322+ def combine_state_arrays (cls , type_ : type ) -> type :
323+ """Collapses (potentially nested) state array(s) into a single list"""
324+ if TypeParser .get_origin (type_ ) is StateArray :
325+ # Implicitly combine any remaining uncombined states into a single
326+ # list
327+ type_ = list [TypeParser .strip_splits (type_ )[0 ]]
328+ return type_
329+
268330 @property
269331 def splitter (self ):
270332 """Get the splitter of the state."""
0 commit comments