Skip to content

Commit e37e769

Browse files
committed
adding cond_dim option to splitter
1 parent 81af3b8 commit e37e769

File tree

5 files changed

+217
-71
lines changed

5 files changed

+217
-71
lines changed

pydra/engine/helpers_state.py

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -362,14 +362,15 @@ def iter_splits(iterable, keys):
362362
yield dict(zip(keys, list(flatten(iter, max_depth=1000))))
363363

364364

365-
def input_shape(in1):
365+
def input_shape(inp, cont_dim=1):
366366
"""Get input shape."""
367367
# TODO: have to be changed for inner splitter (sometimes different length)
368-
shape = [len(in1)]
368+
cont_dim -= 1
369+
shape = [len(inp)]
369370
last_shape = None
370-
for value in in1:
371-
if isinstance(value, list):
372-
cur_shape = input_shape(value)
371+
for value in inp:
372+
if isinstance(value, list) and cont_dim > 0:
373+
cur_shape = input_shape(value, cont_dim)
373374
if last_shape is None:
374375
last_shape = cur_shape
375376
elif last_shape != cur_shape:
@@ -383,11 +384,34 @@ def input_shape(in1):
383384
return tuple(shape)
384385

385386

386-
def splits(splitter_rpn, inputs, inner_inputs=None):
387-
"""Split process as specified by an rpn splitter, from left to right."""
387+
def splits(splitter_rpn, inputs, inner_inputs=None, cont_dim=None):
388+
"""
389+
Splits input variable as specified by splitter
390+
391+
Parameters
392+
----------
393+
splitter_rpn : list
394+
splitter in RPN notation
395+
inputs: dict
396+
input variables
397+
inner_inputs: dict, optional
398+
inner input specification
399+
400+
401+
Returns
402+
-------
403+
splitter : list
404+
each element contains indices for inputs
405+
keys: list
406+
names of input variables
407+
408+
"""
409+
388410
stack = []
389411
keys = []
390-
shapes_var = {}
412+
if cont_dim is None:
413+
cont_dim = {}
414+
# analysing states from connected tasks if inner_inputs
391415
if inner_inputs:
392416
previous_states_ind = {
393417
"_{}".format(v.name): (v.ind_l_final, v.keys_final)
@@ -407,9 +431,9 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
407431
op_single,
408432
inputs,
409433
inner_inputs,
410-
shapes_var,
411434
previous_states_ind,
412435
keys_fromLeftSpl,
436+
cont_dim=cont_dim,
413437
)
414438

415439
terms = {}
@@ -418,7 +442,11 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
418442
shape = {}
419443
# iterating splitter_rpn
420444
for token in splitter_rpn:
421-
if token in [".", "*"]:
445+
if token not in [".", "*"]: # token is one of the input var
446+
# adding variable to the stack
447+
stack.append(token)
448+
else:
449+
# removing Right and Left var from the stack
422450
terms["R"] = stack.pop()
423451
terms["L"] = stack.pop()
424452
# checking if terms are strings, shapes, etc.
@@ -429,10 +457,14 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
429457
trm_val[lr] = previous_states_ind[term][0]
430458
shape[lr] = (len(trm_val[lr]),)
431459
else:
432-
shape[lr] = input_shape(inputs[term])
460+
if term in cont_dim:
461+
shape[lr] = input_shape(
462+
inputs[term], cont_dim=cont_dim[term]
463+
)
464+
else:
465+
shape[lr] = input_shape(inputs[term])
433466
trm_val[lr] = range(reduce(lambda x, y: x * y, shape[lr]))
434467
trm_str[lr] = True
435-
shapes_var[term] = shape[lr]
436468
else:
437469
trm_val[lr], shape[lr] = term
438470
trm_str[lr] = False
@@ -447,6 +479,7 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
447479
)
448480
newshape = shape["R"]
449481
if token == "*":
482+
# TODO: pomyslec
450483
newshape = tuple(list(shape["L"]) + list(shape["R"]))
451484

452485
# creating list with keys
@@ -466,7 +499,6 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
466499
elif trm_str["R"]:
467500
keys = keys + new_keys["R"]
468501

469-
#
470502
newtrm_val = {}
471503
for lr in ["R", "L"]:
472504
# TODO: rewrite once I have more tests
@@ -491,13 +523,11 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
491523

492524
pushval = (op[token](newtrm_val["L"], newtrm_val["R"]), newshape)
493525
stack.append(pushval)
494-
else: # name of one of the inputs (token not in [".", "*"])
495-
stack.append(token)
496526

497527
val = stack.pop()
498528
if isinstance(val, tuple):
499529
val = val[0]
500-
return val, keys, shapes_var, keys_fromLeftSpl
530+
return val, keys, keys_fromLeftSpl
501531

502532

503533
# dj: TODO: do I need keys?
@@ -636,17 +666,22 @@ def splits_groups(splitter_rpn, combiner=None, inner_inputs=None):
636666

637667

638668
def _single_op_splits(
639-
op_single, inputs, inner_inputs, shapes_var, previous_states_ind, keys_fromLeftSpl
669+
op_single,
670+
inputs,
671+
inner_inputs,
672+
previous_states_ind,
673+
keys_fromLeftSpl,
674+
cont_dim=None,
640675
):
641676
if op_single.startswith("_"):
642677
return (
643678
previous_states_ind[op_single][0],
644679
previous_states_ind[op_single][1],
645-
None,
646680
keys_fromLeftSpl,
647681
)
648-
shape = input_shape(inputs[op_single])
649-
shapes_var[op_single] = shape
682+
if cont_dim is None:
683+
cont_dim = {}
684+
shape = input_shape(inputs[op_single], cont_dim=cont_dim.get(op_single, 1))
650685
trmval = range(reduce(lambda x, y: x * y, shape))
651686
if op_single in inner_inputs:
652687
# TODO: have to be changed if differ length
@@ -659,11 +694,11 @@ def _single_op_splits(
659694
res = op["."](op_out, trmval)
660695
val = res
661696
keys = inner_inputs[op_single].keys_final + [op_single]
662-
return val, keys, shapes_var, keys_fromLeftSpl
697+
return val, keys, keys_fromLeftSpl
663698
else:
664699
val = op["*"](trmval)
665700
keys = [op_single]
666-
return val, keys, shapes_var, keys_fromLeftSpl
701+
return val, keys, keys_fromLeftSpl
667702

668703

669704
def _single_op_splits_groups(
@@ -727,10 +762,15 @@ def combine_final_groups(combiner, groups, groups_stack, keys):
727762
return keys_final, groups_final, groups_stack_final, combiner_all
728763

729764

730-
def map_splits(split_iter, inputs):
765+
def map_splits(split_iter, inputs, cont_dim=None):
731766
"""Get a dictionary of prescribed splits."""
767+
if cont_dim is None:
768+
cont_dim = {}
732769
for split in split_iter:
733-
yield {k: list(flatten(ensure_list(inputs[k])))[v] for k, v in split.items()}
770+
yield {
771+
k: list(flatten(ensure_list(inputs[k]), max_depth=cont_dim.get(k, None)))[v]
772+
for k, v in split.items()
773+
}
734774

735775

736776
# Functions for merging and completing splitters in states.

pydra/engine/state.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def push_new_states(self):
346346
stack = [gr + nr_gr_f for gr in stack]
347347
self.groups_stack_final.append(stack)
348348

349-
def prepare_states(self, inputs):
349+
def prepare_states(self, inputs, cont_dim=None):
350350
"""
351351
Prepare a full list of state indices and state values.
352352
@@ -357,6 +357,10 @@ def prepare_states(self, inputs):
357357
specific elements from inputs that can be used running interfaces
358358
359359
"""
360+
if cont_dim is None:
361+
self.cont_dim = {}
362+
else:
363+
self.cont_dim = cont_dim
360364
if isinstance(inputs, BaseSpec):
361365
self.inputs = hlpst.inputs_types_to_dict(self.name, inputs)
362366
else:
@@ -366,7 +370,7 @@ def prepare_states(self, inputs):
366370
# I think now this if is never used
367371
if not hasattr(st, "states_ind"):
368372
# dj: should i provide different inputs?
369-
st.prepare_states(self.inputs)
373+
st.prepare_states(self.inputs, cont_dim=cont_dim)
370374
self.inputs.update(st.inputs)
371375
self.prepare_states_ind()
372376
self.prepare_states_val()
@@ -395,8 +399,11 @@ def prepare_states_ind(self):
395399
partial_rpn = hlpst.remove_inp_from_splitter_rpn(
396400
deepcopy(self.splitter_rpn_compact), elements_to_remove
397401
)
398-
values_out_pr, keys_out_pr, _, kL = hlpst.splits(
399-
partial_rpn, self.inputs, inner_inputs=self.inner_inputs
402+
values_out_pr, keys_out_pr, kL = hlpst.splits(
403+
partial_rpn,
404+
self.inputs,
405+
inner_inputs=self.inner_inputs,
406+
cont_dim=self.cont_dim,
400407
)
401408
values_pr = list(values_out_pr)
402409

@@ -429,8 +436,11 @@ def prepare_states_combined_ind(self, elements_to_remove_comb):
429436
)
430437
# TODO: create a function for this!!
431438
if combined_rpn:
432-
val_r, key_r, _, _ = hlpst.splits(
433-
combined_rpn, self.inputs, inner_inputs=self.inner_inputs
439+
val_r, key_r, _ = hlpst.splits(
440+
combined_rpn,
441+
self.inputs,
442+
inner_inputs=self.inner_inputs,
443+
cont_dim=self.cont_dim,
434444
)
435445
values = list(val_r)
436446
else:
@@ -464,7 +474,9 @@ def prepare_states_combined_ind(self, elements_to_remove_comb):
464474

465475
def prepare_states_val(self):
466476
"""Evaluate states values having states indices."""
467-
self.states_val = list(hlpst.map_splits(self.states_ind, self.inputs))
477+
self.states_val = list(
478+
hlpst.map_splits(self.states_ind, self.inputs, cont_dim=self.cont_dim)
479+
)
468480
return self.states_val
469481

470482
def prepare_inputs(self):
@@ -489,8 +501,11 @@ def prepare_inputs(self):
489501
deepcopy(self.splitter_rpn_compact), elements_to_remove
490502
)
491503
if partial_rpn:
492-
values_inp, keys_inp, _, _ = hlpst.splits(
493-
partial_rpn, self.inputs, inner_inputs=self.inner_inputs
504+
values_inp, keys_inp, _ = hlpst.splits(
505+
partial_rpn,
506+
self.inputs,
507+
inner_inputs=self.inner_inputs,
508+
cont_dim=self.cont_dim,
494509
)
495510
inputs_ind = values_inp
496511
else:

0 commit comments

Comments
 (0)