Skip to content

Commit baddea3

Browse files
committed
Use IREE's MLIR builder Python bindings for gemmbench IR generation
Resolves #52. The README is changed to suggest using the pre-release IREE builds (the CI already used them), because the linalg.matmul constructor is broken in the latest stable IREE build, but not on the latest IREE.
1 parent 6b4299b commit baddea3

File tree

4 files changed

+49
-82
lines changed

4 files changed

+49
-82
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
If you are not using a local iree build, install the iree pip packages:
66
```
7-
pip install --find-links https://iree.dev/pip-release-links.html iree-base-compiler iree-base-runtime --upgrade
7+
pip install --pre --find-links https://iree.dev/pip-release-links.html iree-base-compiler iree-base-runtime --upgrade
88
```
99

1010
Create a python environment and install the requirements for the project:

iree_kernel_benchmark/gemmbench/gemm_utils.py

Lines changed: 39 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import traceback
2121
from iree.compiler import ir
22+
from iree.compiler.dialects import arith, func, linalg, tensor
2223

2324

2425
@dataclass
@@ -89,75 +90,52 @@ def generate_mlir(config: GemmConfig):
8990
K = config.K
9091
M = config.M
9192
N = config.N
93+
tA = config.tA
94+
tB = config.tB
9295

9396
with ir.Location.name(config.get_name()):
9497
operand_element_type = _convert_dtype_to_mlir(config.operand_element_type)
9598
acc_element_type = _convert_dtype_to_mlir(config.accumulator_element_type)
9699
result_element_type = _convert_dtype_to_mlir(config.result_element_type)
97100
is_integer = isinstance(operand_element_type, ir.IntegerType)
98101
literal_zero = ir.IntegerAttr.get(acc_element_type, 0) if is_integer else ir.FloatAttr.get(acc_element_type, 0.0)
99-
K_M_operand_tensor_type = ir.RankedTensorType.get([K, M], operand_element_type)
100-
M_K_operand_tensor_type = ir.RankedTensorType.get([M, K], operand_element_type)
101-
K_N_operand_tensor_type = ir.RankedTensorType.get([K, N], operand_element_type)
102-
N_K_operand_tensor_type = ir.RankedTensorType.get([N, K], operand_element_type)
103-
M_N_acc_tensor_type = ir.RankedTensorType.get([M, N], acc_element_type)
104-
M_N_result_tensor_type = ir.RankedTensorType.get([M, N], result_element_type)
105-
106-
trunc_op = "arith.trunci" if is_integer else "arith.truncf"
107-
108-
tA = config.tA
109-
tB = config.tB
110-
mlir_template_matmul_transpose_a = f"""
111-
module {{
112-
func.func @main(%arg0: {K_M_operand_tensor_type}, %arg1: {K_N_operand_tensor_type}) -> {M_N_result_tensor_type} {{
113-
%cst = arith.constant {literal_zero}
114-
%0 = tensor.empty() : {M_N_acc_tensor_type}
115-
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : {M_N_acc_tensor_type}) -> {M_N_acc_tensor_type}
116-
%2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : {K_M_operand_tensor_type}, {K_N_operand_tensor_type})
117-
outs(%1 : {M_N_acc_tensor_type})
118-
-> {M_N_acc_tensor_type}
119-
"""
120-
121-
mlir_template_matmul_transpose_b = f"""
122-
module {{
123-
func.func @main(%arg0: {M_K_operand_tensor_type}, %arg1: {N_K_operand_tensor_type}) -> {M_N_result_tensor_type} {{
124-
%cst = arith.constant {literal_zero}
125-
%0 = tensor.empty() : {M_N_acc_tensor_type}
126-
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : {M_N_acc_tensor_type}) -> {M_N_acc_tensor_type}
127-
%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : {M_K_operand_tensor_type}, {N_K_operand_tensor_type})
128-
outs(%1 : {M_N_acc_tensor_type})
129-
-> {M_N_acc_tensor_type}
130-
"""
131-
132-
mlir_template_matmul_normal = f"""
133-
module {{
134-
func.func @main(%arg0: {M_K_operand_tensor_type}, %arg1: {K_N_operand_tensor_type}) -> {M_N_result_tensor_type} {{
135-
%cst = arith.constant {literal_zero}
136-
%0 = tensor.empty() : {M_N_acc_tensor_type}
137-
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : {M_N_acc_tensor_type}) -> {M_N_acc_tensor_type}
138-
%2 = linalg.matmul ins(%arg0, %arg1 : {M_K_operand_tensor_type}, {K_N_operand_tensor_type})
139-
outs(%1 : {M_N_acc_tensor_type})
140-
-> {M_N_acc_tensor_type}
141-
"""
142-
mlir_template_matmul = mlir_template_matmul_transpose_a if tA == "T" else mlir_template_matmul_transpose_b if tB == "T" else mlir_template_matmul_normal
143-
144-
mlir_template_return_truncated = f"""
145-
%3 = {trunc_op} %2 : {M_N_acc_tensor_type} to {M_N_result_tensor_type}
146-
return %3 : {M_N_result_tensor_type}
147-
}}
148-
}}
149-
"""
150-
151-
mlir_template_return_untruncated = f"""
152-
return %2 : {M_N_result_tensor_type}
153-
}}
154-
}}
155-
"""
156-
157-
mlir_template_return = mlir_template_return_untruncated if (acc_element_type == result_element_type) else mlir_template_return_truncated
158-
159-
return mlir_template_matmul + mlir_template_return
160102

103+
# Transpose A
104+
if tA == "T":
105+
arg0_type = ir.RankedTensorType.get([K, M], operand_element_type)
106+
arg1_type = ir.RankedTensorType.get([K, N], operand_element_type)
107+
# Transpose B
108+
elif tB == "T":
109+
arg0_type = ir.RankedTensorType.get([M, K], operand_element_type)
110+
arg1_type = ir.RankedTensorType.get([N, K], operand_element_type)
111+
# "Normal" path (can't transpose both)
112+
else:
113+
assert tA == "N" and tB == "N"
114+
arg0_type = ir.RankedTensorType.get([M, K], operand_element_type)
115+
arg1_type = ir.RankedTensorType.get([K, N], operand_element_type)
116+
result_type = ir.RankedTensorType.get([M, N], result_element_type)
117+
118+
module = ir.Module.create()
119+
with ir.InsertionPoint(module.body):
120+
@func.FuncOp.from_py_func(arg0_type, arg1_type)
121+
def main(arg0, arg1):
122+
zero_element = arith.constant(value = literal_zero, result = acc_element_type)
123+
empty_tensor = tensor.empty(element_type = acc_element_type, sizes = [M, N])
124+
filled_tensor = linalg.fill(zero_element, outs = [empty_tensor])
125+
126+
if tA == "T":
127+
acc = linalg.matmul_transpose_a(arg0, arg1, outs = [filled_tensor])
128+
elif tB == "T":
129+
acc = linalg.matmul_transpose_b(arg0, arg1, outs = [filled_tensor])
130+
else:
131+
acc = linalg.matmul(arg0, arg1, outs = [filled_tensor])
132+
133+
if acc_element_type == result_element_type:
134+
return acc
135+
if is_integer:
136+
return arith.trunci(result_type, acc)
137+
return arith.truncf(result_type, acc)
138+
return f"{module}"
161139

162140
@dataclass
163141
class TkTunedConfig:

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ tqdm
33
matplotlib
44
torch>=2.3.0
55
pytest>=8.3.5
6+
PyYAML>=6.0.2 # Required by iree.compiler.dialects.linalg

tests/test_gemmbench_mlir_gen.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@ def test_n_t_f16_f32_f16():
3434
"%cst = arith.constant 0.000000e+00 : f32",
3535
"%0 = tensor.empty() : tensor<512x4096xf32>",
3636
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x4096xf32>) -> tensor<512x4096xf32>",
37-
"%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<512x14336xf16>, tensor<4096x14336xf16>)",
38-
"outs(%1 : tensor<512x4096xf32>)",
39-
"-> tensor<512x4096xf32>",
37+
"%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<512x14336xf16>, tensor<4096x14336xf16>) outs(%1 : tensor<512x4096xf32>) -> tensor<512x4096xf32>",
4038
"%3 = arith.truncf %2 : tensor<512x4096xf32> to tensor<512x4096xf16>",
4139
"return %3 : tensor<512x4096xf16>",
4240
],
@@ -64,9 +62,7 @@ def test_n_t_bf16_f32_bf16():
6462
"%cst = arith.constant 0.000000e+00 : f32",
6563
"%0 = tensor.empty() : tensor<2x1280xf32>",
6664
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x1280xf32>) -> tensor<2x1280xf32>",
67-
"%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<2x8192xbf16>, tensor<1280x8192xbf16>)",
68-
"outs(%1 : tensor<2x1280xf32>)",
69-
"-> tensor<2x1280xf32>",
65+
"%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<2x8192xbf16>, tensor<1280x8192xbf16>) outs(%1 : tensor<2x1280xf32>) -> tensor<2x1280xf32>",
7066
"%3 = arith.truncf %2 : tensor<2x1280xf32> to tensor<2x1280xbf16>",
7167
"return %3 : tensor<2x1280xbf16>",
7268
],
@@ -94,9 +90,7 @@ def test_t_n_f16_f32_f16():
9490
"%cst = arith.constant 0.000000e+00 : f32",
9591
"%0 = tensor.empty() : tensor<32000x1xf32>",
9692
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32000x1xf32>) -> tensor<32000x1xf32>",
97-
"%2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<5120x32000xf16>, tensor<5120x1xf16>)",
98-
"outs(%1 : tensor<32000x1xf32>)",
99-
"-> tensor<32000x1xf32>",
93+
"%2 = linalg.matmul_transpose_a {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<5120x32000xf16>, tensor<5120x1xf16>) outs(%1 : tensor<32000x1xf32>) -> tensor<32000x1xf32>",
10094
"%3 = arith.truncf %2 : tensor<32000x1xf32> to tensor<32000x1xf16>",
10195
"return %3 : tensor<32000x1xf16>",
10296
],
@@ -124,9 +118,7 @@ def test_t_n_bf16_f32_bf16():
124118
"%cst = arith.constant 0.000000e+00 : f32",
125119
"%0 = tensor.empty() : tensor<32000x1xf32>",
126120
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32000x1xf32>) -> tensor<32000x1xf32>",
127-
"%2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<5120x32000xbf16>, tensor<5120x1xbf16>)",
128-
"outs(%1 : tensor<32000x1xf32>)",
129-
"-> tensor<32000x1xf32>",
121+
"%2 = linalg.matmul_transpose_a {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<5120x32000xbf16>, tensor<5120x1xbf16>) outs(%1 : tensor<32000x1xf32>) -> tensor<32000x1xf32>",
130122
"%3 = arith.truncf %2 : tensor<32000x1xf32> to tensor<32000x1xbf16>",
131123
"return %3 : tensor<32000x1xbf16>",
132124
],
@@ -154,9 +146,7 @@ def test_n_n_f16_f32_f16():
154146
"%cst = arith.constant 0.000000e+00 : f32",
155147
"%0 = tensor.empty() : tensor<2048x2048xf32>",
156148
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2048x2048xf32>) -> tensor<2048x2048xf32>",
157-
"%2 = linalg.matmul ins(%arg0, %arg1 : tensor<2048x1024xf16>, tensor<1024x2048xf16>)",
158-
"outs(%1 : tensor<2048x2048xf32>)",
159-
"-> tensor<2048x2048xf32>",
149+
"%2 = linalg.matmul ins(%arg0, %arg1 : tensor<2048x1024xf16>, tensor<1024x2048xf16>) outs(%1 : tensor<2048x2048xf32>) -> tensor<2048x2048xf32>",
160150
"%3 = arith.truncf %2 : tensor<2048x2048xf32> to tensor<2048x2048xf16>",
161151
"return %3 : tensor<2048x2048xf16>",
162152
],
@@ -181,12 +171,10 @@ def test_n_t_i8_i32_i8():
181171
[
182172
"module {",
183173
"func.func @main(%arg0: tensor<128x128xi8>, %arg1: tensor<128x128xi8>) -> tensor<128x128xi8> {",
184-
"%cst = arith.constant 0 : i32",
174+
"%c0_i32 = arith.constant 0 : i32",
185175
"%0 = tensor.empty() : tensor<128x128xi32>",
186-
"%1 = linalg.fill ins(%cst : i32) outs(%0 : tensor<128x128xi32>) -> tensor<128x128xi32>",
187-
"%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<128x128xi8>, tensor<128x128xi8>)",
188-
"outs(%1 : tensor<128x128xi32>)",
189-
"-> tensor<128x128xi32>",
176+
"%1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<128x128xi32>) -> tensor<128x128xi32>",
177+
"%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<128x128xi8>, tensor<128x128xi8>) outs(%1 : tensor<128x128xi32>) -> tensor<128x128xi32>",
190178
"%3 = arith.trunci %2 : tensor<128x128xi32> to tensor<128x128xi8>",
191179
"return %3 : tensor<128x128xi8>",
192180
],

0 commit comments

Comments
 (0)