33from copy import deepcopy
44import itertools
55from collections import OrderedDict
6- from operator import itemgetter
76from functools import reduce
87import typing as ty
98from . import helpers_state as hlpst
@@ -41,7 +40,7 @@ def __init__(self, indices: dict[str, int] | None = None):
4140 if indices is None :
4241 self .indices = OrderedDict ()
4342 else :
44- self .indices = OrderedDict (sorted ( indices .items () ))
43+ self .indices = OrderedDict (indices .items ())
4544
4645 def __len__ (self ) -> int :
4746 return len (self .indices )
@@ -53,13 +52,12 @@ def __getitem__(self, key: str) -> int:
5352 return self .indices [key ]
5453
5554 def __lt__ (self , other : "StateIndex" ) -> bool :
56- if set (self .indices ) != set (other .indices ):
55+ if list (self .indices ) != list (other .indices ):
5756 raise ValueError (
58- f"StateIndex { self } does not contain the same indices as { other } "
57+ f"StateIndex { self } does not contain the same indices in the same order "
58+ f"as { other } : { list (self .indices )} != { list (other .indices )} "
5959 )
60- return sorted (self .indices .items (), key = itemgetter (0 )) < sorted (
61- other .indices .items (), key = itemgetter (0 )
62- )
60+ return tuple (self .indices .items ()) < tuple (other .indices .items ())
6361
6462 def __repr__ (self ) -> str :
6563 return (
@@ -273,24 +271,29 @@ def depth(self, after_combine: bool = True) -> int:
273271 int
274272 number of splits in the state (i.e. linked splits only add 1)
275273 """
276- depth = 0
277- stack = []
278274
279- def included (s ):
280- return s not in self .combiner if after_combine else True
275+ # replace field names with 1 or 0 (1 if the field is included in the state)
276+ include_rpn = [
277+ (
278+ s
279+ if s in ["." , "*" ]
280+ else (int (s not in self .combiner ) if after_combine else 1 )
281+ )
282+ for s in self .splitter_rpn
283+ ]
281284
282- for spl in self . splitter_rpn :
283- if spl in [ "." , "*" ] :
284- if spl == "." :
285- depth += int ( all ( included ( s ) for s in stack ))
286- else :
287- assert spl == "*"
288- depth += len ([ s for s in stack if included ( s )])
289- stack = []
285+ stack = []
286+ for opr in include_rpn :
287+ if opr == "." :
288+ assert len ( stack ) >= 2
289+ stack . append ( stack . pop () and stack . pop ())
290+ elif opr == "*" :
291+ assert len (stack ) >= 2
292+ stack . append ( stack . pop () + stack . pop ())
290293 else :
291- stack .append (spl )
292- remaining_stack = [ s for s in stack if included ( s )]
293- return depth + len ( remaining_stack )
294+ stack .append (opr )
295+ assert len ( stack ) == 1
296+ return stack [ 0 ]
294297
295298 def nest_output_type (self , type_ : type ) -> type :
296299 """Nests a type of an output field in a combination of lists and state-arrays
0 commit comments