@@ -362,14 +362,15 @@ def iter_splits(iterable, keys):
362
362
yield dict (zip (keys , list (flatten (iter , max_depth = 1000 ))))
363
363
364
364
365
- def input_shape (in1 ):
365
+ def input_shape (inp , cont_dim = 1 ):
366
366
"""Get input shape."""
367
367
# TODO: have to be changed for inner splitter (sometimes different length)
368
- shape = [len (in1 )]
368
+ cont_dim -= 1
369
+ shape = [len (inp )]
369
370
last_shape = None
370
- for value in in1 :
371
- if isinstance (value , list ):
372
- cur_shape = input_shape (value )
371
+ for value in inp :
372
+ if isinstance (value , list ) and cont_dim > 0 :
373
+ cur_shape = input_shape (value , cont_dim )
373
374
if last_shape is None :
374
375
last_shape = cur_shape
375
376
elif last_shape != cur_shape :
@@ -383,11 +384,34 @@ def input_shape(in1):
383
384
return tuple (shape )
384
385
385
386
386
- def splits (splitter_rpn , inputs , inner_inputs = None ):
387
- """Split process as specified by an rpn splitter, from left to right."""
387
+ def splits (splitter_rpn , inputs , inner_inputs = None , cont_dim = None ):
388
+ """
389
+ Splits input variable as specified by splitter
390
+
391
+ Parameters
392
+ ----------
393
+ splitter_rpn : list
394
+ splitter in RPN notation
395
+ inputs: dict
396
+ input variables
397
+ inner_inputs: dict, optional
398
+ inner input specification
399
+
400
+
401
+ Returns
402
+ -------
403
+ splitter : list
404
+ each element contains indices for inputs
405
+ keys: list
406
+ names of input variables
407
+
408
+ """
409
+
388
410
stack = []
389
411
keys = []
390
- shapes_var = {}
412
+ if cont_dim is None :
413
+ cont_dim = {}
414
+ # analysing states from connected tasks if inner_inputs
391
415
if inner_inputs :
392
416
previous_states_ind = {
393
417
"_{}" .format (v .name ): (v .ind_l_final , v .keys_final )
@@ -407,9 +431,9 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
407
431
op_single ,
408
432
inputs ,
409
433
inner_inputs ,
410
- shapes_var ,
411
434
previous_states_ind ,
412
435
keys_fromLeftSpl ,
436
+ cont_dim = cont_dim ,
413
437
)
414
438
415
439
terms = {}
@@ -418,7 +442,11 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
418
442
shape = {}
419
443
# iterating splitter_rpn
420
444
for token in splitter_rpn :
421
- if token in ["." , "*" ]:
445
+ if token not in ["." , "*" ]: # token is one of the input var
446
+ # adding variable to the stack
447
+ stack .append (token )
448
+ else :
449
+ # removing Right and Left var from the stack
422
450
terms ["R" ] = stack .pop ()
423
451
terms ["L" ] = stack .pop ()
424
452
# checking if terms are strings, shapes, etc.
@@ -429,10 +457,14 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
429
457
trm_val [lr ] = previous_states_ind [term ][0 ]
430
458
shape [lr ] = (len (trm_val [lr ]),)
431
459
else :
432
- shape [lr ] = input_shape (inputs [term ])
460
+ if term in cont_dim :
461
+ shape [lr ] = input_shape (
462
+ inputs [term ], cont_dim = cont_dim [term ]
463
+ )
464
+ else :
465
+ shape [lr ] = input_shape (inputs [term ])
433
466
trm_val [lr ] = range (reduce (lambda x , y : x * y , shape [lr ]))
434
467
trm_str [lr ] = True
435
- shapes_var [term ] = shape [lr ]
436
468
else :
437
469
trm_val [lr ], shape [lr ] = term
438
470
trm_str [lr ] = False
@@ -447,6 +479,7 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
447
479
)
448
480
newshape = shape ["R" ]
449
481
if token == "*" :
482
+ # TODO: pomyslec
450
483
newshape = tuple (list (shape ["L" ]) + list (shape ["R" ]))
451
484
452
485
# creating list with keys
@@ -466,7 +499,6 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
466
499
elif trm_str ["R" ]:
467
500
keys = keys + new_keys ["R" ]
468
501
469
- #
470
502
newtrm_val = {}
471
503
for lr in ["R" , "L" ]:
472
504
# TODO: rewrite once I have more tests
@@ -491,13 +523,11 @@ def splits(splitter_rpn, inputs, inner_inputs=None):
491
523
492
524
pushval = (op [token ](newtrm_val ["L" ], newtrm_val ["R" ]), newshape )
493
525
stack .append (pushval )
494
- else : # name of one of the inputs (token not in [".", "*"])
495
- stack .append (token )
496
526
497
527
val = stack .pop ()
498
528
if isinstance (val , tuple ):
499
529
val = val [0 ]
500
- return val , keys , shapes_var , keys_fromLeftSpl
530
+ return val , keys , keys_fromLeftSpl
501
531
502
532
503
533
# dj: TODO: do I need keys?
@@ -636,17 +666,22 @@ def splits_groups(splitter_rpn, combiner=None, inner_inputs=None):
636
666
637
667
638
668
def _single_op_splits (
639
- op_single , inputs , inner_inputs , shapes_var , previous_states_ind , keys_fromLeftSpl
669
+ op_single ,
670
+ inputs ,
671
+ inner_inputs ,
672
+ previous_states_ind ,
673
+ keys_fromLeftSpl ,
674
+ cont_dim = None ,
640
675
):
641
676
if op_single .startswith ("_" ):
642
677
return (
643
678
previous_states_ind [op_single ][0 ],
644
679
previous_states_ind [op_single ][1 ],
645
- None ,
646
680
keys_fromLeftSpl ,
647
681
)
648
- shape = input_shape (inputs [op_single ])
649
- shapes_var [op_single ] = shape
682
+ if cont_dim is None :
683
+ cont_dim = {}
684
+ shape = input_shape (inputs [op_single ], cont_dim = cont_dim .get (op_single , 1 ))
650
685
trmval = range (reduce (lambda x , y : x * y , shape ))
651
686
if op_single in inner_inputs :
652
687
# TODO: have to be changed if differ length
@@ -659,11 +694,11 @@ def _single_op_splits(
659
694
res = op ["." ](op_out , trmval )
660
695
val = res
661
696
keys = inner_inputs [op_single ].keys_final + [op_single ]
662
- return val , keys , shapes_var , keys_fromLeftSpl
697
+ return val , keys , keys_fromLeftSpl
663
698
else :
664
699
val = op ["*" ](trmval )
665
700
keys = [op_single ]
666
- return val , keys , shapes_var , keys_fromLeftSpl
701
+ return val , keys , keys_fromLeftSpl
667
702
668
703
669
704
def _single_op_splits_groups (
@@ -727,10 +762,15 @@ def combine_final_groups(combiner, groups, groups_stack, keys):
727
762
return keys_final , groups_final , groups_stack_final , combiner_all
728
763
729
764
730
- def map_splits (split_iter , inputs ):
765
+ def map_splits (split_iter , inputs , cont_dim = None ):
731
766
"""Get a dictionary of prescribed splits."""
767
+ if cont_dim is None :
768
+ cont_dim = {}
732
769
for split in split_iter :
733
- yield {k : list (flatten (ensure_list (inputs [k ])))[v ] for k , v in split .items ()}
770
+ yield {
771
+ k : list (flatten (ensure_list (inputs [k ]), max_depth = cont_dim .get (k , None )))[v ]
772
+ for k , v in split .items ()
773
+ }
734
774
735
775
736
776
# Functions for merging and completing splitters in states.
0 commit comments