Skip to content

Commit 457c423

Browse files
committed
initial effort to separate the connections within state from the init etc.
1 parent 603854b commit 457c423

File tree

2 files changed

+63
-44
lines changed

2 files changed

+63
-44
lines changed

pydra/engine/core.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -776,20 +776,21 @@ def create_connections(self, task):
776776
)
777777
# if task has connections state has to be recalculated
778778
if other_states:
779+
if hasattr(task, "fut_combiner"):
780+
combiner = task.fut_combiner
781+
else:
782+
combiner = None
783+
779784
if task.state:
780-
old_splitter = task.state.splitter
785+
task.state.update_connections(
786+
new_other_states=other_states, new_combiner=combiner
787+
)
781788
else:
782-
old_splitter = None
783-
if hasattr(task, "fut_combiner"):
784789
task.state = state.State(
785790
task.name,
786-
splitter=old_splitter,
791+
splitter=None,
787792
other_states=other_states,
788-
combiner=task.fut_combiner,
789-
)
790-
else:
791-
task.state = state.State(
792-
task.name, splitter=old_splitter, other_states=other_states
793+
combiner=combiner,
793794
)
794795

795796
async def _run(self, submitter=None, rerun=False, **kwargs):

pydra/engine/state.py

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -96,20 +96,31 @@ def __init__(self, name, splitter=None, combiner=None, other_states=None):
9696
self.missing_connections = False
9797
self.other_states = other_states
9898
self.splitter = splitter
99+
# temporary combiner
100+
self._combiner = combiner
99101
# if missing_connections, we can't continue, should wait for updates
100-
# TODO: should find a better way, so it's not in the init, but combiner complicates
101102
if not self.missing_connections:
102-
self._connect_splitters()
103-
self.combiner = combiner
104-
self.inner_inputs = {}
105-
for name, (st, inp) in self.other_states.items():
106-
if f"_{st.name}" in self.splitter_rpn_compact:
107-
self.inner_inputs[f"{self.name}.{inp}"] = st
108-
self.set_input_groups()
109-
self.set_splitter_final()
110-
self.states_val = []
111-
self.inputs_ind = []
112-
self.final_combined_ind_mapping = {}
103+
self.update_connections()
104+
105+
def update_connections(self, new_other_states=None, new_combiner=None):
106+
if new_other_states:
107+
self.missing_connections = False
108+
self.other_states = new_other_states
109+
self.splitter = self._splitter
110+
self._connect_splitters()
111+
if new_combiner:
112+
self.combiner = new_combiner
113+
else:
114+
self.combiner = self._combiner
115+
self.inner_inputs = {}
116+
for name, (st, inp) in self.other_states.items():
117+
if f"_{st.name}" in self.splitter_rpn_compact:
118+
self.inner_inputs[f"{self.name}.{inp}"] = st
119+
self.set_input_groups()
120+
self.set_splitter_final()
121+
self.states_val = []
122+
self.inputs_ind = []
123+
self.final_combined_ind_mapping = {}
113124

114125
def __str__(self):
115126
"""Generate a string representation of the object."""
@@ -175,24 +186,31 @@ def combiner(self):
175186

176187
@combiner.setter
177188
def combiner(self, combiner):
178-
if combiner:
179-
if not self.splitter:
180-
raise Exception("splitter has to be set before setting combiner")
181-
if not isinstance(combiner, (str, list)):
182-
raise Exception("combiner has to be a string or a list")
183-
self._combiner = hlpst.add_name_combiner(ensure_list(combiner), self.name)
184-
if set(self._combiner) - set(self.splitter_rpn):
185-
raise Exception("all combiners have to be in the splitter")
186-
# combiners from the current fields: i.e. {self.name}.input
187-
self._right_combiner = [
188-
comb for comb in self._combiner if self.name in comb
189-
]
190-
# combiners from the previous states
191-
self._left_combiner = list(set(self._combiner) - set(self._right_combiner))
189+
if self.missing_connections or not self.splitter:
190+
self._combiner = combiner
192191
else:
193-
self._combiner = []
194-
self._left_combiner = []
195-
self._right_combiner = []
192+
if combiner:
193+
if not self.splitter:
194+
raise Exception("splitter has to be set before setting combiner")
195+
if not isinstance(combiner, (str, list)):
196+
raise Exception("combiner has to be a string or a list")
197+
self._combiner = hlpst.add_name_combiner(
198+
ensure_list(combiner), self.name
199+
)
200+
if set(self._combiner) - set(self.splitter_rpn):
201+
raise Exception("all combiners have to be in the splitter")
202+
# combiners from the current fields: i.e. {self.name}.input
203+
self._right_combiner = [
204+
comb for comb in self._combiner if self.name in comb
205+
]
206+
# combiners from the previous states
207+
self._left_combiner = list(
208+
set(self._combiner) - set(self._right_combiner)
209+
)
210+
else:
211+
self._combiner = []
212+
self._left_combiner = []
213+
self._right_combiner = []
196214

197215
def _connect_splitters(self):
198216
"""
@@ -276,7 +294,7 @@ def _complete_left(self, left=None):
276294
left = left[0]
277295
return left
278296

279-
def _left_right_check(self, splitter_part, rec_lev=0):
297+
def _left_right_check(self, splitter_part, check_nested=True):
280298
"""
281299
Check if splitter_part is purely Left, Right
282300
or [Left, Right] if the splitter_part is a list (outer splitter)
@@ -299,9 +317,9 @@ def _left_right_check(self, splitter_part, rec_lev=0):
299317
return "Right"
300318
elif (
301319
isinstance(self.splitter, list)
302-
and rec_lev == 0
303-
and self._left_right_check(self.splitter[0], rec_lev=1) == "Left"
304-
and self._left_right_check(self.splitter[1], rec_lev=1) == "Right"
320+
and check_nested
321+
and self._left_right_check(self.splitter[0], check_nested=False) == "Left"
322+
and self._left_right_check(self.splitter[1], check_nested=False) == "Right"
305323
):
306324
return "[Left, Right]" # Left and Right parts separated in outer scalar
307325
else:
@@ -339,11 +357,11 @@ def set_input_groups(self):
339357

340358
def connect_groups(self):
341359
""""Connect previous states and evaluate the final groups."""
342-
self.merge_previous_states()
360+
self._merge_previous_states()
343361
if self._right_splitter: # if Right part, adding groups from current st
344362
self.push_new_states()
345363

346-
def merge_previous_states(self):
364+
def _merge_previous_states(self):
347365
"""Merge groups from all previous nodes."""
348366
last_gr = 0
349367
self.groups_stack_final = []

0 commit comments

Comments
 (0)