@@ -31,8 +31,12 @@ def __init__(
31
31
"splitter, keys_exp, groups_exp, grstack_exp" ,
32
32
[
33
33
("a" , ["a" ], {"a" : 0 }, [[0 ]]),
34
+ (["a" ], ["a" ], {"a" : 0 }, [[0 ]]),
35
+ (("a" ,), ["a" ], {"a" : 0 }, [[0 ]]),
34
36
(("a" , "b" ), ["a" , "b" ], {"a" : 0 , "b" : 0 }, [[0 ]]),
35
37
(["a" , "b" ], ["a" , "b" ], {"a" : 0 , "b" : 1 }, [[0 , 1 ]]),
38
+ ([["a" , "b" ]], ["a" , "b" ], {"a" : 0 , "b" : 1 }, [[0 , 1 ]]),
39
+ ((["a" , "b" ],), ["a" , "b" ], {"a" : 0 , "b" : 1 }, [[0 , 1 ]]),
36
40
((["a" , "b" ], "c" ), ["a" , "b" , "c" ], {"a" : 0 , "b" : 1 , "c" : [0 , 1 ]}, [[0 , 1 ]]),
37
41
([("a" , "b" ), "c" ], ["a" , "b" , "c" ], {"a" : 0 , "b" : 0 , "c" : 1 }, [[0 , 1 ]]),
38
42
([["a" , "b" ], "c" ], ["a" , "b" , "c" ], {"a" : 0 , "b" : 1 , "c" : 2 }, [[0 , 1 , 2 ]]),
@@ -58,6 +62,8 @@ def test_splits_groups(splitter, keys_exp, groups_exp, grstack_exp):
58
62
"keys_final_exp, groups_final_exp, grstack_final_exp" ,
59
63
[
60
64
("a" , ["a" ], ["a" ], [], {}, []),
65
+ (["a" ], ["a" ], ["a" ], [], {}, []),
66
+ (("a" ,), ["a" ], ["a" ], [], {}, []),
61
67
(("a" , "b" ), ["a" ], ["a" , "b" ], [], {}, [[]]),
62
68
(("a" , "b" ), ["b" ], ["a" , "b" ], [], {}, [[]]),
63
69
(["a" , "b" ], ["b" ], ["b" ], ["a" ], {"a" : 0 }, [[0 ]]),
@@ -69,6 +75,8 @@ def test_splits_groups(splitter, keys_exp, groups_exp, grstack_exp):
69
75
([("a" , "b" ), "c" ], ["a" ], ["a" , "b" ], ["c" ], {"c" : 0 }, [[0 ]]),
70
76
([("a" , "b" ), "c" ], ["b" ], ["a" , "b" ], ["c" ], {"c" : 0 }, [[0 ]]),
71
77
([("a" , "b" ), "c" ], ["c" ], ["c" ], ["a" , "b" ], {"a" : 0 , "b" : 0 }, [[0 ]]),
78
+ ([[("a" , "b" ), "c" ]], ["c" ], ["c" ], ["a" , "b" ], {"a" : 0 , "b" : 0 }, [[0 ]]),
79
+ (([("a" , "b" ), "c" ],), ["c" ], ["c" ], ["a" , "b" ], {"a" : 0 , "b" : 0 }, [[0 ]]),
72
80
],
73
81
)
74
82
def test_splits_groups_comb (
@@ -94,6 +102,8 @@ def test_splits_groups_comb(
94
102
"splitter, cont_dim, values, keys, splits" ,
95
103
[
96
104
("a" , None , [(0 ,), (1 ,)], ["a" ], [{"a" : 1 }, {"a" : 2 }]),
105
+ (["a" ], None , [(0 ,), (1 ,)], ["a" ], [{"a" : 1 }, {"a" : 2 }]),
106
+ (("a" ,), None , [(0 ,), (1 ,)], ["a" ], [{"a" : 1 }, {"a" : 2 }]),
97
107
(
98
108
("a" , "v" ),
99
109
None ,
@@ -468,6 +478,8 @@ def test_splits_2(splitter_rpn, inner_inputs, values, keys, splits):
468
478
(["a" , ("b" , ["c" , "d" ])], ["a" , "b" , "c" , "d" , "*" , "." , "*" ]),
469
479
((["a" , "b" ], "c" ), ["a" , "b" , "*" , "c" , "." ]),
470
480
((["a" , "b" ], ["c" , "d" ]), ["a" , "b" , "*" , "c" , "d" , "*" , "." ]),
481
+ (([["a" , "b" ]], ["c" , "d" ]), ["a" , "b" , "*" , "c" , "d" , "*" , "." ]),
482
+ (((["a" , "b" ],), ["c" , "d" ]), ["a" , "b" , "*" , "c" , "d" , "*" , "." ]),
471
483
([("a" , "b" ), ("c" , "d" )], ["a" , "b" , "." , "c" , "d" , "." , "*" ]),
472
484
],
473
485
)
0 commit comments