Skip to content

Commit 287eec9

Browse files
committed
some small edits to the state class, adding comments
1 parent bce5ca3 commit 287eec9

File tree

3 files changed

+79
-95
lines changed

3 files changed

+79
-95
lines changed

pydra/engine/helpers_state.py

Lines changed: 47 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ def splitter2rpn(splitter, other_states=None, state_fields=True):
1919
Parameters
2020
----------
2121
splitter :
22-
TODO
22+
splitter (standard form)
2323
other_states :
24-
TODO
24+
other states that are connected to the state
2525
state_fields : :obj:`bool`
26-
TODO
26+
if False the splitter from the previous states are unwrapped
2727
2828
"""
2929
if not splitter:
@@ -174,7 +174,6 @@ def _iterate_list(element, sign, other_states, output_splitter, state_fields=Tru
174174
)
175175

176176

177-
# functions used in State to know which element should be used for a specific axis
178177
def converter_groups_to_input(group_for_inputs):
179178
"""
180179
Return fields for each axis and number of all groups.
@@ -184,7 +183,7 @@ def converter_groups_to_input(group_for_inputs):
184183
Parameters
185184
----------
186185
group_for_inputs :
187-
TODO
186+
specified axes (groups) for each input
188187
189188
"""
190189
input_for_axis = {}
@@ -199,7 +198,6 @@ def converter_groups_to_input(group_for_inputs):
199198
return input_for_axis, ngr
200199

201200

202-
# function used in State if combiner
203201
def remove_inp_from_splitter_rpn(splitter_rpn, inputs_to_remove):
204202
"""
205203
Remove inputs due to combining.
@@ -211,7 +209,7 @@ def remove_inp_from_splitter_rpn(splitter_rpn, inputs_to_remove):
211209
splitter_rpn :
212210
The splitter in reverse polish notation
213211
inputs_to_remove :
214-
TODO
212+
input names that should be removed from the splitter
215213
216214
"""
217215
splitter_rpn_copy = splitter_rpn.copy()
@@ -258,12 +256,12 @@ def rpn2splitter(splitter_rpn):
258256
Parameters
259257
----------
260258
splitter_rpn :
261-
TODO
259+
splitter in reverse polish notation
262260
263261
Returns
264262
-------
265263
splitter :
266-
TODO
264+
splitter in the standard/original form
267265
268266
"""
269267
if splitter_rpn == []:
@@ -298,9 +296,8 @@ def rpn2splitter(splitter_rpn):
298296
return rpn2splitter(splitter_modified)
299297

300298

301-
# used in the Node to change names in a splitter and combiner
302299
def add_name_combiner(combiner, name):
303-
"""Add a combiner."""
300+
""" adding a node's name to each field from the combiner"""
304301
combiner_changed = []
305302
for comb in combiner:
306303
if "." not in comb:
@@ -311,7 +308,7 @@ def add_name_combiner(combiner, name):
311308

312309

313310
def add_name_splitter(splitter, name):
314-
"""Change names of splitter: adding names of the node."""
311+
""" adding a node's name to each field from the splitter"""
315312
if isinstance(splitter, str):
316313
return _add_name([splitter], name)[0]
317314
elif isinstance(splitter, list):
@@ -322,6 +319,7 @@ def add_name_splitter(splitter, name):
322319

323320

324321
def _add_name(mlist, name):
322+
""" adding anem to each element from the list"""
325323
for i, elem in enumerate(mlist):
326324
if isinstance(elem, str):
327325
if "." in elem or elem.startswith("_"):
@@ -421,22 +419,15 @@ def splits(splitter_rpn, inputs, inner_inputs=None, cont_dim=None):
421419
for _, v in inner_inputs.items()
422420
}
423421
inner_inputs = {k: v for k, v in inner_inputs.items() if k in splitter_rpn}
424-
keys_fromLeftSpl = ["_{}".format(st.name) for _, st in inner_inputs.items()]
425422
else:
426423
previous_states_ind = {}
427424
inner_inputs = {}
428-
keys_fromLeftSpl = []
429425

430426
# when splitter is a single element (no operators)
431427
if len(splitter_rpn) == 1:
432428
op_single = splitter_rpn[0]
433429
return _single_op_splits(
434-
op_single,
435-
inputs,
436-
inner_inputs,
437-
previous_states_ind,
438-
keys_fromLeftSpl,
439-
cont_dim=cont_dim,
430+
op_single, inputs, inner_inputs, previous_states_ind, cont_dim=cont_dim
440431
)
441432

442433
terms = {}
@@ -530,12 +521,41 @@ def splits(splitter_rpn, inputs, inner_inputs=None, cont_dim=None):
530521
val = stack.pop()
531522
if isinstance(val, tuple):
532523
val = val[0]
533-
return val, keys, keys_fromLeftSpl
524+
return val, keys
525+
526+
527+
def _single_op_splits(
528+
op_single, inputs, inner_inputs, previous_states_ind, cont_dim=None
529+
):
530+
""" splits function if splitter is a singleton"""
531+
if op_single.startswith("_"):
532+
return (previous_states_ind[op_single][0], previous_states_ind[op_single][1])
533+
if cont_dim is None:
534+
cont_dim = {}
535+
shape = input_shape(inputs[op_single], cont_dim=cont_dim.get(op_single, 1))
536+
trmval = range(reduce(lambda x, y: x * y, shape))
537+
if op_single in inner_inputs:
538+
# TODO: have to be changed if differ length
539+
inner_len = [shape[-1]] * reduce(lambda x, y: x * y, shape[:-1])
540+
# this come from the previous node
541+
outer_ind = inner_inputs[op_single].ind_l
542+
op_out = itertools.chain.from_iterable(
543+
itertools.repeat(x, n) for x, n in zip(outer_ind, inner_len)
544+
)
545+
res = op["."](op_out, trmval)
546+
val = res
547+
keys = inner_inputs[op_single].keys_final + [op_single]
548+
return val, keys
549+
else:
550+
val = op["*"](trmval)
551+
keys = [op_single]
552+
return val, keys
534553

535554

536-
# dj: TODO: do I need keys?
537555
def splits_groups(splitter_rpn, combiner=None, inner_inputs=None):
538-
"""Process splitter rpn from left to right."""
556+
""" splits inputs to groups (axes) and creates stacks for these groups
557+
This is used to specify which input can be combined.
558+
"""
539559
if not splitter_rpn:
540560
return [], {}, [], []
541561
stack = []
@@ -556,9 +576,7 @@ def splits_groups(splitter_rpn, combiner=None, inner_inputs=None):
556576
# when splitter is a single element (no operators)
557577
if len(splitter_rpn) == 1:
558578
op_single = splitter_rpn[0]
559-
return _single_op_splits_groups(
560-
op_single, combiner, inner_inputs, previous_states_ind, groups
561-
)
579+
return _single_op_splits_groups(op_single, combiner, inner_inputs, groups)
562580

563581
# len(splitter_rpn) > 1
564582
# iterating splitter_rpn
@@ -668,45 +686,8 @@ def splits_groups(splitter_rpn, combiner=None, inner_inputs=None):
668686
return keys, groups, groups_stack, []
669687

670688

671-
def _single_op_splits(
672-
op_single,
673-
inputs,
674-
inner_inputs,
675-
previous_states_ind,
676-
keys_fromLeftSpl, # TODO NOW do I need it?
677-
cont_dim=None,
678-
):
679-
if op_single.startswith("_"):
680-
return (
681-
previous_states_ind[op_single][0],
682-
previous_states_ind[op_single][1],
683-
keys_fromLeftSpl,
684-
)
685-
if cont_dim is None:
686-
cont_dim = {}
687-
shape = input_shape(inputs[op_single], cont_dim=cont_dim.get(op_single, 1))
688-
trmval = range(reduce(lambda x, y: x * y, shape))
689-
if op_single in inner_inputs:
690-
# TODO: have to be changed if differ length
691-
inner_len = [shape[-1]] * reduce(lambda x, y: x * y, shape[:-1])
692-
# this come from the previous node
693-
outer_ind = inner_inputs[op_single].ind_l
694-
op_out = itertools.chain.from_iterable(
695-
itertools.repeat(x, n) for x, n in zip(outer_ind, inner_len)
696-
)
697-
res = op["."](op_out, trmval)
698-
val = res
699-
keys = inner_inputs[op_single].keys_final + [op_single]
700-
return val, keys, keys_fromLeftSpl
701-
else:
702-
val = op["*"](trmval)
703-
keys = [op_single]
704-
return val, keys, keys_fromLeftSpl
705-
706-
707-
def _single_op_splits_groups(
708-
op_single, combiner, inner_inputs, previous_states_ind, groups
709-
):
689+
def _single_op_splits_groups(op_single, combiner, inner_inputs, groups):
690+
""" splits_groups function if splitter is a singleton"""
710691
if op_single in inner_inputs:
711692
# TODO: have to be changed if differ length
712693
# TODO: i think I don't want to add here from left part
@@ -768,7 +749,7 @@ def combine_final_groups(combiner, groups, groups_stack, keys):
768749

769750

770751
def map_splits(split_iter, inputs, cont_dim=None):
771-
"""Get a dictionary of prescribed splits."""
752+
"""generate a dictionary of inputs prescribed by the splitter."""
772753
if cont_dim is None:
773754
cont_dim = {}
774755
for split in split_iter:

pydra/engine/state.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def splitter_rpn_final(self):
132132

133133
@property
134134
def splitter_final(self):
135+
""" final splitter, after removing the combined fields"""
135136
return hlpst.rpn2splitter(self.splitter_rpn_final)
136137

137138
@property
@@ -153,6 +154,7 @@ def splitter_rpn_compact(self):
153154

154155
@property
155156
def right_splitter(self):
157+
""" current state splitter (i.e. the Right part)"""
156158
lr_flag = self._left_right_check(self.splitter)
157159
if lr_flag == "Left":
158160
return None
@@ -161,8 +163,19 @@ def right_splitter(self):
161163
elif lr_flag == "[Left, Right]":
162164
return self.splitter[1]
163165

166+
@property
167+
def right_splitter_rpn(self):
168+
if self.right_splitter:
169+
right_splitter_rpn = hlpst.splitter2rpn(
170+
self.right_splitter, other_states=self.other_states
171+
)
172+
return right_splitter_rpn
173+
else:
174+
return []
175+
164176
@property
165177
def left_splitter(self):
178+
""" splitters from the previous stated (i.e. the Light part)"""
166179
if hasattr(self, "_left_splitter"):
167180
return self._left_splitter
168181
else:
@@ -261,15 +274,13 @@ def inner_inputs(self):
261274
return {}
262275

263276
def update_connections(self, new_other_states=None, new_combiner=None):
277+
""" updating states connections and input groups"""
264278
if new_other_states:
265279
self.other_states = new_other_states
266280
self._connect_splitters()
267281
if new_combiner:
268282
self.combiner = new_combiner
269283
self.set_input_groups()
270-
self.states_val = []
271-
self.inputs_ind = []
272-
self.final_combined_ind_mapping = {}
273284

274285
def _connect_splitters(self):
275286
"""
@@ -308,7 +319,7 @@ def _connect_splitters(self):
308319
self.splitter = deepcopy(self._left_splitter)
309320

310321
def _complete_left(self, left=None):
311-
"""Add all splitters from previous nodes (completing left part)."""
322+
"""Add all splitters from previous nodes (completing the Left part)."""
312323
if left:
313324
rpn_left = hlpst.splitter2rpn(
314325
left, other_states=self.other_states, state_fields=False
@@ -470,6 +481,7 @@ def push_new_states(self):
470481
self.groups_stack_final.append(stack)
471482

472483
def splitter_validation(self):
484+
""" validating if the splitter is correct (after all states are connected)"""
473485
for spl in self.splitter_rpn_compact:
474486
if not (
475487
spl in [".", "*"]
@@ -483,6 +495,7 @@ def splitter_validation(self):
483495
)
484496

485497
def combiner_validation(self):
498+
""" validating if the combiner is correct (after all states are connected)"""
486499
if self.combiner:
487500
if not self.splitter:
488501
raise Exception("splitter has to be set before setting combiner")
@@ -516,7 +529,6 @@ def prepare_states(self, inputs, cont_dim=None):
516529
for nm, (st, _) in self.other_states.items():
517530
# I think now this if is never used
518531
if not hasattr(st, "states_ind"):
519-
# dj: should i provide different inputs?
520532
st.prepare_states(self.inputs, cont_dim=cont_dim)
521533
self.inputs.update(st.inputs)
522534
self.prepare_states_ind()
@@ -532,6 +544,7 @@ def prepare_states_ind(self):
532544
# removing elements that are connected to inner splitter
533545
# (they will be taken into account in hlpst.splits anyway)
534546
# _comb part will be used in prepare_states_combined_ind
547+
# TODO: need tests in test_Workflow.py
535548
elements_to_remove = []
536549
elements_to_remove_comb = []
537550
for name, (st, inp) in self.other_states.items():
@@ -546,7 +559,7 @@ def prepare_states_ind(self):
546559
partial_rpn = hlpst.remove_inp_from_splitter_rpn(
547560
deepcopy(self.splitter_rpn_compact), elements_to_remove
548561
)
549-
values_out_pr, keys_out_pr, kL = hlpst.splits(
562+
values_out_pr, keys_out_pr = hlpst.splits(
550563
partial_rpn,
551564
self.inputs,
552565
inner_inputs=self.inner_inputs,
@@ -581,9 +594,8 @@ def prepare_states_combined_ind(self, elements_to_remove_comb):
581594
combined_rpn = hlpst.remove_inp_from_splitter_rpn(
582595
deepcopy(partial_rpn), self.right_combiner_all + self.left_combiner_all
583596
)
584-
# TODO: create a function for this!!
585597
if combined_rpn:
586-
val_r, key_r, _ = hlpst.splits(
598+
val_r, key_r = hlpst.splits(
587599
combined_rpn,
588600
self.inputs,
589601
inner_inputs=self.inner_inputs,
@@ -596,7 +608,6 @@ def prepare_states_combined_ind(self, elements_to_remove_comb):
596608

597609
keys_out = key_r
598610
if values:
599-
# NOW TODO: move to init?
600611
self.ind_l_final = values
601612
self.keys_final = keys_out
602613
# groups after combiner
@@ -628,28 +639,19 @@ def prepare_states_val(self):
628639

629640
def prepare_inputs(self):
630641
"""
631-
Get inputs ready.
642+
Preparing inputs indices, merges input from previous states.
632643
633-
1. Remove elements that come from connected states.
634-
2. Merge elements that come from outputs of previous nodes.
635-
3. Remove elements connected to the inner splitter.
644+
Includes indices for fields from inner splitters
645+
(removes elements connected to the inner splitters fields).
636646
637647
"""
638648
if not self.other_states:
639649
self.inputs_ind = self.states_ind
640650
else:
641-
# removing elements that come from connected states
642-
elements_to_remove = [
643-
spl
644-
for spl in self.splitter_rpn_compact
645-
if spl[1:] in self.other_states.keys()
646-
]
647-
partial_rpn = hlpst.remove_inp_from_splitter_rpn(
648-
deepcopy(self.splitter_rpn_compact), elements_to_remove
649-
)
650-
if partial_rpn:
651-
values_inp, keys_inp, _ = hlpst.splits(
652-
partial_rpn,
651+
# elements from the current node (the Right part)
652+
if self.right_splitter_rpn:
653+
values_inp, keys_inp = hlpst.splits(
654+
self.right_splitter_rpn,
653655
self.inputs,
654656
inner_inputs=self.inner_inputs,
655657
cont_dim=self.cont_dim,
@@ -700,5 +702,6 @@ def prepare_inputs(self):
700702
# iter_splits using inputs from current state/node
701703
self.inputs_ind = list(hlpst.iter_splits(inputs_ind, keys_inp))
702704
# removing elements that are connected to inner splitter
705+
# TODO - add tests to test_workflow.py (not sure if we want to remove it)
703706
for el in connected_to_inner:
704707
[dict.pop(el) for dict in self.inputs_ind]

0 commit comments

Comments
 (0)