Skip to content

Commit d8a0320

Browse files
committed
important changes in grouping fields in states (removing it from init) and adding to prepare_states, changes in set_input_groups); adding tests to teh state class
1 parent 287eec9 commit d8a0320

File tree

3 files changed

+160
-89
lines changed

3 files changed

+160
-89
lines changed

pydra/engine/helpers_state.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,7 @@ def _single_op_splits_groups(op_single, combiner, inner_inputs, groups):
701701
if combiner == [op_single]:
702702
return [], {}, [], combiner
703703
else:
704+
# TODO: probably not needed, should be already check by st.combiner_validation
704705
raise Exception(
705706
"all fields from the combiner have to be in splitter_rpn: {}, but combiner: {} is set".format(
706707
[op_single], combiner

pydra/engine/state.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,6 @@ def __init__(self, name, splitter=None, combiner=None, other_states=None):
9595
# if other_states, the connections have to be updated
9696
if self.other_states:
9797
self.update_connections()
98-
else:
99-
self.set_input_groups(state_fields=False)
10098

10199
def __str__(self):
102100
"""Generate a string representation of the object."""
@@ -232,7 +230,7 @@ def right_combiner_all(self):
232230
if hasattr(self, "_right_combiner_all"):
233231
return self._right_combiner_all
234232
else:
235-
return self.combiner
233+
return self.right_combiner
236234

237235
@property
238236
def left_combiner_all(self):
@@ -264,6 +262,7 @@ def other_states(self, other_states):
264262

265263
@property
266264
def inner_inputs(self):
265+
"""input fields from previous nodes"""
267266
if self.other_states:
268267
_inner_inputs = {}
269268
for name, (st, inp) in self.other_states.items():
@@ -280,7 +279,6 @@ def update_connections(self, new_other_states=None, new_combiner=None):
280279
self._connect_splitters()
281280
if new_combiner:
282281
self.combiner = new_combiner
283-
self.set_input_groups()
284282

285283
def _connect_splitters(self):
286284
"""
@@ -371,6 +369,9 @@ def set_input_groups(self, state_fields=True):
371369
other_states=self.other_states,
372370
state_fields=state_fields,
373371
)
372+
# merging groups from previous nodes if any input come from previous the nodes
373+
if self.inner_inputs:
374+
self._merge_previous_groups()
374375
keys_f, group_for_inputs_f, groups_stack_f, combiner_all = hlpst.splits_groups(
375376
right_splitter_rpn,
376377
combiner=self.right_combiner,
@@ -381,19 +382,15 @@ def set_input_groups(self, state_fields=True):
381382
self._right_keys_final = keys_f
382383
self._right_group_for_inputs_final = group_for_inputs_f
383384
self._right_groups_stack_final = groups_stack_f
384-
self.connect_groups()
385+
if self.right_splitter: # if Right part, adding groups from current st
386+
self._add_current_groups()
387+
385388
else:
386389
self.group_for_inputs_final = group_for_inputs_f
387390
self.groups_stack_final = groups_stack_f
388391
self.keys_final = keys_f
389392

390-
def connect_groups(self):
391-
""""Connect previous states and evaluate the final groups."""
392-
self._merge_previous_states()
393-
if self.right_splitter: # if Right part, adding groups from current st
394-
self.push_new_states()
395-
396-
def _merge_previous_states(self):
393+
def _merge_previous_groups(self):
397394
"""Merge groups from all previous nodes."""
398395
last_gr = 0
399396
self.groups_stack_final = []
@@ -420,6 +417,8 @@ def _merge_previous_states(self):
420417
st_combiner = [
421418
comb for comb in self.left_combiner_all if comb in st.splitter_rpn_final
422419
]
420+
if not hasattr(st, "keys_final"):
421+
st.set_input_groups()
423422
if st_combiner:
424423
# keys and groups from previous states
425424
# after taking into account combiner from current state
@@ -462,7 +461,7 @@ def _merge_previous_states(self):
462461
nmb_gr += len(groups)
463462
last_gr += nmb_gr
464463

465-
def push_new_states(self):
464+
def _add_current_groups(self):
466465
"""Add additional groups from the current state."""
467466
self.keys_final += self._right_keys_final
468467
nr_gr_f = max(self.group_for_inputs_final.values()) + 1
@@ -516,6 +515,7 @@ def prepare_states(self, inputs, cont_dim=None):
516515
# checking if splitter and combiner have valid forms
517516
self.splitter_validation()
518517
self.combiner_validation()
518+
self.set_input_groups()
519519
# container dimension for each input, specifies how nested the input is
520520
if cont_dim is None:
521521
self.cont_dim = {}

0 commit comments

Comments
 (0)