@@ -85,7 +85,7 @@ def __init__(
8585 name ,
8686 splitter = None ,
8787 combiner = None ,
88- cont_dim = None ,
88+ container_ndim = None ,
8989 other_states = None ,
9090 ):
9191 """
@@ -109,8 +109,8 @@ def __init__(
109109 self .splitter = splitter
110110 # temporary combiner
111111 self .combiner = combiner
112- self .cont_dim = cont_dim or {}
113- self ._inner_cont_dim = {}
112+ self .container_ndim = container_ndim or {}
113+ self ._inner_container_ndim = {}
114114 self ._inputs_ind = None
115115 # if other_states, the connections have to be updated
116116 if self .other_states :
@@ -377,12 +377,12 @@ def prev_state_splitter_rpn_compact(self):
377377 return self ._prev_state_splitter_rpn_compact
378378
379379 @property
380- def cont_dim_all (self ):
381- # adding inner_cont_dim to the general container_dimension provided by the users
382- cont_dim_all = deepcopy (self .cont_dim )
383- for k , v in self ._inner_cont_dim .items ():
384- cont_dim_all [k ] = cont_dim_all .get (k , 1 ) + v
385- return cont_dim_all
380+ def container_ndim_all (self ):
381+ # adding inner_container_ndim to the general container_dimension provided by the users
382+ container_ndim_all = deepcopy (self .container_ndim )
383+ for k , v in self ._inner_container_ndim .items ():
384+ container_ndim_all [k ] = container_ndim_all .get (k , 1 ) + v
385+ return container_ndim_all
386386
387387 @property
388388 def combiner (self ):
@@ -869,7 +869,7 @@ def combiner_validation(self):
869869 def prepare_states (
870870 self ,
871871 inputs : dict [str , ty .Any ],
872- cont_dim : dict [str , int ] | None = None ,
872+ container_ndim : dict [str , int ] | None = None ,
873873 ):
874874 """
875875 Prepare a full list of state indices and state values.
@@ -885,13 +885,13 @@ def prepare_states(
885885 self .combiner_validation ()
886886 self .set_input_groups ()
887887 self .inputs = inputs
888- if cont_dim is not None :
889- self .cont_dim = cont_dim
888+ if container_ndim is not None :
889+ self .container_ndim = container_ndim
890890 if self .other_states :
891891 st : State
892892 for nm , (st , _ ) in self .other_states .items ():
893893 self .inputs .update (st .inputs )
894- self .cont_dim .update (st .cont_dim_all )
894+ self .container_ndim .update (st .container_ndim_all )
895895
896896 self .prepare_states_ind ()
897897 self .prepare_states_val ()
@@ -995,7 +995,9 @@ def prepare_states_combined_ind(self, elements_to_remove_comb):
995995 def prepare_states_val (self ):
996996 """Evaluate states values having states indices."""
997997 self .states_val = list (
998- map_splits (self .states_ind , self .inputs , cont_dim = self .cont_dim_all )
998+ map_splits (
999+ self .states_ind , self .inputs , container_ndim = self .container_ndim_all
1000+ )
9991001 )
10001002 return self .states_val
10011003
@@ -1165,8 +1167,8 @@ def _processing_terms(self, term, previous_states_ind):
11651167 var_ind , new_keys = previous_states_ind [term ]
11661168 shape = (len (var_ind ),)
11671169 else :
1168- cont_dim = self .cont_dim_all .get (term , 1 )
1169- shape = input_shape (self .inputs [term ], cont_dim = cont_dim )
1170+ container_ndim = self .container_ndim_all .get (term , 1 )
1171+ shape = input_shape (self .inputs [term ], container_ndim = container_ndim )
11701172 var_ind = range (reduce (lambda x , y : x * y , shape ))
11711173 new_keys = [term ]
11721174 # checking if the term is in inner_inputs
@@ -1186,7 +1188,8 @@ def _processing_terms(self, term, previous_states_ind):
11861188 def _single_op_splits (self , op_single ):
11871189 """splits function if splitter is a singleton"""
11881190 shape = input_shape (
1189- self .inputs [op_single ], cont_dim = self .cont_dim_all .get (op_single , 1 )
1191+ self .inputs [op_single ],
1192+ container_ndim = self .container_ndim_all .get (op_single , 1 ),
11901193 )
11911194 val_ind = range (reduce (lambda x , y : x * y , shape ))
11921195 if op_single in self .inner_inputs :
@@ -1211,8 +1214,8 @@ def _single_op_splits(self, op_single):
12111214 def _get_element (self , value : ty .Any , field_name : str , ind : int ) -> ty .Any :
12121215 """
12131216 Extracting element of the inputs taking into account
1214- container dimension of the specific element that can be set in self.state.cont_dim .
1215- If input name is not in cont_dim , it is assumed that the input values has
1217+ container dimension of the specific element that can be set in self.state.container_ndim .
1218+ If input name is not in container_ndim , it is assumed that the input values has
12161219 a container dimension of 1, so only the most outer dim will be used for splitting.
12171220
12181221 Parameters
@@ -1229,11 +1232,11 @@ def _get_element(self, value: ty.Any, field_name: str, ind: int) -> ty.Any:
12291232 Any
12301233 specific element of the input field
12311234 """
1232- if f"{ self .name } .{ field_name } " in self .cont_dim_all :
1235+ if f"{ self .name } .{ field_name } " in self .container_ndim_all :
12331236 return list (
12341237 flatten (
12351238 ensure_list (value ),
1236- max_depth = self .cont_dim_all [f"{ self .name } .{ field_name } " ],
1239+ max_depth = self .container_ndim_all [f"{ self .name } .{ field_name } " ],
12371240 )
12381241 )[ind ]
12391242 else :
@@ -1600,15 +1603,15 @@ def iter_splits(iterable, keys):
16001603 yield dict (zip (keys , list (flatten (iter , max_depth = 1000 ))))
16011604
16021605
1603- def input_shape (inp , cont_dim = 1 ):
1606+ def input_shape (inp , container_ndim = 1 ):
16041607 """Get input shape, depends on the container dimension, if not specify it is assumed to be 1"""
16051608 # TODO: have to be changed for inner splitter (sometimes different length)
1606- cont_dim -= 1
1609+ container_ndim -= 1
16071610 shape = [len (inp )]
16081611 last_shape = None
16091612 for value in inp :
1610- if isinstance (value , list ) and cont_dim > 0 :
1611- cur_shape = input_shape (value , cont_dim )
1613+ if isinstance (value , list ) and container_ndim > 0 :
1614+ cur_shape = input_shape (value , container_ndim )
16121615 if last_shape is None :
16131616 last_shape = cur_shape
16141617 elif last_shape != cur_shape :
@@ -1828,13 +1831,15 @@ def combine_final_groups(combiner, groups, groups_stack, keys):
18281831 return keys_final , groups_final_map , groups_stack_final , combiner_all
18291832
18301833
1831- def map_splits (split_iter , inputs , cont_dim = None ):
1834+ def map_splits (split_iter , inputs , container_ndim = None ):
18321835 """generate a dictionary of inputs prescribed by the splitter."""
1833- if cont_dim is None :
1834- cont_dim = {}
1836+ if container_ndim is None :
1837+ container_ndim = {}
18351838 for split in split_iter :
18361839 yield {
1837- k : list (flatten (ensure_list (inputs [k ]), max_depth = cont_dim .get (k , None )))[v ]
1840+ k : list (
1841+ flatten (ensure_list (inputs [k ]), max_depth = container_ndim .get (k , None ))
1842+ )[v ]
18381843 for k , v in split .items ()
18391844 }
18401845
0 commit comments