@@ -95,8 +95,6 @@ def __init__(self, name, splitter=None, combiner=None, other_states=None):
95
95
# if other_states, the connections have to be updated
96
96
if self .other_states :
97
97
self .update_connections ()
98
- else :
99
- self .set_input_groups (state_fields = False )
100
98
101
99
def __str__ (self ):
102
100
"""Generate a string representation of the object."""
@@ -232,7 +230,7 @@ def right_combiner_all(self):
232
230
if hasattr (self , "_right_combiner_all" ):
233
231
return self ._right_combiner_all
234
232
else :
235
- return self .combiner
233
+ return self .right_combiner
236
234
237
235
@property
238
236
def left_combiner_all (self ):
@@ -264,6 +262,7 @@ def other_states(self, other_states):
264
262
265
263
@property
266
264
def inner_inputs (self ):
265
+ """input fields from previous nodes"""
267
266
if self .other_states :
268
267
_inner_inputs = {}
269
268
for name , (st , inp ) in self .other_states .items ():
@@ -280,7 +279,6 @@ def update_connections(self, new_other_states=None, new_combiner=None):
280
279
self ._connect_splitters ()
281
280
if new_combiner :
282
281
self .combiner = new_combiner
283
- self .set_input_groups ()
284
282
285
283
def _connect_splitters (self ):
286
284
"""
@@ -371,6 +369,9 @@ def set_input_groups(self, state_fields=True):
371
369
other_states = self .other_states ,
372
370
state_fields = state_fields ,
373
371
)
372
+ # merging groups from previous nodes if any input come from previous the nodes
373
+ if self .inner_inputs :
374
+ self ._merge_previous_groups ()
374
375
keys_f , group_for_inputs_f , groups_stack_f , combiner_all = hlpst .splits_groups (
375
376
right_splitter_rpn ,
376
377
combiner = self .right_combiner ,
@@ -381,19 +382,15 @@ def set_input_groups(self, state_fields=True):
381
382
self ._right_keys_final = keys_f
382
383
self ._right_group_for_inputs_final = group_for_inputs_f
383
384
self ._right_groups_stack_final = groups_stack_f
384
- self .connect_groups ()
385
+ if self .right_splitter : # if Right part, adding groups from current st
386
+ self ._add_current_groups ()
387
+
385
388
else :
386
389
self .group_for_inputs_final = group_for_inputs_f
387
390
self .groups_stack_final = groups_stack_f
388
391
self .keys_final = keys_f
389
392
390
- def connect_groups (self ):
391
- """"Connect previous states and evaluate the final groups."""
392
- self ._merge_previous_states ()
393
- if self .right_splitter : # if Right part, adding groups from current st
394
- self .push_new_states ()
395
-
396
- def _merge_previous_states (self ):
393
+ def _merge_previous_groups (self ):
397
394
"""Merge groups from all previous nodes."""
398
395
last_gr = 0
399
396
self .groups_stack_final = []
@@ -420,6 +417,8 @@ def _merge_previous_states(self):
420
417
st_combiner = [
421
418
comb for comb in self .left_combiner_all if comb in st .splitter_rpn_final
422
419
]
420
+ if not hasattr (st , "keys_final" ):
421
+ st .set_input_groups ()
423
422
if st_combiner :
424
423
# keys and groups from previous states
425
424
# after taking into account combiner from current state
@@ -462,7 +461,7 @@ def _merge_previous_states(self):
462
461
nmb_gr += len (groups )
463
462
last_gr += nmb_gr
464
463
465
- def push_new_states (self ):
464
+ def _add_current_groups (self ):
466
465
"""Add additional groups from the current state."""
467
466
self .keys_final += self ._right_keys_final
468
467
nr_gr_f = max (self .group_for_inputs_final .values ()) + 1
@@ -516,6 +515,7 @@ def prepare_states(self, inputs, cont_dim=None):
516
515
# checking if splitter and combiner have valid forms
517
516
self .splitter_validation ()
518
517
self .combiner_validation ()
518
+ self .set_input_groups ()
519
519
# container dimension for each input, specifies how nested the input is
520
520
if cont_dim is None :
521
521
self .cont_dim = {}
0 commit comments