Skip to content

Commit 8ce8bed

Browse files
authored
Simplifications in e2e matmul tests (#18889)
Two commits: 1. Stop inferring `acc_type`. Require specifying it. Only a few tests were relying on the inferrence. 2. Stop special-casing narrow float types (only using f32 as ABI type, generating `arith.truncf` internally). This was only needed when these narrow float types were not supported in the rest of IREE. Signed-off-by: Benoit Jacob <[email protected]>
1 parent 225baf2 commit 8ce8bed

File tree

3 files changed

+49
-67
lines changed

3 files changed

+49
-67
lines changed

tests/e2e/matmul/BUILD.bazel

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -360,16 +360,17 @@ X86_64_AVX512_BF16 = X86_64_AVX512 + [
360360
generator = ":generate_e2e_matmul_tests",
361361
generator_args = [
362362
"--lhs_rhs_type=%s" % lhs_rhs_type,
363+
"--acc_type=%s" % acc_type,
363364
"--shapes=small",
364365
],
365366
target_backends_and_drivers = [
366367
("vmvx", "local-task"),
367368
],
368369
test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
369370
test_type = "matmul",
370-
) for lhs_rhs_type in [
371-
"i8",
372-
"f32",
371+
) for (lhs_rhs_type, acc_type) in [
372+
("i8", "i32"),
373+
("f32", "f32"),
373374
]]
374375

375376
###########################################################################
@@ -383,6 +384,7 @@ iree_generated_e2e_runner_test(
383384
generator = ":generate_e2e_matmul_tests",
384385
generator_args = [
385386
"--lhs_rhs_type=f32",
387+
"--acc_type=f32",
386388
"--shapes=easy_large_static",
387389
"--compilation_info=LLVMGPUMatmulSimt",
388390
],
@@ -411,6 +413,7 @@ iree_generated_e2e_runner_test(
411413
generator = ":generate_e2e_matmul_tests",
412414
generator_args = [
413415
"--lhs_rhs_type=f32",
416+
"--acc_type=f32",
414417
"--shapes=easy_large_static",
415418
"--compilation_info=LLVMGPUMatmulTensorCore",
416419
],
@@ -437,6 +440,7 @@ iree_generated_e2e_runner_test(
437440
generator = ":generate_e2e_matmul_tests",
438441
generator_args = [
439442
"--lhs_rhs_type=f32",
443+
"--acc_type=f32",
440444
],
441445
tags = [
442446
# CUDA cuInit fails with sanitizer on.
@@ -461,6 +465,7 @@ iree_generated_e2e_runner_test(
461465
generator = ":generate_e2e_matmul_tests",
462466
generator_args = [
463467
"--lhs_rhs_type=f16",
468+
"--acc_type=f32",
464469
],
465470
tags = [
466471
# CUDA cuInit fails with sanitizer on.
@@ -486,6 +491,7 @@ iree_generated_e2e_runner_test(
486491
generator = ":generate_e2e_matmul_tests",
487492
generator_args = [
488493
"--lhs_rhs_type=f32",
494+
"--acc_type=f32",
489495
"--shapes=easy_large_static",
490496
"--compilation_info=LLVMGPUMatmulTensorCoreMmaSync",
491497
],
@@ -513,6 +519,7 @@ iree_generated_e2e_runner_test(
513519
generator = ":generate_e2e_matmul_tests",
514520
generator_args = [
515521
"--lhs_rhs_type=f16",
522+
"--acc_type=f32",
516523
"--shapes=easy_large_static",
517524
"--compilation_info=LLVMGPUMatmulTensorCore",
518525
],
@@ -540,6 +547,7 @@ iree_generated_e2e_runner_test(
540547
generator = ":generate_e2e_matmul_tests",
541548
generator_args = [
542549
"--lhs_rhs_type=f16",
550+
"--acc_type=f32",
543551
"--shapes=easy_large_static",
544552
"--compilation_info=LLVMGPUMatmulTensorCoreMmaSync",
545553
],
@@ -566,6 +574,7 @@ iree_generated_e2e_runner_test(
566574
generator = ":generate_e2e_matmul_tests",
567575
generator_args = [
568576
"--lhs_rhs_type=%s" % lhs_rhs_type,
577+
"--acc_type=%s" % acc_type,
569578
],
570579
tags = [
571580
# CUDA cuInit fails with sanitizer on.
@@ -580,8 +589,8 @@ iree_generated_e2e_runner_test(
580589
],
581590
test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
582591
test_type = "matmul",
583-
) for lhs_rhs_type in [
584-
"f32",
592+
) for (lhs_rhs_type, acc_type) in [
593+
("f32", "f32"),
585594
]]
586595

587596
###########################################################################
@@ -598,6 +607,7 @@ iree_generated_e2e_runner_test(
598607
generator = ":generate_e2e_matmul_tests",
599608
generator_args = [
600609
"--lhs_rhs_type=%s" % lhs_rhs_type,
610+
"--acc_type=%s" % acc_type,
601611
"--shapes=easy_large_static",
602612
"--compilation_info=SPIRVVectorizeMali",
603613
],
@@ -611,10 +621,10 @@ iree_generated_e2e_runner_test(
611621
],
612622
test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
613623
test_type = "matmul",
614-
) for lhs_rhs_type in [
615-
"i8",
616-
"f16",
617-
"f32",
624+
) for (lhs_rhs_type, acc_type) in [
625+
("i8", "i32"),
626+
("f16", "f32"),
627+
("f32", "f32"),
618628
]]
619629

620630
[iree_generated_e2e_runner_test(
@@ -625,6 +635,7 @@ iree_generated_e2e_runner_test(
625635
generator = ":generate_e2e_matmul_tests",
626636
generator_args = [
627637
"--lhs_rhs_type=%s" % lhs_rhs_type,
638+
"--acc_type=%s" % acc_type,
628639
"--shapes=easy_large_static",
629640
"--compilation_info=SPIRVVectorizeNVIDIA",
630641
],
@@ -637,10 +648,10 @@ iree_generated_e2e_runner_test(
637648
],
638649
test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
639650
test_type = "matmul",
640-
) for lhs_rhs_type in [
641-
"i8",
642-
"f16",
643-
"f32",
651+
) for (lhs_rhs_type, acc_type) in [
652+
("i8", "i32"),
653+
("f16", "f32"),
654+
("f32", "f32"),
644655
]]
645656

646657
iree_generated_e2e_runner_test(
@@ -651,6 +662,7 @@ iree_generated_e2e_runner_test(
651662
generator = ":generate_e2e_matmul_tests",
652663
generator_args = [
653664
"--lhs_rhs_type=f16",
665+
"--acc_type=f32",
654666
"--shapes=easy_large_static",
655667
"--compilation_info=SPIRVCooperativeMatrixVectorize",
656668
],

tests/e2e/matmul/CMakeLists.txt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,7 @@ iree_generated_e2e_runner_test(
927927
"generate_e2e_matmul_tests.py"
928928
GENERATOR_ARGS
929929
"--lhs_rhs_type=i8"
930+
"--acc_type=i32"
930931
"--shapes=small"
931932
TEST_RUNNER
932933
iree_tools_testing_e2e_iree-e2e-matmul-test
@@ -948,6 +949,7 @@ iree_generated_e2e_runner_test(
948949
"generate_e2e_matmul_tests.py"
949950
GENERATOR_ARGS
950951
"--lhs_rhs_type=f32"
952+
"--acc_type=f32"
951953
"--shapes=small"
952954
TEST_RUNNER
953955
iree_tools_testing_e2e_iree-e2e-matmul-test
@@ -969,6 +971,7 @@ iree_generated_e2e_runner_test(
969971
"generate_e2e_matmul_tests.py"
970972
GENERATOR_ARGS
971973
"--lhs_rhs_type=f32"
974+
"--acc_type=f32"
972975
"--shapes=easy_large_static"
973976
"--compilation_info=LLVMGPUMatmulSimt"
974977
TEST_RUNNER
@@ -994,6 +997,7 @@ iree_generated_e2e_runner_test(
994997
"generate_e2e_matmul_tests.py"
995998
GENERATOR_ARGS
996999
"--lhs_rhs_type=f32"
1000+
"--acc_type=f32"
9971001
"--shapes=easy_large_static"
9981002
"--compilation_info=LLVMGPUMatmulTensorCore"
9991003
TEST_RUNNER
@@ -1021,6 +1025,7 @@ iree_generated_e2e_runner_test(
10211025
"generate_e2e_matmul_tests.py"
10221026
GENERATOR_ARGS
10231027
"--lhs_rhs_type=f32"
1028+
"--acc_type=f32"
10241029
TEST_RUNNER
10251030
iree_tools_testing_e2e_iree-e2e-matmul-test
10261031
TARGET_BACKENDS
@@ -1046,6 +1051,7 @@ iree_generated_e2e_runner_test(
10461051
"generate_e2e_matmul_tests.py"
10471052
GENERATOR_ARGS
10481053
"--lhs_rhs_type=f16"
1054+
"--acc_type=f32"
10491055
TEST_RUNNER
10501056
iree_tools_testing_e2e_iree-e2e-matmul-test
10511057
TARGET_BACKENDS
@@ -1071,6 +1077,7 @@ iree_generated_e2e_runner_test(
10711077
"generate_e2e_matmul_tests.py"
10721078
GENERATOR_ARGS
10731079
"--lhs_rhs_type=f32"
1080+
"--acc_type=f32"
10741081
"--shapes=easy_large_static"
10751082
"--compilation_info=LLVMGPUMatmulTensorCoreMmaSync"
10761083
TEST_RUNNER
@@ -1098,6 +1105,7 @@ iree_generated_e2e_runner_test(
10981105
"generate_e2e_matmul_tests.py"
10991106
GENERATOR_ARGS
11001107
"--lhs_rhs_type=f16"
1108+
"--acc_type=f32"
11011109
"--shapes=easy_large_static"
11021110
"--compilation_info=LLVMGPUMatmulTensorCore"
11031111
TEST_RUNNER
@@ -1125,6 +1133,7 @@ iree_generated_e2e_runner_test(
11251133
"generate_e2e_matmul_tests.py"
11261134
GENERATOR_ARGS
11271135
"--lhs_rhs_type=f16"
1136+
"--acc_type=f32"
11281137
"--shapes=easy_large_static"
11291138
"--compilation_info=LLVMGPUMatmulTensorCoreMmaSync"
11301139
TEST_RUNNER
@@ -1152,6 +1161,7 @@ iree_generated_e2e_runner_test(
11521161
"generate_e2e_matmul_tests.py"
11531162
GENERATOR_ARGS
11541163
"--lhs_rhs_type=f32"
1164+
"--acc_type=f32"
11551165
TEST_RUNNER
11561166
iree_tools_testing_e2e_iree-e2e-matmul-test
11571167
TARGET_BACKENDS
@@ -1177,6 +1187,7 @@ iree_generated_e2e_runner_test(
11771187
"generate_e2e_matmul_tests.py"
11781188
GENERATOR_ARGS
11791189
"--lhs_rhs_type=i8"
1190+
"--acc_type=i32"
11801191
"--shapes=easy_large_static"
11811192
"--compilation_info=SPIRVVectorizeMali"
11821193
TEST_RUNNER
@@ -1201,6 +1212,7 @@ iree_generated_e2e_runner_test(
12011212
"generate_e2e_matmul_tests.py"
12021213
GENERATOR_ARGS
12031214
"--lhs_rhs_type=f16"
1215+
"--acc_type=f32"
12041216
"--shapes=easy_large_static"
12051217
"--compilation_info=SPIRVVectorizeMali"
12061218
TEST_RUNNER
@@ -1225,6 +1237,7 @@ iree_generated_e2e_runner_test(
12251237
"generate_e2e_matmul_tests.py"
12261238
GENERATOR_ARGS
12271239
"--lhs_rhs_type=f32"
1240+
"--acc_type=f32"
12281241
"--shapes=easy_large_static"
12291242
"--compilation_info=SPIRVVectorizeMali"
12301243
TEST_RUNNER
@@ -1249,6 +1262,7 @@ iree_generated_e2e_runner_test(
12491262
"generate_e2e_matmul_tests.py"
12501263
GENERATOR_ARGS
12511264
"--lhs_rhs_type=i8"
1265+
"--acc_type=i32"
12521266
"--shapes=easy_large_static"
12531267
"--compilation_info=SPIRVVectorizeNVIDIA"
12541268
TEST_RUNNER
@@ -1273,6 +1287,7 @@ iree_generated_e2e_runner_test(
12731287
"generate_e2e_matmul_tests.py"
12741288
GENERATOR_ARGS
12751289
"--lhs_rhs_type=f16"
1290+
"--acc_type=f32"
12761291
"--shapes=easy_large_static"
12771292
"--compilation_info=SPIRVVectorizeNVIDIA"
12781293
TEST_RUNNER
@@ -1297,6 +1312,7 @@ iree_generated_e2e_runner_test(
12971312
"generate_e2e_matmul_tests.py"
12981313
GENERATOR_ARGS
12991314
"--lhs_rhs_type=f32"
1315+
"--acc_type=f32"
13001316
"--shapes=easy_large_static"
13011317
"--compilation_info=SPIRVVectorizeNVIDIA"
13021318
TEST_RUNNER
@@ -1321,6 +1337,7 @@ iree_generated_e2e_runner_test(
13211337
"generate_e2e_matmul_tests.py"
13221338
GENERATOR_ARGS
13231339
"--lhs_rhs_type=f16"
1340+
"--acc_type=f32"
13241341
"--shapes=easy_large_static"
13251342
"--compilation_info=SPIRVCooperativeMatrixVectorize"
13261343
TEST_RUNNER

tests/e2e/matmul/generate_e2e_matmul_tests.py

Lines changed: 7 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -545,20 +545,6 @@ def int_or_DYN(s: DimSize):
545545
return s.value or "DYN"
546546

547547

548-
# Gets friendlier form/type that we can use as arg types which we can cast into the target_type.
549-
def cast_argtype_if_required(target_type: MatrixElemTypeId):
550-
if target_type == MatrixElemTypeId.F8E4M3FNUZ:
551-
return MatrixElemTypeId.F32
552-
return target_type
553-
554-
555-
# Gets the op needed to cast/convert from the friendly form/type into the target_type.
556-
def get_castback_from_arg_op(target_type: MatrixElemTypeId):
557-
if target_type == MatrixElemTypeId.F8E4M3FNUZ:
558-
return "arith.truncf"
559-
return ValueError(f"Unhandled castback type of {target_type}")
560-
561-
562548
# Describes the fully resolved shape dimensions of all 3 input matrices,
563549
# LHS, RHS, and Accumulator, in a testcase.
564550
# Each value is a string, which may either represent a positive integer such as "123",
@@ -659,9 +645,8 @@ def generate_function(
659645
acc_r = int_or_question_mark(shapes.acc_rows)
660646
acc_c = int_or_question_mark(shapes.acc_cols)
661647

662-
casted_lhs_rhs_type = cast_argtype_if_required(lhs_rhs_type)
663-
lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{casted_lhs_rhs_type.value}>"
664-
rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{casted_lhs_rhs_type.value}>"
648+
lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{lhs_rhs_type.value}>"
649+
rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{lhs_rhs_type.value}>"
665650
acc_tensor_type = f"tensor<{acc_r}x{acc_c}x{acc_type.value}>"
666651

667652
if transpose_rhs:
@@ -680,15 +665,6 @@ def generate_function(
680665
func_definition = func_definition + compilation_info_string
681666
generate_function.compilation_index += 1
682667
compute = f" %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
683-
if casted_lhs_rhs_type != lhs_rhs_type:
684-
castback_op = get_castback_from_arg_op(lhs_rhs_type)
685-
compute_lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{lhs_rhs_type.value}>"
686-
compute_rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{lhs_rhs_type.value}>"
687-
compute = (
688-
f" %lhs_casted = {castback_op} %lhs: {lhs_tensor_type} to {compute_lhs_tensor_type}\n"
689-
f" %rhs_casted = {castback_op} %rhs: {rhs_tensor_type} to {compute_rhs_tensor_type}\n"
690-
f" %result = {op_name} {compilation_info_attr}ins(%lhs_casted, %rhs_casted: {compute_lhs_tensor_type}, {compute_rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}"
691-
)
692668
if shape.accumulate:
693669
signature = f"({lhs_tensor_type}, {rhs_tensor_type}, {acc_tensor_type}) -> {acc_tensor_type}"
694670
import_declaration = f"func.func private @module.{func_name}(%lhs: !hal.buffer_view, %rhs: !hal.buffer_view, %acc: !hal.buffer_view) -> !hal.buffer_view"
@@ -818,9 +794,8 @@ def generate_call(
818794
rhs_shape = [shape.k, shape.n]
819795
transpose_rhs = 0
820796

821-
casted_lhs_rhs_type = cast_argtype_if_required(lhs_rhs_type)
822-
op = op + generate_random_matrix("lhs", lhs_shape, casted_lhs_rhs_type)
823-
op = op + generate_random_matrix("rhs", rhs_shape, casted_lhs_rhs_type)
797+
op = op + generate_random_matrix("lhs", lhs_shape, lhs_rhs_type)
798+
op = op + generate_random_matrix("rhs", rhs_shape, lhs_rhs_type)
824799
if shape.accumulate:
825800
op = op + generate_random_matrix("acc", [shape.m, shape.n], acc_type)
826801
# TODO(#16168): there's a bug with in-place input->output aliasing and
@@ -919,16 +894,15 @@ def parse_arguments():
919894
"f8E5M2FNUZ",
920895
"f8E4M3FNUZ",
921896
],
922-
help="Numeric type of input matrices",
897+
help="Numeric type of input LHS and RHS matrices",
923898
required=True,
924899
)
925900
parser.add_argument(
926901
"--acc_type",
927902
type=str,
928903
choices=["i32", "f32", "f16", "bf16"],
929-
help="Numeric type of input matrices",
930-
default="",
931-
required=False,
904+
help="Numeric type of the accumulator and result matrices",
905+
required=True,
932906
)
933907
parser.add_argument(
934908
"--shapes",
@@ -1005,30 +979,9 @@ def write_calls_file(functions, calls, filename, requirements):
1005979
file.write(module_definition)
1006980

1007981

1008-
# For now, the accumulator type can always be inferred from the input LHS/RHS
1009-
# type, so we do that. That is temporary: eventually there will be cases
1010-
# where the same input types are used with different accumulator types, e.g.
1011-
# f16 inputs with both f16 and f32 accumulator.
1012-
def infer_acc_type(lhs_rhs_type: MatrixElemTypeId, acc_type: MatrixElemTypeId):
1013-
if acc_type != MatrixElemTypeId.NONE:
1014-
return acc_type
1015-
if lhs_rhs_type == MatrixElemTypeId.F8E5M2:
1016-
return MatrixElemTypeId.F32
1017-
if lhs_rhs_type == MatrixElemTypeId.F8E4M3:
1018-
return MatrixElemTypeId.F32
1019-
if lhs_rhs_type == MatrixElemTypeId.F8E5M2FNUZ:
1020-
return MatrixElemTypeId.F32
1021-
if lhs_rhs_type == MatrixElemTypeId.F8E4M3FNUZ:
1022-
return MatrixElemTypeId.F32
1023-
if lhs_rhs_type == MatrixElemTypeId.I8:
1024-
return MatrixElemTypeId.I32
1025-
return lhs_rhs_type
1026-
1027-
1028982
def main(args):
1029983
lhs_rhs_type = MatrixElemTypeId(args.lhs_rhs_type)
1030984
acc_type = MatrixElemTypeId(args.acc_type)
1031-
acc_type = infer_acc_type(lhs_rhs_type, acc_type)
1032985
shapes_id = ShapesId(args.shapes)
1033986
compilation_info_id = CompilationInfoId(args.compilation_info)
1034987

0 commit comments

Comments
 (0)