Skip to content

Commit bc71076

Browse files
authored
Merge pull request #465 from djarecka/fix/states_connection
[wip] fixing connections for multiple inputs
2 parents a9224dd + c2d63d7 commit bc71076

File tree

6 files changed

+279
-32
lines changed

6 files changed

+279
-32
lines changed

pydra/engine/core.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -952,16 +952,21 @@ def create_connections(self, task, detailed=False):
952952
(task.name, field.name, val.name, val.field)
953953
)
954954
logger.debug("Connecting %s to %s", val.name, task.name)
955-
955+
# adding a state from the previous task to other_states
956956
if (
957957
getattr(self, val.name).state
958958
and getattr(self, val.name).state.splitter_rpn_final
959959
):
960-
# adding a state from the previous task to other_states
961-
other_states[val.name] = (
962-
getattr(self, val.name).state,
963-
field.name,
964-
)
960+
# adding task_name: (task.state, [a field from the connection]
961+
if val.name not in other_states:
962+
other_states[val.name] = (
963+
getattr(self, val.name).state,
964+
[field.name],
965+
)
966+
else:
967+
# if the task already exist in other_state,
968+
# additional field name should be added to the list of fields
969+
other_states[val.name][1].append(field.name)
965970
else: # LazyField with the wf input
966971
# connections with wf input should be added to the detailed graph description
967972
if detailed:

pydra/engine/state.py

Lines changed: 107 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,10 @@ def inner_inputs(self):
304304
"""
305305
if self.other_states:
306306
_inner_inputs = {}
307-
for name, (st, inp) in self.other_states.items():
307+
for name, (st, inp_l) in self.other_states.items():
308308
if f"_{st.name}" in self.splitter_rpn_compact:
309-
_inner_inputs[f"{self.name}.{inp}"] = st
309+
for inp in inp_l:
310+
_inner_inputs[f"{self.name}.{inp}"] = st
310311
return _inner_inputs
311312
else:
312313
return {}
@@ -323,6 +324,10 @@ def update_connections(self, new_other_states=None, new_combiner=None):
323324
"""
324325
if new_other_states:
325326
self.other_states = new_other_states
327+
# ensuring that the connected fields are set as a list
328+
self.other_states = {
329+
nm: (st, ensure_list(flds)) for nm, (st, flds) in self.other_states.items()
330+
}
326331
self._connect_splitters()
327332
if new_combiner:
328333
self.combiner = new_combiner
@@ -388,8 +393,93 @@ def _complete_prev_state(self, prev_state=None):
388393
prev_state = [f"_{name}" for name in self.other_states]
389394
if len(prev_state) == 1:
390395
prev_state = prev_state[0]
396+
397+
if isinstance(prev_state, list):
398+
prev_state = self._removed_repeated(prev_state)
391399
return prev_state
392400

401+
def _removed_repeated(self, previous_splitters):
402+
"""removing states from previous tasks that are repeated either directly or indirectly"""
403+
for el in previous_splitters:
404+
if el[1:] not in self.other_states:
405+
raise hlpst.PydraStateError(
406+
f"can't ask for splitter from {el[1:]}, other nodes that are connected: {self.other_states}"
407+
)
408+
409+
repeated = set(
410+
[
411+
(el, previous_splitters.count(el))
412+
for el in previous_splitters
413+
if previous_splitters.count(el) > 1
414+
]
415+
)
416+
if repeated:
417+
# assuming that I want to remove fro right
418+
previous_splitters.reverse()
419+
for el, cnt in repeated:
420+
for ii in range(cnt):
421+
previous_splitters.remove(el)
422+
previous_splitters.reverse()
423+
424+
el_state = []
425+
el_connect = []
426+
el_state_connect = []
427+
for el in previous_splitters:
428+
nm = el[1:]
429+
st = self.other_states[nm][0]
430+
if not st.other_states:
431+
# states that has no other connections
432+
el_state.append(el)
433+
else: # element has previous_connection
434+
if st.current_splitter: # final?
435+
# states that has previous connections and it's own splitter
436+
el_state_connect.append((el, st.prev_state_splitter))
437+
else:
438+
# states with previous connections but no additional splitter
439+
el_connect.append((el, st.prev_state_splitter))
440+
441+
for el in el_connect:
442+
nm = el[0][1:]
443+
repeated_prev = set(ensure_list(el[1])).intersection(el_state)
444+
if repeated_prev:
445+
for r_el in repeated_prev:
446+
r_nm = r_el[1:]
447+
self.other_states[r_nm] = (
448+
self.other_states[r_nm][0],
449+
self.other_states[r_nm][1] + self.other_states[nm][1],
450+
)
451+
new_st = set(ensure_list(el[1])) - set(el_state)
452+
if not new_st:
453+
previous_splitters.remove(el[0])
454+
else:
455+
for n_el in new_st:
456+
n_nm = n_el[1:]
457+
self.other_states[n_nm] = (
458+
self.other_states[nm][0].other_states[n_nm][0],
459+
self.other_states[nm][1],
460+
)
461+
# removing el of the splitter and adding new_st instead
462+
ind = previous_splitters.index(el[0])
463+
if ind == len(previous_splitters) - 1:
464+
previous_splitters = previous_splitters[:-1] + list(new_st)
465+
else:
466+
previous_splitters = (
467+
previous_splitters[:ind]
468+
+ list(new_st)
469+
+ previous_splitters[ind + 1 :]
470+
)
471+
# TODO: this part is not tested, needs more work
472+
for el in el_state_connect:
473+
repeated_prev = set(ensure_list(el[1])).intersection(el_state)
474+
if repeated_prev:
475+
for r_el in repeated_prev:
476+
previous_splitters.remove(r_el)
477+
478+
if len(previous_splitters) == 1:
479+
return previous_splitters[0]
480+
else:
481+
return previous_splitters
482+
393483
def _prevst_current_check(self, splitter_part, check_nested=True):
394484
"""
395485
Check if splitter_part is purely prev-state part, the current part,
@@ -640,14 +730,15 @@ def prepare_states_ind(self):
640730
# TODO: need tests in test_Workflow.py
641731
elements_to_remove = []
642732
elements_to_remove_comb = []
643-
for name, (st, inp) in self.other_states.items():
644-
if (
645-
f"{self.name}.{inp}" in self.splitter_rpn
646-
and f"_{name}" in self.splitter_rpn_compact
647-
):
648-
elements_to_remove.append(f"_{name}")
649-
if f"{self.name}.{inp}" not in self.combiner:
650-
elements_to_remove_comb.append(f"_{name}")
733+
for name, (st, inp_l) in self.other_states.items():
734+
for inp in inp_l:
735+
if (
736+
f"{self.name}.{inp}" in self.splitter_rpn
737+
and f"_{name}" in self.splitter_rpn_compact
738+
):
739+
elements_to_remove.append(f"_{name}")
740+
if f"{self.name}.{inp}" not in self.combiner:
741+
elements_to_remove_comb.append(f"_{name}")
651742

652743
partial_rpn = hlpst.remove_inp_from_splitter_rpn(
653744
deepcopy(self.splitter_rpn_compact), elements_to_remove
@@ -770,8 +861,9 @@ def prepare_inputs(self):
770861
for ii, el in enumerate(self.prev_state_splitter_rpn_compact):
771862
if el in ["*", "."]:
772863
continue
773-
st, inp = self.other_states[el[1:]]
774-
if f"{self.name}.{inp}" in self.splitter_rpn: # inner splitter
864+
st, inp_l = self.other_states[el[1:]]
865+
inp_l = [f"{self.name}.{inp}" for inp in inp_l]
866+
if set(inp_l).intersection(self.splitter_rpn): # inner splitter
775867
connected_to_inner += [
776868
el for el in st.splitter_rpn_final if el not in [".", "*"]
777869
]
@@ -784,8 +876,9 @@ def prepare_inputs(self):
784876
else:
785877
inputs_ind_prev = hlpst.op["*"](inputs_ind_prev, st_ind)
786878
else:
787-
inputs_ind_prev = hlpst.op["*"](st_ind)
788-
keys_inp_prev += [f"{self.name}.{inp}"]
879+
# TODO: more tests needed
880+
inputs_ind_prev = hlpst.op["."](*[st_ind] * len(inp_l))
881+
keys_inp_prev += inp_l
789882
keys_inp = keys_inp_prev + keys_inp
790883

791884
if inputs_ind and inputs_ind_prev:

pydra/engine/tests/test_state.py

Lines changed: 104 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,98 @@ def test_state_connect_6a():
586586
]
587587

588588

589+
def test_state_connect_7():
590+
"""two 'connected' states with multiple fields that are connected
591+
no explicit splitter for the second state
592+
"""
593+
st1 = State(name="NA", splitter="a")
594+
st2 = State(name="NB", other_states={"NA": (st1, ["x", "y"])})
595+
# should take into account that x, y come from the same task
596+
assert st2.splitter == "_NA"
597+
assert st2.splitter_rpn == ["NA.a"]
598+
assert st2.prev_state_splitter == st2.splitter
599+
assert st2.prev_state_splitter_rpn == st2.splitter_rpn
600+
assert st2.current_splitter is None
601+
assert st2.current_splitter_rpn == []
602+
603+
st2.prepare_states(inputs={"NA.a": [3, 5]})
604+
assert st2.group_for_inputs_final == {"NA.a": 0}
605+
assert st2.groups_stack_final == [[0]]
606+
assert st2.states_ind == [{"NA.a": 0}, {"NA.a": 1}]
607+
assert st2.states_val == [{"NA.a": 3}, {"NA.a": 5}]
608+
609+
st2.prepare_inputs()
610+
# since x,y come from the same state, they should have the same index
611+
assert st2.inputs_ind == [{"NB.x": 0, "NB.y": 0}, {"NB.x": 1, "NB.y": 1}]
612+
613+
614+
def test_state_connect_8():
615+
"""three 'connected' states: NA -> NB -> NC; NA -> NC (only NA has its own splitter)
616+
pydra should recognize, that there is only one splitter - NA
617+
and it should give the same as the previous test
618+
"""
619+
st1 = State(name="NA", splitter="a")
620+
st2 = State(name="NB", other_states={"NA": (st1, "b")})
621+
st3 = State(name="NC", other_states={"NA": (st1, "x"), "NB": (st2, "y")})
622+
# x comes from NA and y comes from NB, but NB has only NA's splitter,
623+
# so it should be treated as both inputs are from NA state
624+
assert st3.splitter == "_NA"
625+
assert st3.splitter_rpn == ["NA.a"]
626+
assert st3.prev_state_splitter == st3.splitter
627+
assert st3.prev_state_splitter_rpn == st3.splitter_rpn
628+
assert st3.current_splitter is None
629+
assert st3.current_splitter_rpn == []
630+
631+
st3.prepare_states(inputs={"NA.a": [3, 5]})
632+
assert st3.group_for_inputs_final == {"NA.a": 0}
633+
assert st3.groups_stack_final == [[0]]
634+
assert st3.states_ind == [{"NA.a": 0}, {"NA.a": 1}]
635+
assert st3.states_val == [{"NA.a": 3}, {"NA.a": 5}]
636+
637+
st3.prepare_inputs()
638+
# since x,y come from the same state (although y indirectly), they should have the same index
639+
assert st3.inputs_ind == [{"NC.x": 0, "NC.y": 0}, {"NC.x": 1, "NC.y": 1}]
640+
641+
642+
@pytest.mark.xfail(
643+
reason="doesn't recognize that NC.y has 4 elements (not independend on NC.x)"
644+
)
645+
def test_state_connect_9():
646+
"""four 'connected' states: NA1 -> NB; NA2 -> NB, NA1 -> NC; NB -> NC
647+
pydra should recognize, that there is only one splitter - NA_1 and NA_2
648+
649+
"""
650+
st1 = State(name="NA_1", splitter="a")
651+
st1a = State(name="NA_2", splitter="a")
652+
st2 = State(name="NB", other_states={"NA_1": (st1, "b"), "NA_2": (st1a, "c")})
653+
st3 = State(name="NC", other_states={"NA_1": (st1, "x"), "NB": (st2, "y")})
654+
# x comes from NA_1 and y comes from NB, but NB has only NA_1/2's splitters,
655+
assert st3.splitter == ["_NA_1", "_NA_2"]
656+
assert st3.splitter_rpn == ["NA_1.a", "NA_2.a", "*"]
657+
assert st3.prev_state_splitter == st3.splitter
658+
assert st3.prev_state_splitter_rpn == st3.splitter_rpn
659+
assert st3.current_splitter is None
660+
assert st3.current_splitter_rpn == []
661+
662+
st3.prepare_states(inputs={"NA_1.a": [3, 5], "NA_2.a": [11, 12]})
663+
assert st3.group_for_inputs_final == {"NA_1.a": 0, "NA_2.a": 1}
664+
assert st3.groups_stack_final == [[0, 1]]
665+
assert st3.states_ind == [
666+
{"NA_1.a": 0, "NA_2.a": 0},
667+
{"NA_1.a": 0, "NA_2.a": 1},
668+
{"NA_1.a": 1, "NA_2.a": 0},
669+
{"NA_1.a": 1, "NA_2.a": 1},
670+
]
671+
672+
st3.prepare_inputs()
673+
assert st3.inputs_ind == [
674+
{"NC.x": 0, "NC.y": 0},
675+
{"NC.x": 0, "NC.y": 1},
676+
{"NC.x": 1, "NC.y": 2},
677+
{"NC.x": 1, "NC.y": 3},
678+
]
679+
680+
589681
def test_state_connect_innerspl_1():
590682
"""two 'connected' states: testing groups, prepare_states and prepare_inputs,
591683
the second state has an inner splitter, full splitter provided
@@ -605,7 +697,7 @@ def test_state_connect_innerspl_1():
605697
inputs={"NA.a": [3, 5], "NB.b": [[1, 10, 100], [2, 20, 200]]},
606698
cont_dim={"NB.b": 2}, # will be treated as 2d container
607699
)
608-
assert st2.other_states["NA"][1] == "b"
700+
assert st2.other_states["NA"][1] == ["b"]
609701
assert st2.group_for_inputs_final == {"NA.a": 0, "NB.b": 1}
610702
assert st2.groups_stack_final == [[0], [1]]
611703

@@ -653,7 +745,7 @@ def test_state_connect_innerspl_1a():
653745
assert st2.current_splitter == "NB.b"
654746
assert st2.current_splitter_rpn == ["NB.b"]
655747

656-
assert st2.other_states["NA"][1] == "b"
748+
assert st2.other_states["NA"][1] == ["b"]
657749

658750
st2.prepare_states(
659751
inputs={"NA.a": [3, 5], "NB.b": [[1, 10, 100], [2, 20, 200]]},
@@ -717,7 +809,7 @@ def test_state_connect_innerspl_2():
717809
inputs={"NA.a": [3, 5], "NB.b": [[1, 10, 100], [2, 20, 200]], "NB.c": [13, 17]},
718810
cont_dim={"NB.b": 2}, # will be treated as 2d container
719811
)
720-
assert st2.other_states["NA"][1] == "b"
812+
assert st2.other_states["NA"][1] == ["b"]
721813
assert st2.group_for_inputs_final == {"NA.a": 0, "NB.c": 1, "NB.b": 2}
722814
assert st2.groups_stack_final == [[0], [1, 2]]
723815

@@ -778,7 +870,7 @@ def test_state_connect_innerspl_2a():
778870

779871
assert st2.splitter == ["_NA", ["NB.b", "NB.c"]]
780872
assert st2.splitter_rpn == ["NA.a", "NB.b", "NB.c", "*", "*"]
781-
assert st2.other_states["NA"][1] == "b"
873+
assert st2.other_states["NA"][1] == ["b"]
782874

783875
st2.prepare_states(
784876
inputs={"NA.a": [3, 5], "NB.b": [[1, 10, 100], [2, 20, 200]], "NB.c": [13, 17]},
@@ -839,6 +931,7 @@ def test_state_connect_innerspl_3():
839931
the second state has one inner splitter and one 'normal' splitter
840932
the prev-state parts of the splitter have to be added
841933
"""
934+
842935
st1 = State(name="NA", splitter="a")
843936
st2 = State(name="NB", splitter=["c", "b"], other_states={"NA": (st1, "b")})
844937
st3 = State(name="NC", splitter="d", other_states={"NB": (st2, "a")})
@@ -986,8 +1079,8 @@ def test_state_connect_innerspl_4():
9861079

9871080
assert st3.splitter == [["_NA", "_NB"], "NC.d"]
9881081
assert st3.splitter_rpn == ["NA.a", "NB.b", "NB.c", "*", "*", "NC.d", "*"]
989-
assert st3.other_states["NA"][1] == "e"
990-
assert st3.other_states["NB"][1] == "f"
1082+
assert st3.other_states["NA"][1] == ["e"]
1083+
assert st3.other_states["NB"][1] == ["f"]
9911084

9921085
st3.prepare_states(
9931086
inputs={
@@ -1736,12 +1829,12 @@ def test_connect_splitters_exception_1(splitter, other_states):
17361829

17371830

17381831
def test_connect_splitters_exception_2():
1739-
st = State(
1740-
name="CN",
1741-
splitter="_NB",
1742-
other_states={"NA": (State(name="NA", splitter="a"), "b")},
1743-
)
17441832
with pytest.raises(PydraStateError) as excinfo:
1833+
st = State(
1834+
name="CN",
1835+
splitter="_NB",
1836+
other_states={"NA": (State(name="NA", splitter="a"), "b")},
1837+
)
17451838
st.set_input_groups()
17461839
assert "can't ask for splitter from NB" in str(excinfo.value)
17471840

0 commit comments

Comments
 (0)