@@ -377,6 +377,14 @@ def prev_state_splitter_rpn_compact(self):
377377 """
378378 return self ._prev_state_splitter_rpn_compact
379379
380+ @property
381+ def cont_dim_all (self ):
382+ # adding inner_cont_dim to the general container_dimension provided by the users
383+ cont_dim_all = deepcopy (self .cont_dim )
384+ for k , v in self ._inner_cont_dim .items ():
385+ cont_dim_all [k ] = cont_dim_all .get (k , 1 ) + v
386+ return cont_dim_all
387+
380388 @property
381389 def combiner (self ):
382390 """the combiner associated to the state."""
@@ -858,7 +866,6 @@ def combiner_validation(self):
858866 def prepare_states (
859867 self ,
860868 inputs : dict [str , ty .Any ],
861- cont_dim : dict [str , int ] | None = None ,
862869 ):
863870 """
864871 Prepare a full list of state indices and state values.
@@ -874,8 +881,6 @@ def prepare_states(
874881 self .combiner_validation ()
875882 self .set_input_groups ()
876883 self .inputs = inputs
877- if not self .cont_dim :
878- self .cont_dim = cont_dim or {}
879884 if self .other_states :
880885 st : State
881886 for nm , (st , _ ) in self .other_states .items ():
@@ -986,7 +991,7 @@ def prepare_states_combined_ind(self, elements_to_remove_comb):
986991 def prepare_states_val (self ):
987992 """Evaluate states values having states indices."""
988993 self .states_val = list (
989- hlpst .map_splits (self .states_ind , self .inputs , cont_dim = self .cont_dim )
994+ hlpst .map_splits (self .states_ind , self .inputs , cont_dim = self .cont_dim_all )
990995 )
991996 return self .states_val
992997
@@ -1156,7 +1161,7 @@ def _processing_terms(self, term, previous_states_ind):
11561161 var_ind , new_keys = previous_states_ind [term ]
11571162 shape = (len (var_ind ),)
11581163 else :
1159- cont_dim = self .cont_dim .get (term , 1 )
1164+ cont_dim = self .cont_dim_all .get (term , 1 )
11601165 shape = hlpst .input_shape (self .inputs [term ], cont_dim = cont_dim )
11611166 var_ind = range (reduce (lambda x , y : x * y , shape ))
11621167 new_keys = [term ]
@@ -1177,7 +1182,7 @@ def _processing_terms(self, term, previous_states_ind):
11771182 def _single_op_splits (self , op_single ):
11781183 """splits function if splitter is a singleton"""
11791184 shape = hlpst .input_shape (
1180- self .inputs [op_single ], cont_dim = self .cont_dim .get (op_single , 1 )
1185+ self .inputs [op_single ], cont_dim = self .cont_dim_all .get (op_single , 1 )
11811186 )
11821187 val_ind = range (reduce (lambda x , y : x * y , shape ))
11831188 if op_single in self .inner_inputs :
0 commit comments