diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index d32fa715734..9cec4891c10 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -169,7 +169,13 @@ def get_linear_test_suites(): @register_test_suite("aten._weight_int8pack_mm.default") def get_weight_int8pack_mm_inputs(): - MKN_list = common_MKN_list + MKN_list = [ + [6, 480, 256], + [6, 256, 1024], + [6, 1024, 256], + [6, 256, 256], + [6, 256, 512], + ] inputs_list = [((M, K), (N, K), (N)) for M, K, N in MKN_list]