@@ -922,86 +922,103 @@ def get_split_with_sizes_inputs():
922922 Test = namedtuple ("VkSliceTest" , ["self" , "sizes" , "dim" ])
923923 test_cases = [
924924 # Split on Width
925+ Test (self = (S1 , 7 , 10 , 11 ), sizes = [1 , 3 , 3 , 5 ], dim = 3 ),
925926 Test (self = (S1 , 7 , 10 , 10 ), sizes = [1 , 2 , 3 , 4 ], dim = 3 ),
927+ Test (self = (7 , 10 , 11 ), sizes = [1 , 3 , 3 , 5 ], dim = 2 ),
926928 Test (self = (7 , 10 , 10 ), sizes = [1 , 2 , 3 , 4 ], dim = 2 ),
929+ Test (self = (7 , 10 , 11 ), sizes = [3 , 8 ], dim = 2 ),
927930 Test (self = (7 , 10 , 10 ), sizes = [1 , 9 ], dim = 2 ),
928931 Test (self = (10 , 10 ), sizes = [1 , 9 ], dim = 1 ),
929932 Test (self = (10 ,), sizes = [1 , 9 ], dim = 0 ),
930933 # Split on Height
934+ Test (self = (S1 , 7 , 11 , 10 ), sizes = [1 , 3 , 3 , 5 ], dim = 2 ),
931935 Test (self = (S1 , 7 , 10 , 10 ), sizes = [1 , 2 , 3 , 4 ], dim = 2 ),
936+ Test (self = (7 , 11 , 10 ), sizes = [1 , 3 , 3 , 5 ], dim = 1 ),
932937 Test (self = (7 , 10 , 10 ), sizes = [1 , 2 , 3 , 4 ], dim = 1 ),
938+ Test (self = (7 , 11 , 11 ), sizes = [3 , 8 ], dim = 1 ),
933939 Test (self = (7 , 10 , 10 ), sizes = [10 ], dim = 1 ),
934940 Test (self = (7 , 6 , 10 ), sizes = [1 , 1 , 1 , 1 , 1 , 1 ], dim = 1 ),
935941 Test (self = (10 , 10 ), sizes = [1 , 2 , 3 , 4 ], dim = 0 ),
936942 # Split on Batch
937943 Test (self = (10 , 7 , 10 , 10 ), sizes = [3 , 6 , 1 ], dim = 0 ),
938944 Test (self = (10 , 7 , 10 , 10 ), sizes = [10 ], dim = 0 ),
939945 # Split on Channel
946+ Test (self = (7 , 13 , 4 , 8 ), sizes = [3 , 5 , 2 , 3 ], dim = 1 ),
940947 Test (self = (7 , 13 , 4 , 8 ), sizes = [3 , 6 , 1 , 3 ], dim = 1 ),
948+ Test (self = (7 , 13 , 4 , 8 ), sizes = [3 , 3 , 2 , 5 , 1 ], dim = 1 ),
941949 Test (self = (7 , 13 , 4 , 8 ), sizes = [3 , 3 , 3 , 3 , 1 ], dim = 1 ),
950+ Test (self = (13 , 4 , 8 ), sizes = [3 , 5 , 2 , 1 , 2 ], dim = 0 ),
942951 Test (self = (13 , 4 , 8 ), sizes = [3 , 3 , 3 , 3 , 1 ], dim = 0 ),
943952 Test (self = (13 , 4 , 8 ), sizes = [2 , 9 , 2 ], dim = 0 ),
944953 Test (self = (13 , 4 , 8 ), sizes = [13 ], dim = 0 ),
945954 ]
946955 test_suite = VkTestSuite ([tuple (tc ) for tc in test_cases ])
947956
948957 test_suite .layouts = [
958+ "utils::kWidthPacked" ,
959+ "utils::kHeightPacked" ,
949960 "utils::kChannelsPacked" ,
950961 ]
951962 test_suite .data_gen = "make_seq_tensor"
952963 test_suite .dtypes = ["at::kFloat" ]
953964 return test_suite
954965
955966
956- @register_test_suite ("aten.split.Tensor" )
957- def get_split_tensor_inputs ():
958- test_suite = VkTestSuite (
959- [
960- # Split on Width
961- ((S1 , 7 , 10 , 12 ), 12 , 3 ),
962- ((S1 , 7 , 10 , 12 ), 3 , 3 ),
963- ((S1 , 7 , 10 , 12 ), 1 , 3 ),
964- ((7 , 10 , 12 ), 12 , 2 ),
965- ((7 , 10 , 12 ), 3 , 2 ),
966- ((7 , 10 , 12 ), 1 , 2 ),
967- ((10 , 12 ), 12 , 1 ),
968- ((10 , 12 ), 3 , 1 ),
969- ((10 , 12 ), 1 , 1 ),
970- ((12 ,), 12 , 0 ),
971- ((12 ,), 3 , 0 ),
972- ((12 ,), 1 , 0 ),
973- # Split on Height
974- ((S1 , 7 , 12 , 8 ), 12 , 2 ),
975- ((S1 , 7 , 12 , 8 ), 3 , 2 ),
976- ((S1 , 7 , 12 , 8 ), 1 , 2 ),
977- ((7 , 12 , 8 ), 12 , 1 ),
978- ((7 , 12 , 8 ), 3 , 1 ),
979- ((7 , 12 , 8 ), 1 , 1 ),
980- ((12 , 8 ), 12 , 0 ),
981- ((12 , 8 ), 3 , 0 ),
982- ((12 , 8 ), 1 , 0 ),
983- # Split on Batch
984- ((12 , 7 , 10 , 10 ), 12 , 0 ),
985- ((12 , 7 , 10 , 10 ), 3 , 0 ),
986- ((12 , 7 , 10 , 10 ), 1 , 0 ),
987- # Split on Channel
988- ((7 , 15 , 10 , 10 ), 15 , 1 ),
989- ((7 , 15 , 10 , 10 ), 5 , 1 ),
990- ((7 , 15 , 10 , 10 ), 3 , 1 ),
991- ((7 , 15 , 10 , 10 ), 1 , 1 ),
992- ((15 , 10 , 10 ), 15 , 0 ),
993- ((15 , 10 , 10 ), 5 , 0 ),
994- ((15 , 10 , 10 ), 3 , 0 ),
995- ((15 , 10 , 10 ), 1 , 0 ),
996- ]
997- )
998-
999- test_suite .layouts = [
1000- "utils::kChannelsPacked" ,
1001- ]
1002- test_suite .data_gen = "make_seq_tensor"
1003- test_suite .dtypes = ["at::kFloat" ]
1004- return test_suite
967+ # @register_test_suite("aten.split.Tensor")
968+ # def get_split_tensor_inputs():
969+ # test_suite = VkTestSuite(
970+ # [
971+ # # Split on Width
972+ # ((M1, 7, 10, 12), 12, 3),
973+ # ((S1, 7, 10, 12), 12, 3),
974+ # ((M1, 7, 10, 12), 3, 3),
975+ # ((S1, 7, 10, 12), 3, 3),
976+ # ((M1, 7, 10, 12), 1, 3),
977+ # ((S1, 7, 10, 12), 1, 3),
978+ # ((7, 10, 12), 12, 2),
979+ # ((7, 10, 12), 3, 2),
980+ # ((7, 10, 12), 1, 2),
981+ # ((2, 3, 4), 1, 2),
982+ # ((10, 12), 12, 1),
983+ # ((10, 12), 3, 1),
984+ # ((10, 12), 1, 1),
985+ # ((12,), 12, 0),
986+ # ((12,), 3, 0),
987+ # ((12,), 1, 0),
988+ # # Split on Height
989+ # ((S1, 7, 12, 8), 12, 2),
990+ # ((S1, 7, 12, 8), 3, 2),
991+ # ((S1, 7, 12, 8), 1, 2),
992+ # ((7, 12, 8), 12, 1),
993+ # ((7, 12, 8), 3, 1),
994+ # ((7, 12, 8), 1, 1),
995+ # ((12, 8), 12, 0),
996+ # ((12, 8), 3, 0),
997+ # ((12, 8), 1, 0),
998+ # # Split on Batch
999+ # ((12, 7, 10, 10), 12, 0),
1000+ # ((12, 7, 10, 10), 3, 0),
1001+ # ((12, 7, 10, 10), 1, 0),
1002+ # # Split on Channel
1003+ # ((7, 15, 10, 10), 15, 1),
1004+ # ((7, 15, 10, 10), 5, 1),
1005+ # ((7, 15, 10, 10), 3, 1),
1006+ # ((7, 15, 10, 10), 1, 1),
1007+ # ((15, 10, 10), 15, 0),
1008+ # ((15, 10, 10), 5, 0),
1009+ # ((15, 10, 10), 3, 0),
1010+ # ((15, 10, 10), 1, 0),
1011+ # ]
1012+ # )
1013+
1014+ # test_suite.layouts = [
1015+ # "utils::kWidthPacked",
1016+ # "utils::kHeightPacked",
1017+ # "utils::kChannelsPacked",
1018+ # ]
1019+ # test_suite.data_gen = "make_seq_tensor"
1020+ # test_suite.dtypes = ["at::kFloat"]
1021+ # return test_suite
10051022
10061023
10071024def get_reduce_inputs (is_softmax : bool = False ):
0 commit comments