Skip to content

Commit 603854b

Browse files
committed
moving some functions from helper_state to the State methods; some mnt. edits
1 parent feafec8 commit 603854b

File tree

4 files changed

+185
-205
lines changed

4 files changed

+185
-205
lines changed

pydra/engine/helpers_state.py

Lines changed: 0 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -776,110 +776,6 @@ def map_splits(split_iter, inputs, cont_dim=None):
776776
}
777777

778778

779-
# Functions for merging and completing splitters in states.
780-
def connect_splitters(splitter, other_states, state_fields=False):
781-
"""Link splitters."""
782-
if splitter:
783-
# if splitter is string, have to check if this is Left or Right part (Left is required)
784-
if isinstance(splitter, str):
785-
# so this is the Left part
786-
if splitter.startswith("_"):
787-
left_part = _complete_left(
788-
left=splitter, other_states=other_states, state_fields=state_fields
789-
)
790-
right_part = None
791-
else: # this is Right part
792-
left_part = _complete_left(
793-
other_states=other_states, state_fields=state_fields
794-
)
795-
right_part = splitter
796-
# if splitter is tuple, it has to be either Left or Right part
797-
elif isinstance(splitter, tuple):
798-
lr_flag = _left_right_check(splitter, other_states=other_states)
799-
if lr_flag == "Left":
800-
left_part = _complete_left(left=splitter, other_states=other_states)
801-
right_part = None
802-
elif lr_flag == "Right":
803-
left_part = _complete_left(other_states=other_states)
804-
right_part = splitter
805-
else:
806-
raise Exception("splitter mix Left and Right parts in scalar splitter")
807-
elif isinstance(splitter, list):
808-
lr_flag = _left_right_check(splitter, other_states=other_states)
809-
if lr_flag == "Left":
810-
left_part = _complete_left(
811-
left=splitter, other_states=other_states, state_fields=state_fields
812-
)
813-
right_part = None
814-
elif lr_flag == "Right":
815-
left_part = _complete_left(
816-
other_states=other_states, state_fields=state_fields
817-
)
818-
right_part = splitter
819-
elif (
820-
_left_right_check(splitter[0], other_states=other_states) == "Left"
821-
and _left_right_check(splitter[1], other_states=other_states) == "Right"
822-
):
823-
left_part = _complete_left(
824-
left=splitter[0],
825-
other_states=other_states,
826-
state_fields=state_fields,
827-
)
828-
right_part = splitter[1]
829-
else:
830-
raise Exception("splitter doesn't have separated Left and Right parts")
831-
else:
832-
raise Exception(
833-
"splitter has to be str, tuple or list, "
834-
"{} was provided".format(type(splitter))
835-
)
836-
else:
837-
# if there is no splitter, I create the Left part
838-
left_part = _complete_left(other_states=other_states, state_fields=state_fields)
839-
right_part = None
840-
if right_part:
841-
splitter = [deepcopy(left_part), deepcopy(right_part)]
842-
else:
843-
splitter = deepcopy(left_part)
844-
return splitter, left_part, right_part
845-
846-
847-
def _complete_left(other_states, left=None, state_fields=False):
848-
"""Add all splitters from previous nodes (completing left part)."""
849-
if left:
850-
rpn_left = splitter2rpn(
851-
left, other_states=other_states, state_fields=state_fields
852-
)
853-
for name, (st, inp) in list(other_states.items())[::-1]:
854-
if "_{}".format(name) not in rpn_left and st.splitter_final:
855-
left = ["_{}".format(name), left]
856-
else:
857-
left = ["_{}".format(name) for name in other_states]
858-
if len(left) == 1:
859-
left = left[0]
860-
return left
861-
862-
863-
def _left_right_check(splitter_part, other_states):
864-
"""
865-
Check if splitter_part is purely Left or Right.
866-
867-
String is returned. If the splitter_part is mixed None is returned.
868-
869-
"""
870-
rpn_part = splitter2rpn(
871-
splitter_part, other_states=other_states, state_fields=False
872-
)
873-
inputs_in_splitter = [i for i in rpn_part if i not in ["*", "."]]
874-
others_in_splitter = [
875-
True if el.startswith("_") else False for el in inputs_in_splitter
876-
]
877-
if all(others_in_splitter):
878-
return "Left"
879-
elif (not all(others_in_splitter)) and (not any(others_in_splitter)):
880-
return "Right"
881-
882-
883779
def inputs_types_to_dict(name, inputs):
884780
"""Convert type.Inputs to dictionary."""
885781
# dj: any better option?

pydra/engine/state.py

Lines changed: 94 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,9 @@ def __init__(self, name, splitter=None, combiner=None, other_states=None):
9797
self.other_states = other_states
9898
self.splitter = splitter
9999
# if missing_connections, we can't continue, should wait for updates
100+
# TODO: should find a better way, so it's not in the init, but combiner complicates
100101
if not self.missing_connections:
101-
self.connect_splitters()
102+
self._connect_splitters()
102103
self.combiner = combiner
103104
self.inner_inputs = {}
104105
for name, (st, inp) in self.other_states.items():
@@ -193,21 +194,15 @@ def combiner(self, combiner):
193194
self._left_combiner = []
194195
self._right_combiner = []
195196

196-
def connect_splitters(self):
197+
def _connect_splitters(self):
197198
"""
198199
Connect splitters from previous nodes.
199200
200201
Evaluates Left (the part from previous states) and Right (current state) parts.
201202
202203
"""
203204
if self.other_states:
204-
(
205-
self.splitter,
206-
self._left_splitter,
207-
self._right_splitter,
208-
) = hlpst.connect_splitters(
209-
splitter=self.splitter, other_states=self.other_states
210-
)
205+
self._merge_splitters()
211206
# left rpn part, but keeping the names of the nodes, e.g. [_NA, _NB, *]
212207
self._left_splitter_rpn_compact = hlpst.splitter2rpn(
213208
deepcopy(self._left_splitter),
@@ -219,13 +214,99 @@ def connect_splitters(self):
219214
)
220215
else: # if other_states is empty there is only Right part
221216
self._left_splitter = None
222-
self._left_splitter_rpn_compact = []
223-
self._left_splitter_rpn = []
217+
self._left_splitter_rpn_compact, self._left_splitter_rpn = [], []
224218
self._right_splitter = self.splitter
225219
self._right_splitter_rpn = hlpst.splitter2rpn(
226220
deepcopy(self._right_splitter), other_states=self.other_states
227221
)
228222

223+
def _merge_splitters(self):
224+
"""
225+
Merging current splitter with the ones from other states.
226+
227+
If left splitter is not provided the splitter has to be completed.
228+
229+
"""
230+
if self.splitter:
231+
# if splitter is string, have to check if this is Left or Right part (Left is required)
232+
if isinstance(self.splitter, str):
233+
# so this is the Left part
234+
if self.splitter.startswith("_"):
235+
self._left_splitter = self._complete_left(left=self.splitter)
236+
self._right_splitter = None
237+
else: # this is Right part
238+
self._left_splitter = self._complete_left()
239+
self._right_splitter = self.splitter
240+
elif isinstance(self.splitter, (tuple, list)):
241+
lr_flag = self._left_right_check(self.splitter)
242+
if lr_flag == "Left":
243+
self._left_splitter = self._complete_left(left=self.splitter)
244+
self._right_splitter = None
245+
elif lr_flag == "Right":
246+
self._left_splitter = self._complete_left()
247+
self._right_splitter = self.splitter
248+
elif lr_flag == "[Left, Right]":
249+
self._left_splitter = self._complete_left(left=self.splitter[0])
250+
self._right_splitter = self.splitter[1]
251+
else:
252+
# if there is no splitter, I create the Left part
253+
self._left_splitter = self._complete_left()
254+
self._right_splitter = None
255+
256+
if self._right_splitter:
257+
self.splitter = [
258+
deepcopy(self._left_splitter),
259+
deepcopy(self._right_splitter),
260+
]
261+
else:
262+
self.splitter = deepcopy(self._left_splitter)
263+
264+
def _complete_left(self, left=None):
265+
"""Add all splitters from previous nodes (completing left part)."""
266+
if left:
267+
rpn_left = hlpst.splitter2rpn(
268+
left, other_states=self.other_states, state_fields=False
269+
)
270+
for name, (st, inp) in list(self.other_states.items())[::-1]:
271+
if "_{}".format(name) not in rpn_left and st.splitter_final:
272+
left = ["_{}".format(name), left]
273+
else:
274+
left = ["_{}".format(name) for name in self.other_states]
275+
if len(left) == 1:
276+
left = left[0]
277+
return left
278+
279+
def _left_right_check(self, splitter_part, rec_lev=0):
280+
"""
281+
Check if splitter_part is purely Left, Right
282+
or [Left, Right] if the splitter_part is a list (outer splitter)
283+
284+
String is returned.
285+
286+
If the splitter_part is mixed exception is raised.
287+
288+
"""
289+
rpn_part = hlpst.splitter2rpn(
290+
splitter_part, other_states=self.other_states, state_fields=False
291+
)
292+
inputs_in_splitter = [i for i in rpn_part if i not in ["*", "."]]
293+
others_in_splitter = [
294+
True if el.startswith("_") else False for el in inputs_in_splitter
295+
]
296+
if all(others_in_splitter):
297+
return "Left"
298+
elif (not all(others_in_splitter)) and (not any(others_in_splitter)):
299+
return "Right"
300+
elif (
301+
isinstance(self.splitter, list)
302+
and rec_lev == 0
303+
and self._left_right_check(self.splitter[0], rec_lev=1) == "Left"
304+
and self._left_right_check(self.splitter[1], rec_lev=1) == "Right"
305+
):
306+
return "[Left, Right]" # Left and Right parts separated in outer scalar
307+
else:
308+
raise Exception("Left and Right splitters are mixed - splitter invalid")
309+
229310
def set_splitter_final(self):
230311
"""Evaluate a final splitter after combining."""
231312
_splitter_rpn_final = hlpst.remove_inp_from_splitter_rpn(
@@ -309,7 +390,8 @@ def merge_previous_states(self):
309390
else:
310391
# if no element from st.splitter is in the current combiner,
311392
# using st attributes without changes
312-
self.keys_final += st.keys_final
393+
if st.keys_final:
394+
self.keys_final += st.keys_final
313395
group_for_inputs = st.group_for_inputs_final
314396
groups_stack = st.groups_stack_final
315397

pydra/engine/tests/test_helpers_state.py

Lines changed: 0 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -606,90 +606,3 @@ def test_groups_to_input(group_for_inputs, input_for_groups, ndim):
606606
res = hlpst.converter_groups_to_input(group_for_inputs)
607607
assert res[0] == input_for_groups
608608
assert res[1] == ndim
609-
610-
611-
@pytest.mark.parametrize(
612-
"splitter, other_states, expected_splitter, expected_left, expected_right",
613-
[
614-
(
615-
None,
616-
{"NA": (other_states_to_tests(splitter="NA.a"), "b")},
617-
"_NA",
618-
"_NA",
619-
None,
620-
),
621-
(
622-
"b",
623-
{"NA": (other_states_to_tests(splitter="NA.a"), "b")},
624-
["_NA", "b"],
625-
"_NA",
626-
"b",
627-
),
628-
(
629-
("b", "c"),
630-
{"NA": (other_states_to_tests(splitter="NA.a"), "b")},
631-
["_NA", ("b", "c")],
632-
"_NA",
633-
("b", "c"),
634-
),
635-
(
636-
None,
637-
{
638-
"NA": (other_states_to_tests(splitter="NA.a"), "a"),
639-
"NB": (other_states_to_tests(splitter="NB.a"), "b"),
640-
},
641-
["_NA", "_NB"],
642-
["_NA", "_NB"],
643-
None,
644-
),
645-
(
646-
"b",
647-
{
648-
"NA": (other_states_to_tests(splitter="NA.a"), "a"),
649-
"NB": (other_states_to_tests(splitter="NB.a"), "b"),
650-
},
651-
[["_NA", "_NB"], "b"],
652-
["_NA", "_NB"],
653-
"b",
654-
),
655-
(
656-
["_NA", "b"],
657-
{
658-
"NA": (other_states_to_tests(splitter="NA.a"), "a"),
659-
"NB": (other_states_to_tests(splitter="NB.a"), "b"),
660-
},
661-
[["_NB", "_NA"], "b"],
662-
["_NB", "_NA"],
663-
"b",
664-
),
665-
],
666-
)
667-
def test_connect_splitters(
668-
splitter, other_states, expected_splitter, expected_left, expected_right
669-
):
670-
updated_splitter, left_splitter, right_splitter = hlpst.connect_splitters(
671-
splitter, other_states
672-
)
673-
assert updated_splitter == expected_splitter
674-
assert left_splitter == expected_left
675-
assert right_splitter == expected_right
676-
677-
678-
@pytest.mark.parametrize(
679-
"splitter, other_states",
680-
[
681-
("_NB", {"NA": (other_states_to_tests(splitter="NA.a"), "b")}),
682-
(("_NA", "b"), {"NA": (other_states_to_tests(splitter="NA.a"), "b")}),
683-
(["b", "_NA"], {"NA": (other_states_to_tests(splitter="NA.a"), "b")}),
684-
(
685-
["_NB", ["_NA", "b"]],
686-
{
687-
"NA": (other_states_to_tests(splitter="NA.a"), "a"),
688-
"NB": (other_states_to_tests(splitter="NB.a"), "b"),
689-
},
690-
),
691-
],
692-
)
693-
def test_connect_splitters_exception(splitter, other_states):
694-
with pytest.raises(Exception):
695-
hlpst.connect_splitters(splitter, other_states, state_fields=True)

0 commit comments

Comments
 (0)