@@ -19,11 +19,11 @@ def splitter2rpn(splitter, other_states=None, state_fields=True):
19
19
Parameters
20
20
----------
21
21
splitter :
22
- TODO
22
+ splitter (standard form)
23
23
other_states :
24
- TODO
24
+ other states that are connected to the state
25
25
state_fields : :obj:`bool`
26
- TODO
26
+ if False the splitter from the previous states are unwrapped
27
27
28
28
"""
29
29
if not splitter :
@@ -174,7 +174,6 @@ def _iterate_list(element, sign, other_states, output_splitter, state_fields=Tru
174
174
)
175
175
176
176
177
- # functions used in State to know which element should be used for a specific axis
178
177
def converter_groups_to_input (group_for_inputs ):
179
178
"""
180
179
Return fields for each axis and number of all groups.
@@ -184,7 +183,7 @@ def converter_groups_to_input(group_for_inputs):
184
183
Parameters
185
184
----------
186
185
group_for_inputs :
187
- TODO
186
+ specified axes (groups) for each input
188
187
189
188
"""
190
189
input_for_axis = {}
@@ -199,7 +198,6 @@ def converter_groups_to_input(group_for_inputs):
199
198
return input_for_axis , ngr
200
199
201
200
202
- # function used in State if combiner
203
201
def remove_inp_from_splitter_rpn (splitter_rpn , inputs_to_remove ):
204
202
"""
205
203
Remove inputs due to combining.
@@ -211,7 +209,7 @@ def remove_inp_from_splitter_rpn(splitter_rpn, inputs_to_remove):
211
209
splitter_rpn :
212
210
The splitter in reverse polish notation
213
211
inputs_to_remove :
214
- TODO
212
+ input names that should be removed from the splitter
215
213
216
214
"""
217
215
splitter_rpn_copy = splitter_rpn .copy ()
@@ -258,12 +256,12 @@ def rpn2splitter(splitter_rpn):
258
256
Parameters
259
257
----------
260
258
splitter_rpn :
261
- TODO
259
+ splitter in reverse polish notation
262
260
263
261
Returns
264
262
-------
265
263
splitter :
266
- TODO
264
+ splitter in the standard/original form
267
265
268
266
"""
269
267
if splitter_rpn == []:
@@ -298,9 +296,8 @@ def rpn2splitter(splitter_rpn):
298
296
return rpn2splitter (splitter_modified )
299
297
300
298
301
- # used in the Node to change names in a splitter and combiner
302
299
def add_name_combiner (combiner , name ):
303
- """Add a combiner. """
300
+ """ adding a node's name to each field from the combiner"""
304
301
combiner_changed = []
305
302
for comb in combiner :
306
303
if "." not in comb :
@@ -311,7 +308,7 @@ def add_name_combiner(combiner, name):
311
308
312
309
313
310
def add_name_splitter (splitter , name ):
314
- """Change names of splitter: adding names of the node. """
311
+ """ adding a node's name to each field from the splitter """
315
312
if isinstance (splitter , str ):
316
313
return _add_name ([splitter ], name )[0 ]
317
314
elif isinstance (splitter , list ):
@@ -322,6 +319,7 @@ def add_name_splitter(splitter, name):
322
319
323
320
324
321
def _add_name (mlist , name ):
322
+ """ adding anem to each element from the list"""
325
323
for i , elem in enumerate (mlist ):
326
324
if isinstance (elem , str ):
327
325
if "." in elem or elem .startswith ("_" ):
@@ -421,22 +419,15 @@ def splits(splitter_rpn, inputs, inner_inputs=None, cont_dim=None):
421
419
for _ , v in inner_inputs .items ()
422
420
}
423
421
inner_inputs = {k : v for k , v in inner_inputs .items () if k in splitter_rpn }
424
- keys_fromLeftSpl = ["_{}" .format (st .name ) for _ , st in inner_inputs .items ()]
425
422
else :
426
423
previous_states_ind = {}
427
424
inner_inputs = {}
428
- keys_fromLeftSpl = []
429
425
430
426
# when splitter is a single element (no operators)
431
427
if len (splitter_rpn ) == 1 :
432
428
op_single = splitter_rpn [0 ]
433
429
return _single_op_splits (
434
- op_single ,
435
- inputs ,
436
- inner_inputs ,
437
- previous_states_ind ,
438
- keys_fromLeftSpl ,
439
- cont_dim = cont_dim ,
430
+ op_single , inputs , inner_inputs , previous_states_ind , cont_dim = cont_dim
440
431
)
441
432
442
433
terms = {}
@@ -530,12 +521,41 @@ def splits(splitter_rpn, inputs, inner_inputs=None, cont_dim=None):
530
521
val = stack .pop ()
531
522
if isinstance (val , tuple ):
532
523
val = val [0 ]
533
- return val , keys , keys_fromLeftSpl
524
+ return val , keys
525
+
526
+
527
+ def _single_op_splits (
528
+ op_single , inputs , inner_inputs , previous_states_ind , cont_dim = None
529
+ ):
530
+ """ splits function if splitter is a singleton"""
531
+ if op_single .startswith ("_" ):
532
+ return (previous_states_ind [op_single ][0 ], previous_states_ind [op_single ][1 ])
533
+ if cont_dim is None :
534
+ cont_dim = {}
535
+ shape = input_shape (inputs [op_single ], cont_dim = cont_dim .get (op_single , 1 ))
536
+ trmval = range (reduce (lambda x , y : x * y , shape ))
537
+ if op_single in inner_inputs :
538
+ # TODO: have to be changed if differ length
539
+ inner_len = [shape [- 1 ]] * reduce (lambda x , y : x * y , shape [:- 1 ])
540
+ # this come from the previous node
541
+ outer_ind = inner_inputs [op_single ].ind_l
542
+ op_out = itertools .chain .from_iterable (
543
+ itertools .repeat (x , n ) for x , n in zip (outer_ind , inner_len )
544
+ )
545
+ res = op ["." ](op_out , trmval )
546
+ val = res
547
+ keys = inner_inputs [op_single ].keys_final + [op_single ]
548
+ return val , keys
549
+ else :
550
+ val = op ["*" ](trmval )
551
+ keys = [op_single ]
552
+ return val , keys
534
553
535
554
536
- # dj: TODO: do I need keys?
537
555
def splits_groups (splitter_rpn , combiner = None , inner_inputs = None ):
538
- """Process splitter rpn from left to right."""
556
+ """ splits inputs to groups (axes) and creates stacks for these groups
557
+ This is used to specify which input can be combined.
558
+ """
539
559
if not splitter_rpn :
540
560
return [], {}, [], []
541
561
stack = []
@@ -556,9 +576,7 @@ def splits_groups(splitter_rpn, combiner=None, inner_inputs=None):
556
576
# when splitter is a single element (no operators)
557
577
if len (splitter_rpn ) == 1 :
558
578
op_single = splitter_rpn [0 ]
559
- return _single_op_splits_groups (
560
- op_single , combiner , inner_inputs , previous_states_ind , groups
561
- )
579
+ return _single_op_splits_groups (op_single , combiner , inner_inputs , groups )
562
580
563
581
# len(splitter_rpn) > 1
564
582
# iterating splitter_rpn
@@ -668,45 +686,8 @@ def splits_groups(splitter_rpn, combiner=None, inner_inputs=None):
668
686
return keys , groups , groups_stack , []
669
687
670
688
671
- def _single_op_splits (
672
- op_single ,
673
- inputs ,
674
- inner_inputs ,
675
- previous_states_ind ,
676
- keys_fromLeftSpl , # TODO NOW do I need it?
677
- cont_dim = None ,
678
- ):
679
- if op_single .startswith ("_" ):
680
- return (
681
- previous_states_ind [op_single ][0 ],
682
- previous_states_ind [op_single ][1 ],
683
- keys_fromLeftSpl ,
684
- )
685
- if cont_dim is None :
686
- cont_dim = {}
687
- shape = input_shape (inputs [op_single ], cont_dim = cont_dim .get (op_single , 1 ))
688
- trmval = range (reduce (lambda x , y : x * y , shape ))
689
- if op_single in inner_inputs :
690
- # TODO: have to be changed if differ length
691
- inner_len = [shape [- 1 ]] * reduce (lambda x , y : x * y , shape [:- 1 ])
692
- # this come from the previous node
693
- outer_ind = inner_inputs [op_single ].ind_l
694
- op_out = itertools .chain .from_iterable (
695
- itertools .repeat (x , n ) for x , n in zip (outer_ind , inner_len )
696
- )
697
- res = op ["." ](op_out , trmval )
698
- val = res
699
- keys = inner_inputs [op_single ].keys_final + [op_single ]
700
- return val , keys , keys_fromLeftSpl
701
- else :
702
- val = op ["*" ](trmval )
703
- keys = [op_single ]
704
- return val , keys , keys_fromLeftSpl
705
-
706
-
707
- def _single_op_splits_groups (
708
- op_single , combiner , inner_inputs , previous_states_ind , groups
709
- ):
689
+ def _single_op_splits_groups (op_single , combiner , inner_inputs , groups ):
690
+ """ splits_groups function if splitter is a singleton"""
710
691
if op_single in inner_inputs :
711
692
# TODO: have to be changed if differ length
712
693
# TODO: i think I don't want to add here from left part
@@ -768,7 +749,7 @@ def combine_final_groups(combiner, groups, groups_stack, keys):
768
749
769
750
770
751
def map_splits (split_iter , inputs , cont_dim = None ):
771
- """Get a dictionary of prescribed splits ."""
752
+ """generate a dictionary of inputs prescribed by the splitter ."""
772
753
if cont_dim is None :
773
754
cont_dim = {}
774
755
for split in split_iter :
0 commit comments