@@ -85,7 +85,7 @@ def __init__(
85
85
name ,
86
86
splitter = None ,
87
87
combiner = None ,
88
- cont_dim = None ,
88
+ container_ndim = None ,
89
89
other_states = None ,
90
90
):
91
91
"""
@@ -109,8 +109,8 @@ def __init__(
109
109
self .splitter = splitter
110
110
# temporary combiner
111
111
self .combiner = combiner
112
- self .cont_dim = cont_dim or {}
113
- self ._inner_cont_dim = {}
112
+ self .container_ndim = container_ndim or {}
113
+ self ._inner_container_ndim = {}
114
114
self ._inputs_ind = None
115
115
# if other_states, the connections have to be updated
116
116
if self .other_states :
@@ -377,12 +377,12 @@ def prev_state_splitter_rpn_compact(self):
377
377
return self ._prev_state_splitter_rpn_compact
378
378
379
379
@property
380
- def cont_dim_all (self ):
381
- # adding inner_cont_dim to the general container_dimension provided by the users
382
- cont_dim_all = deepcopy (self .cont_dim )
383
- for k , v in self ._inner_cont_dim .items ():
384
- cont_dim_all [k ] = cont_dim_all .get (k , 1 ) + v
385
- return cont_dim_all
380
+ def container_ndim_all (self ):
381
+ # adding inner_container_ndim to the general container_dimension provided by the users
382
+ container_ndim_all = deepcopy (self .container_ndim )
383
+ for k , v in self ._inner_container_ndim .items ():
384
+ container_ndim_all [k ] = container_ndim_all .get (k , 1 ) + v
385
+ return container_ndim_all
386
386
387
387
@property
388
388
def combiner (self ):
@@ -869,7 +869,7 @@ def combiner_validation(self):
869
869
def prepare_states (
870
870
self ,
871
871
inputs : dict [str , ty .Any ],
872
- cont_dim : dict [str , int ] | None = None ,
872
+ container_ndim : dict [str , int ] | None = None ,
873
873
):
874
874
"""
875
875
Prepare a full list of state indices and state values.
@@ -885,13 +885,13 @@ def prepare_states(
885
885
self .combiner_validation ()
886
886
self .set_input_groups ()
887
887
self .inputs = inputs
888
- if cont_dim is not None :
889
- self .cont_dim = cont_dim
888
+ if container_ndim is not None :
889
+ self .container_ndim = container_ndim
890
890
if self .other_states :
891
891
st : State
892
892
for nm , (st , _ ) in self .other_states .items ():
893
893
self .inputs .update (st .inputs )
894
- self .cont_dim .update (st .cont_dim_all )
894
+ self .container_ndim .update (st .container_ndim_all )
895
895
896
896
self .prepare_states_ind ()
897
897
self .prepare_states_val ()
@@ -995,7 +995,9 @@ def prepare_states_combined_ind(self, elements_to_remove_comb):
995
995
def prepare_states_val (self ):
996
996
"""Evaluate states values having states indices."""
997
997
self .states_val = list (
998
- map_splits (self .states_ind , self .inputs , cont_dim = self .cont_dim_all )
998
+ map_splits (
999
+ self .states_ind , self .inputs , container_ndim = self .container_ndim_all
1000
+ )
999
1001
)
1000
1002
return self .states_val
1001
1003
@@ -1165,8 +1167,8 @@ def _processing_terms(self, term, previous_states_ind):
1165
1167
var_ind , new_keys = previous_states_ind [term ]
1166
1168
shape = (len (var_ind ),)
1167
1169
else :
1168
- cont_dim = self .cont_dim_all .get (term , 1 )
1169
- shape = input_shape (self .inputs [term ], cont_dim = cont_dim )
1170
+ container_ndim = self .container_ndim_all .get (term , 1 )
1171
+ shape = input_shape (self .inputs [term ], container_ndim = container_ndim )
1170
1172
var_ind = range (reduce (lambda x , y : x * y , shape ))
1171
1173
new_keys = [term ]
1172
1174
# checking if the term is in inner_inputs
@@ -1186,7 +1188,8 @@ def _processing_terms(self, term, previous_states_ind):
1186
1188
def _single_op_splits (self , op_single ):
1187
1189
"""splits function if splitter is a singleton"""
1188
1190
shape = input_shape (
1189
- self .inputs [op_single ], cont_dim = self .cont_dim_all .get (op_single , 1 )
1191
+ self .inputs [op_single ],
1192
+ container_ndim = self .container_ndim_all .get (op_single , 1 ),
1190
1193
)
1191
1194
val_ind = range (reduce (lambda x , y : x * y , shape ))
1192
1195
if op_single in self .inner_inputs :
@@ -1211,8 +1214,8 @@ def _single_op_splits(self, op_single):
1211
1214
def _get_element (self , value : ty .Any , field_name : str , ind : int ) -> ty .Any :
1212
1215
"""
1213
1216
Extracting element of the inputs taking into account
1214
- container dimension of the specific element that can be set in self.state.cont_dim .
1215
- If input name is not in cont_dim , it is assumed that the input values has
1217
+ container dimension of the specific element that can be set in self.state.container_ndim .
1218
+ If input name is not in container_ndim , it is assumed that the input values has
1216
1219
a container dimension of 1, so only the most outer dim will be used for splitting.
1217
1220
1218
1221
Parameters
@@ -1229,11 +1232,11 @@ def _get_element(self, value: ty.Any, field_name: str, ind: int) -> ty.Any:
1229
1232
Any
1230
1233
specific element of the input field
1231
1234
"""
1232
- if f"{ self .name } .{ field_name } " in self .cont_dim_all :
1235
+ if f"{ self .name } .{ field_name } " in self .container_ndim_all :
1233
1236
return list (
1234
1237
flatten (
1235
1238
ensure_list (value ),
1236
- max_depth = self .cont_dim_all [f"{ self .name } .{ field_name } " ],
1239
+ max_depth = self .container_ndim_all [f"{ self .name } .{ field_name } " ],
1237
1240
)
1238
1241
)[ind ]
1239
1242
else :
@@ -1600,15 +1603,15 @@ def iter_splits(iterable, keys):
1600
1603
yield dict (zip (keys , list (flatten (iter , max_depth = 1000 ))))
1601
1604
1602
1605
1603
- def input_shape (inp , cont_dim = 1 ):
1606
+ def input_shape (inp , container_ndim = 1 ):
1604
1607
"""Get input shape, depends on the container dimension, if not specify it is assumed to be 1"""
1605
1608
# TODO: have to be changed for inner splitter (sometimes different length)
1606
- cont_dim -= 1
1609
+ container_ndim -= 1
1607
1610
shape = [len (inp )]
1608
1611
last_shape = None
1609
1612
for value in inp :
1610
- if isinstance (value , list ) and cont_dim > 0 :
1611
- cur_shape = input_shape (value , cont_dim )
1613
+ if isinstance (value , list ) and container_ndim > 0 :
1614
+ cur_shape = input_shape (value , container_ndim )
1612
1615
if last_shape is None :
1613
1616
last_shape = cur_shape
1614
1617
elif last_shape != cur_shape :
@@ -1828,13 +1831,15 @@ def combine_final_groups(combiner, groups, groups_stack, keys):
1828
1831
return keys_final , groups_final_map , groups_stack_final , combiner_all
1829
1832
1830
1833
1831
- def map_splits (split_iter , inputs , cont_dim = None ):
1834
+ def map_splits (split_iter , inputs , container_ndim = None ):
1832
1835
"""generate a dictionary of inputs prescribed by the splitter."""
1833
- if cont_dim is None :
1834
- cont_dim = {}
1836
+ if container_ndim is None :
1837
+ container_ndim = {}
1835
1838
for split in split_iter :
1836
1839
yield {
1837
- k : list (flatten (ensure_list (inputs [k ]), max_depth = cont_dim .get (k , None )))[v ]
1840
+ k : list (
1841
+ flatten (ensure_list (inputs [k ]), max_depth = container_ndim .get (k , None ))
1842
+ )[v ]
1838
1843
for k , v in split .items ()
1839
1844
}
1840
1845
0 commit comments