@@ -45,12 +45,12 @@ class Node(ty.Generic[OutputType]):
4545 init = False , default = None , eq = False , hash = False , repr = False
4646 )
4747 _state : State | None = attrs .field (init = False , default = NOT_SET )
48- _cont_dim : dict [str , int ] | None = attrs .field (
49- init = False , default = None
50- ) # QUESTION: should this be included in the state?
51- _inner_cont_dim : dict [str , int ] = attrs .field (
52- init = False , factory = dict
53- ) # QUESTION: should this be included in the state?
48+ # _cont_dim: dict[str, int] | None = attrs.field(
49+ # init=False, default=None
50+ # ) # QUESTION: should this be included in the state?
51+ # _inner_cont_dim: dict[str, int] = attrs.field(
52+ # init=False, factory=dict
53+ # ) # QUESTION: should this be included in the state?
5454
5555 def __attrs_post_init__ (self ):
5656 self ._set_state ()
@@ -179,16 +179,20 @@ def _check_if_outputs_have_been_used(self, msg):
179179
180180 def _set_state (self ) -> None :
181181 # Add node name to state's splitter, combiner and cont_dim loaded from the def
182- splitter = self ._definition ._splitter
183- combiner = self ._definition ._combiner
182+ splitter = deepcopy (
183+ self ._definition ._splitter
184+ ) # these can be modified by the state
185+ combiner = deepcopy (
186+ self ._definition ._combiner
187+ ) # these can be modified by the state
184188 if splitter :
185189 splitter = hlpst .add_name_splitter (splitter , self .name )
186190 if combiner :
187191 combiner = hlpst .add_name_combiner (combiner , self .name )
188192 if self ._definition ._cont_dim :
189- self . _cont_dim = {}
193+ cont_dim = {}
190194 for key , val in self ._definition ._cont_dim .items ():
191- self . _cont_dim [f"{ self .name } .{ key } " ] = val
195+ cont_dim [f"{ self .name } .{ key } " ] = val
192196 other_states = self ._get_upstream_states ()
193197 if splitter or combiner or other_states :
194198 self ._state = State (
@@ -197,6 +201,7 @@ def _set_state(self) -> None:
197201 splitter = splitter ,
198202 other_states = other_states ,
199203 combiner = combiner ,
204+ cont_dim = cont_dim ,
200205 )
201206 if combiner :
202207 if not_split := [
0 commit comments