Skip to content

Commit 1ca0654

Browse files
authored
Update model benchmarks to matmul_transpose_b (#75)
In general, we try to use matmul_transpose_b (the tensors are M x K and N x K in memory) for most matrux multiplications in IREE today. However, most of the tests that hardcoded a layout were using matmul_transpose_a instead, leading to insufficient benchmark coverage of the common case.
1 parent c31afba commit 1ca0654

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

iree_kernel_benchmark/gemmbench/problems.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -727,8 +727,8 @@ def llama13bmatvec(dtype: str, raw_accumulators: bool) -> list[GemmConfig]:
727727
m,
728728
n,
729729
k,
730-
"T",
731730
"N",
731+
"T",
732732
dtype,
733733
get_default_accumulator_element_type(dtype),
734734
get_default_result_element_type(dtype, raw_accumulators),
@@ -747,8 +747,8 @@ def llama70bmatvec(dtype: str, raw_accumulators: bool) -> list[GemmConfig]:
747747
m,
748748
n,
749749
k,
750-
"T",
751750
"N",
751+
"T",
752752
dtype,
753753
get_default_accumulator_element_type(dtype),
754754
get_default_result_element_type(dtype, raw_accumulators),
@@ -768,8 +768,8 @@ def llama13bskinny(dtype: str, raw_accumulators: bool) -> list[GemmConfig]:
768768
m,
769769
batch,
770770
k,
771-
"T",
772771
"N",
772+
"T",
773773
dtype,
774774
get_default_accumulator_element_type(dtype),
775775
get_default_result_element_type(dtype, raw_accumulators),
@@ -789,8 +789,8 @@ def llama70bskinny(dtype: str, raw_accumulators: bool) -> list[GemmConfig]:
789789
m,
790790
batch,
791791
k,
792-
"T",
793792
"N",
793+
"T",
794794
dtype,
795795
get_default_accumulator_element_type(dtype),
796796
get_default_result_element_type(dtype, raw_accumulators),
@@ -1009,9 +1009,9 @@ def get_matching_configs(
10091009
if not config_re.match(config.get_name()):
10101010
continue
10111011
# TODO(https://github.com/iree-org/iree/issues/20446):
1012-
# Mx1xK transpose-A configurations temporarily skipped because they
1012+
# Mx1xK transpose-A/-B configurations temporarily skipped because they
10131013
# trigger an IREE/MLIR bug causing a compilation failure.
1014-
if config.N == 1 and config.tA == "T":
1014+
if config.N == 1 and (config.tA == "T" or config.tB == "T"):
10151015
continue
10161016
matching_configs.append((tag, config))
10171017

0 commit comments

Comments
 (0)