@@ -586,6 +586,98 @@ def test_state_connect_6a():
586
586
]
587
587
588
588
589
+ def test_state_connect_7 ():
590
+ """two 'connected' states with multiple fields that are connected
591
+ no explicit splitter for the second state
592
+ """
593
+ st1 = State (name = "NA" , splitter = "a" )
594
+ st2 = State (name = "NB" , other_states = {"NA" : (st1 , ["x" , "y" ])})
595
+ # should take into account that x, y come from the same task
596
+ assert st2 .splitter == "_NA"
597
+ assert st2 .splitter_rpn == ["NA.a" ]
598
+ assert st2 .prev_state_splitter == st2 .splitter
599
+ assert st2 .prev_state_splitter_rpn == st2 .splitter_rpn
600
+ assert st2 .current_splitter is None
601
+ assert st2 .current_splitter_rpn == []
602
+
603
+ st2 .prepare_states (inputs = {"NA.a" : [3 , 5 ]})
604
+ assert st2 .group_for_inputs_final == {"NA.a" : 0 }
605
+ assert st2 .groups_stack_final == [[0 ]]
606
+ assert st2 .states_ind == [{"NA.a" : 0 }, {"NA.a" : 1 }]
607
+ assert st2 .states_val == [{"NA.a" : 3 }, {"NA.a" : 5 }]
608
+
609
+ st2 .prepare_inputs ()
610
+ # since x,y come from the same state, they should have the same index
611
+ assert st2 .inputs_ind == [{"NB.x" : 0 , "NB.y" : 0 }, {"NB.x" : 1 , "NB.y" : 1 }]
612
+
613
+
614
+ def test_state_connect_8 ():
615
+ """three 'connected' states: NA -> NB -> NC; NA -> NC (only NA has its own splitter)
616
+ pydra should recognize, that there is only one splitter - NA
617
+ and it should give the same as the previous test
618
+ """
619
+ st1 = State (name = "NA" , splitter = "a" )
620
+ st2 = State (name = "NB" , other_states = {"NA" : (st1 , "b" )})
621
+ st3 = State (name = "NC" , other_states = {"NA" : (st1 , "x" ), "NB" : (st2 , "y" )})
622
+ # x comes from NA and y comes from NB, but NB has only NA's splitter,
623
+ # so it should be treated as both inputs are from NA state
624
+ assert st3 .splitter == "_NA"
625
+ assert st3 .splitter_rpn == ["NA.a" ]
626
+ assert st3 .prev_state_splitter == st3 .splitter
627
+ assert st3 .prev_state_splitter_rpn == st3 .splitter_rpn
628
+ assert st3 .current_splitter is None
629
+ assert st3 .current_splitter_rpn == []
630
+
631
+ st3 .prepare_states (inputs = {"NA.a" : [3 , 5 ]})
632
+ assert st3 .group_for_inputs_final == {"NA.a" : 0 }
633
+ assert st3 .groups_stack_final == [[0 ]]
634
+ assert st3 .states_ind == [{"NA.a" : 0 }, {"NA.a" : 1 }]
635
+ assert st3 .states_val == [{"NA.a" : 3 }, {"NA.a" : 5 }]
636
+
637
+ st3 .prepare_inputs ()
638
+ # since x,y come from the same state (although y indirectly), they should have the same index
639
+ assert st3 .inputs_ind == [{"NC.x" : 0 , "NC.y" : 0 }, {"NC.x" : 1 , "NC.y" : 1 }]
640
+
641
+
642
+ @pytest .mark .xfail (
643
+ reason = "doesn't recognize that NC.y has 4 elements (not independend on NC.x)"
644
+ )
645
+ def test_state_connect_9 ():
646
+ """four 'connected' states: NA1 -> NB; NA2 -> NB, NA1 -> NC; NB -> NC
647
+ pydra should recognize, that there is only one splitter - NA_1 and NA_2
648
+
649
+ """
650
+ st1 = State (name = "NA_1" , splitter = "a" )
651
+ st1a = State (name = "NA_2" , splitter = "a" )
652
+ st2 = State (name = "NB" , other_states = {"NA_1" : (st1 , "b" ), "NA_2" : (st1a , "c" )})
653
+ st3 = State (name = "NC" , other_states = {"NA_1" : (st1 , "x" ), "NB" : (st2 , "y" )})
654
+ # x comes from NA_1 and y comes from NB, but NB has only NA_1/2's splitters,
655
+ assert st3 .splitter == ["_NA_1" , "_NA_2" ]
656
+ assert st3 .splitter_rpn == ["NA_1.a" , "NA_2.a" , "*" ]
657
+ assert st3 .prev_state_splitter == st3 .splitter
658
+ assert st3 .prev_state_splitter_rpn == st3 .splitter_rpn
659
+ assert st3 .current_splitter is None
660
+ assert st3 .current_splitter_rpn == []
661
+
662
+ st3 .prepare_states (inputs = {"NA_1.a" : [3 , 5 ], "NA_2.a" : [11 , 12 ]})
663
+ assert st3 .group_for_inputs_final == {"NA_1.a" : 0 , "NA_2.a" : 1 }
664
+ assert st3 .groups_stack_final == [[0 , 1 ]]
665
+ assert st3 .states_ind == [
666
+ {"NA_1.a" : 0 , "NA_2.a" : 0 },
667
+ {"NA_1.a" : 0 , "NA_2.a" : 1 },
668
+ {"NA_1.a" : 1 , "NA_2.a" : 0 },
669
+ {"NA_1.a" : 1 , "NA_2.a" : 1 },
670
+ ]
671
+
672
+ st3 .prepare_inputs ()
673
+ assert st3 .inputs_ind == [
674
+ {"NC.x" : 0 , "NC.y" : 0 },
675
+ {"NC.x" : 0 , "NC.y" : 1 },
676
+ {"NC.x" : 1 , "NC.y" : 2 },
677
+ {"NC.x" : 1 , "NC.y" : 3 },
678
+ ]
679
+
680
+
589
681
def test_state_connect_innerspl_1 ():
590
682
"""two 'connected' states: testing groups, prepare_states and prepare_inputs,
591
683
the second state has an inner splitter, full splitter provided
@@ -605,7 +697,7 @@ def test_state_connect_innerspl_1():
605
697
inputs = {"NA.a" : [3 , 5 ], "NB.b" : [[1 , 10 , 100 ], [2 , 20 , 200 ]]},
606
698
cont_dim = {"NB.b" : 2 }, # will be treated as 2d container
607
699
)
608
- assert st2 .other_states ["NA" ][1 ] == "b"
700
+ assert st2 .other_states ["NA" ][1 ] == [ "b" ]
609
701
assert st2 .group_for_inputs_final == {"NA.a" : 0 , "NB.b" : 1 }
610
702
assert st2 .groups_stack_final == [[0 ], [1 ]]
611
703
@@ -653,7 +745,7 @@ def test_state_connect_innerspl_1a():
653
745
assert st2 .current_splitter == "NB.b"
654
746
assert st2 .current_splitter_rpn == ["NB.b" ]
655
747
656
- assert st2 .other_states ["NA" ][1 ] == "b"
748
+ assert st2 .other_states ["NA" ][1 ] == [ "b" ]
657
749
658
750
st2 .prepare_states (
659
751
inputs = {"NA.a" : [3 , 5 ], "NB.b" : [[1 , 10 , 100 ], [2 , 20 , 200 ]]},
@@ -717,7 +809,7 @@ def test_state_connect_innerspl_2():
717
809
inputs = {"NA.a" : [3 , 5 ], "NB.b" : [[1 , 10 , 100 ], [2 , 20 , 200 ]], "NB.c" : [13 , 17 ]},
718
810
cont_dim = {"NB.b" : 2 }, # will be treated as 2d container
719
811
)
720
- assert st2 .other_states ["NA" ][1 ] == "b"
812
+ assert st2 .other_states ["NA" ][1 ] == [ "b" ]
721
813
assert st2 .group_for_inputs_final == {"NA.a" : 0 , "NB.c" : 1 , "NB.b" : 2 }
722
814
assert st2 .groups_stack_final == [[0 ], [1 , 2 ]]
723
815
@@ -778,7 +870,7 @@ def test_state_connect_innerspl_2a():
778
870
779
871
assert st2 .splitter == ["_NA" , ["NB.b" , "NB.c" ]]
780
872
assert st2 .splitter_rpn == ["NA.a" , "NB.b" , "NB.c" , "*" , "*" ]
781
- assert st2 .other_states ["NA" ][1 ] == "b"
873
+ assert st2 .other_states ["NA" ][1 ] == [ "b" ]
782
874
783
875
st2 .prepare_states (
784
876
inputs = {"NA.a" : [3 , 5 ], "NB.b" : [[1 , 10 , 100 ], [2 , 20 , 200 ]], "NB.c" : [13 , 17 ]},
@@ -839,6 +931,7 @@ def test_state_connect_innerspl_3():
839
931
the second state has one inner splitter and one 'normal' splitter
840
932
the prev-state parts of the splitter have to be added
841
933
"""
934
+
842
935
st1 = State (name = "NA" , splitter = "a" )
843
936
st2 = State (name = "NB" , splitter = ["c" , "b" ], other_states = {"NA" : (st1 , "b" )})
844
937
st3 = State (name = "NC" , splitter = "d" , other_states = {"NB" : (st2 , "a" )})
@@ -986,8 +1079,8 @@ def test_state_connect_innerspl_4():
986
1079
987
1080
assert st3 .splitter == [["_NA" , "_NB" ], "NC.d" ]
988
1081
assert st3 .splitter_rpn == ["NA.a" , "NB.b" , "NB.c" , "*" , "*" , "NC.d" , "*" ]
989
- assert st3 .other_states ["NA" ][1 ] == "e"
990
- assert st3 .other_states ["NB" ][1 ] == "f"
1082
+ assert st3 .other_states ["NA" ][1 ] == [ "e" ]
1083
+ assert st3 .other_states ["NB" ][1 ] == [ "f" ]
991
1084
992
1085
st3 .prepare_states (
993
1086
inputs = {
@@ -1736,12 +1829,12 @@ def test_connect_splitters_exception_1(splitter, other_states):
1736
1829
1737
1830
1738
1831
def test_connect_splitters_exception_2 ():
1739
- st = State (
1740
- name = "CN" ,
1741
- splitter = "_NB" ,
1742
- other_states = {"NA" : (State (name = "NA" , splitter = "a" ), "b" )},
1743
- )
1744
1832
with pytest .raises (PydraStateError ) as excinfo :
1833
+ st = State (
1834
+ name = "CN" ,
1835
+ splitter = "_NB" ,
1836
+ other_states = {"NA" : (State (name = "NA" , splitter = "a" ), "b" )},
1837
+ )
1745
1838
st .set_input_groups ()
1746
1839
assert "can't ask for splitter from NB" in str (excinfo .value )
1747
1840
0 commit comments