@@ -41,23 +41,63 @@ def __init__(self, indices: dict[str, int] | None = None):
4141 else :
4242 self .indices = OrderedDict (sorted (indices .items ()))
4343
44- def __repr__ (self ):
44+ def __len__ (self ) -> int :
45+ return len (self .indices )
46+
47+ def __iter__ (self ) -> ty .Generator [str , None , None ]:
48+ return iter (self .indices )
49+
50+ def __repr__ (self ) -> str :
4551 return (
4652 "StateIndex(" + ", " .join (f"{ n } ={ v } " for n , v in self .indices .items ()) + ")"
4753 )
4854
4955 def __hash__ (self ):
5056 return hash (tuple (self .indices .items ()))
5157
52- def __eq__ (self , other ):
58+ def __eq__ (self , other ) -> bool :
5359 return self .indices == other .indices
5460
55- def __str__ (self ):
61+ def __str__ (self ) -> str :
5662 return "__" .join (f"{ n } -{ i } " for n , i in self .indices .items ())
5763
58- def __bool__ (self ):
64+ def __bool__ (self ) -> bool :
5965 return bool (self .indices )
6066
67+ def subset (self , state_names : ty .Iterable [str ]) -> ty .Self :
68+ """Create a new StateIndex with only the specified fields
69+
70+ Parameters
71+ ----------
72+ fields : list[str]
73+ the fields to keep in the new StateIndex
74+
75+ Returns
76+ -------
77+ StateIndex
78+ a new StateIndex with only the specified fields
79+ """
80+ return type (self )({k : v for k , v in self .indices .items () if k in state_names })
81+
82+ def matches (self , other : "StateIndex" ) -> bool :
83+ """Check if the indices that are present in the other StateIndex match
84+
85+ Parameters
86+ ----------
87+ other : StateIndex
88+ the other StateIndex to compare against
89+
90+ Returns
91+ -------
92+ bool
93+ True if all the indices in the other StateIndex match
94+ """
95+ if not set (self .indices ).issuperset (other .indices ):
96+ raise ValueError (
97+ f"StateIndex { self } does not contain all the indices in { other } "
98+ )
99+ return all (self .indices [k ] == v for k , v in other .indices .items ())
100+
61101
62102class State :
63103 """
@@ -172,6 +212,9 @@ def __str__(self):
172212 def names (self ):
173213 """Return the names of the states."""
174214 # analysing states from connected tasks if inner_inputs
215+ if not hasattr (self , "keys_final" ):
216+ self .prepare_states ()
217+ self .prepare_inputs ()
175218 previous_states_keys = {
176219 f"_{ v .name } " : v .keys_final for v in self .inner_inputs .values ()
177220 }
@@ -190,13 +233,13 @@ def names(self):
190233
191234 @property
192235 def depth (self ) -> int :
193- """Return the number of uncombined splits of the state, i.e. the number nested
236+ """Return the number of splits of the state, i.e. the number nested
194237 state arrays to wrap around the type of lazy out fields
195238
196239 Returns
197240 -------
198241 int
199- number of uncombined splits
242+ number of uncombined independent splits (i.e. linked splits only add 1)
200243 """
201244 depth = 0
202245 stack = []
@@ -210,7 +253,8 @@ def depth(self) -> int:
210253 stack = []
211254 else :
212255 stack .append (spl )
213- return depth + len (stack )
256+ remaining_stack = [s for s in stack if s not in self .combiner ]
257+ return depth + len (remaining_stack )
214258
215259 @property
216260 def splitter (self ):
0 commit comments