Skip to content

Commit 6b4299b

Browse files
authored
Use IREE's MLIR builder Python bindings to construct gemmbench types (#56)
For now, these get implicitly converted to strings when used by the IR template strings, but this still prepares for generating the entire IR with the bindings (the goal of #52).
1 parent 8e344a8 commit 6b4299b

File tree

2 files changed

+68
-37
lines changed

2 files changed

+68
-37
lines changed

iree_kernel_benchmark/gemmbench/gemm_utils.py

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ..utils import *
1919
import os
2020
import traceback
21+
from iree.compiler import ir
2122

2223

2324
@dataclass
@@ -71,63 +72,84 @@ def get_flops(self) -> int:
7172
flops = 2 * self.M * self.N * self.K
7273
return flops
7374

75+
def _convert_dtype_to_mlir(dtype: str) -> ir.Type:
76+
dtypes = {
77+
"i8": lambda: ir.IntegerType.get_signless(8),
78+
"i16": lambda: ir.IntegerType.get_signless(16),
79+
"i32": lambda: ir.IntegerType.get_signless(32),
80+
"i64": lambda: ir.IntegerType.get_signless(64),
81+
"f16": lambda: ir.F16Type.get(),
82+
"f32": lambda: ir.F32Type.get(),
83+
"f64": lambda: ir.F64Type.get(),
84+
"bf16": lambda: ir.BF16Type.get(),
85+
}
86+
return dtypes[dtype]()
7487

7588
def generate_mlir(config: GemmConfig):
7689
K = config.K
7790
M = config.M
7891
N = config.N
79-
operand_element_type = config.operand_element_type
80-
acc_element_type = config.accumulator_element_type
81-
result_element_type = config.result_element_type
82-
is_integer = operand_element_type.startswith('i')
83-
literal_zero = "0" if is_integer else "0.0"
92+
93+
with ir.Location.name(config.get_name()):
94+
operand_element_type = _convert_dtype_to_mlir(config.operand_element_type)
95+
acc_element_type = _convert_dtype_to_mlir(config.accumulator_element_type)
96+
result_element_type = _convert_dtype_to_mlir(config.result_element_type)
97+
is_integer = isinstance(operand_element_type, ir.IntegerType)
98+
literal_zero = ir.IntegerAttr.get(acc_element_type, 0) if is_integer else ir.FloatAttr.get(acc_element_type, 0.0)
99+
K_M_operand_tensor_type = ir.RankedTensorType.get([K, M], operand_element_type)
100+
M_K_operand_tensor_type = ir.RankedTensorType.get([M, K], operand_element_type)
101+
K_N_operand_tensor_type = ir.RankedTensorType.get([K, N], operand_element_type)
102+
N_K_operand_tensor_type = ir.RankedTensorType.get([N, K], operand_element_type)
103+
M_N_acc_tensor_type = ir.RankedTensorType.get([M, N], acc_element_type)
104+
M_N_result_tensor_type = ir.RankedTensorType.get([M, N], result_element_type)
105+
84106
trunc_op = "arith.trunci" if is_integer else "arith.truncf"
85107

86108
tA = config.tA
87109
tB = config.tB
88110
mlir_template_matmul_transpose_a = f"""
89111
module {{
90-
func.func @main(%arg0: tensor<{K}x{M}x{operand_element_type}>, %arg1: tensor<{K}x{N}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{
91-
%cst = arith.constant {literal_zero} : {acc_element_type}
92-
%0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}>
93-
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}>
94-
%2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<{K}x{M}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>)
95-
outs(%1 : tensor<{M}x{N}x{acc_element_type}>)
96-
-> tensor<{M}x{N}x{acc_element_type}>
112+
func.func @main(%arg0: {K_M_operand_tensor_type}, %arg1: {K_N_operand_tensor_type}) -> {M_N_result_tensor_type} {{
113+
%cst = arith.constant {literal_zero}
114+
%0 = tensor.empty() : {M_N_acc_tensor_type}
115+
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : {M_N_acc_tensor_type}) -> {M_N_acc_tensor_type}
116+
%2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : {K_M_operand_tensor_type}, {K_N_operand_tensor_type})
117+
outs(%1 : {M_N_acc_tensor_type})
118+
-> {M_N_acc_tensor_type}
97119
"""
98120

99121
mlir_template_matmul_transpose_b = f"""
100122
module {{
101-
func.func @main(%arg0: tensor<{M}x{K}x{operand_element_type}>, %arg1: tensor<{N}x{K}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{
102-
%cst = arith.constant {literal_zero} : {acc_element_type}
103-
%0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}>
104-
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}>
105-
%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{N}x{K}x{operand_element_type}>)
106-
outs(%1 : tensor<{M}x{N}x{acc_element_type}>)
107-
-> tensor<{M}x{N}x{acc_element_type}>
123+
func.func @main(%arg0: {M_K_operand_tensor_type}, %arg1: {N_K_operand_tensor_type}) -> {M_N_result_tensor_type} {{
124+
%cst = arith.constant {literal_zero}
125+
%0 = tensor.empty() : {M_N_acc_tensor_type}
126+
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : {M_N_acc_tensor_type}) -> {M_N_acc_tensor_type}
127+
%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : {M_K_operand_tensor_type}, {N_K_operand_tensor_type})
128+
outs(%1 : {M_N_acc_tensor_type})
129+
-> {M_N_acc_tensor_type}
108130
"""
109131

110132
mlir_template_matmul_normal = f"""
111133
module {{
112-
func.func @main(%arg0: tensor<{M}x{K}x{operand_element_type}>, %arg1: tensor<{K}x{N}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{
113-
%cst = arith.constant {literal_zero} : {acc_element_type}
114-
%0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}>
115-
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}>
116-
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>)
117-
outs(%1 : tensor<{M}x{N}x{acc_element_type}>)
118-
-> tensor<{M}x{N}x{acc_element_type}>
134+
func.func @main(%arg0: {M_K_operand_tensor_type}, %arg1: {K_N_operand_tensor_type}) -> {M_N_result_tensor_type} {{
135+
%cst = arith.constant {literal_zero}
136+
%0 = tensor.empty() : {M_N_acc_tensor_type}
137+
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : {M_N_acc_tensor_type}) -> {M_N_acc_tensor_type}
138+
%2 = linalg.matmul ins(%arg0, %arg1 : {M_K_operand_tensor_type}, {K_N_operand_tensor_type})
139+
outs(%1 : {M_N_acc_tensor_type})
140+
-> {M_N_acc_tensor_type}
119141
"""
120142
mlir_template_matmul = mlir_template_matmul_transpose_a if tA == "T" else mlir_template_matmul_transpose_b if tB == "T" else mlir_template_matmul_normal
121143

122144
mlir_template_return_truncated = f"""
123-
%3 = {trunc_op} %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}>
124-
return %3 : tensor<{M}x{N}x{result_element_type}>
145+
%3 = {trunc_op} %2 : {M_N_acc_tensor_type} to {M_N_result_tensor_type}
146+
return %3 : {M_N_result_tensor_type}
125147
}}
126148
}}
127149
"""
128150

129151
mlir_template_return_untruncated = f"""
130-
return %2 : tensor<{M}x{N}x{result_element_type}>
152+
return %2 : {M_N_result_tensor_type}
131153
}}
132154
}}
133155
"""
@@ -167,7 +189,7 @@ def get_tk_tuned_config(config: GemmConfig) -> TkTunedConfig:
167189
# Default config
168190
return TkTunedConfig(64, 64, 32, 2, 2, 1, 2, 2, 2, 1, 1, 2)
169191

170-
def _convert_dtype(dtype: str):
192+
def _convert_dtype_to_tk(dtype: str):
171193
dtypes = {
172194
"i8": tkl.i8,
173195
"i16": tkl.i16,
@@ -189,7 +211,7 @@ def generate_tk_mlir(config: GemmConfig, vmfb_file: Path):
189211
assert config.operand_element_type == 'f16', "Unsupported problem"
190212
assert config.accumulator_element_type == 'f32', "Unsupported problem"
191213

192-
res_dtype = _convert_dtype(config.result_element_type)
214+
res_dtype = _convert_dtype_to_tk(config.result_element_type)
193215
# Input sizes
194216
M = tkl.sym.M
195217
N = tkl.sym.N
@@ -306,7 +328,8 @@ def compile_gemm_config(
306328
f.write(traceback.format_exc())
307329
return mlir_file, None
308330
else:
309-
mlir_content = generate_mlir(config)
331+
with ir.Context():
332+
mlir_content = generate_mlir(config)
310333

311334
# Write MLIR content to file
312335
with open(mlir_file, "w") as f:

tests/test_gemmbench_mlir_gen.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
from iree_kernel_benchmark.gemmbench.gemm_utils import GemmConfig, generate_mlir
22
from .utils import match_lines
3+
from iree.compiler import ir
4+
import pytest
35

46
# These tests should contain a small sampling of the actual problem set, enough
57
# to exercise most of the code paths in the MLIR generation.
68

79

10+
@pytest.fixture(autouse=True)
11+
def run_with_mlir_ctx():
12+
with ir.Context():
13+
yield
14+
15+
816
def test_n_t_f16_f32_f16():
917
# From 'llama8b_prefill'
1018
cfg = GemmConfig(
@@ -23,7 +31,7 @@ def test_n_t_f16_f32_f16():
2331
[
2432
"module {",
2533
"func.func @main(%arg0: tensor<512x14336xf16>, %arg1: tensor<4096x14336xf16>) -> tensor<512x4096xf16> {",
26-
"%cst = arith.constant 0.0 : f32",
34+
"%cst = arith.constant 0.000000e+00 : f32",
2735
"%0 = tensor.empty() : tensor<512x4096xf32>",
2836
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x4096xf32>) -> tensor<512x4096xf32>",
2937
"%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<512x14336xf16>, tensor<4096x14336xf16>)",
@@ -53,7 +61,7 @@ def test_n_t_bf16_f32_bf16():
5361
[
5462
"module {",
5563
"func.func @main(%arg0: tensor<2x8192xbf16>, %arg1: tensor<1280x8192xbf16>) -> tensor<2x1280xbf16> {",
56-
"%cst = arith.constant 0.0 : f32",
64+
"%cst = arith.constant 0.000000e+00 : f32",
5765
"%0 = tensor.empty() : tensor<2x1280xf32>",
5866
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x1280xf32>) -> tensor<2x1280xf32>",
5967
"%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<2x8192xbf16>, tensor<1280x8192xbf16>)",
@@ -83,7 +91,7 @@ def test_t_n_f16_f32_f16():
8391
[
8492
"module {",
8593
"func.func @main(%arg0: tensor<5120x32000xf16>, %arg1: tensor<5120x1xf16>) -> tensor<32000x1xf16> {",
86-
"%cst = arith.constant 0.0 : f32",
94+
"%cst = arith.constant 0.000000e+00 : f32",
8795
"%0 = tensor.empty() : tensor<32000x1xf32>",
8896
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32000x1xf32>) -> tensor<32000x1xf32>",
8997
"%2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<5120x32000xf16>, tensor<5120x1xf16>)",
@@ -113,7 +121,7 @@ def test_t_n_bf16_f32_bf16():
113121
[
114122
"module {",
115123
"func.func @main(%arg0: tensor<5120x32000xbf16>, %arg1: tensor<5120x1xbf16>) -> tensor<32000x1xbf16> {",
116-
"%cst = arith.constant 0.0 : f32",
124+
"%cst = arith.constant 0.000000e+00 : f32",
117125
"%0 = tensor.empty() : tensor<32000x1xf32>",
118126
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32000x1xf32>) -> tensor<32000x1xf32>",
119127
"%2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<5120x32000xbf16>, tensor<5120x1xbf16>)",
@@ -143,7 +151,7 @@ def test_n_n_f16_f32_f16():
143151
[
144152
"module {",
145153
"func.func @main(%arg0: tensor<2048x1024xf16>, %arg1: tensor<1024x2048xf16>) -> tensor<2048x2048xf16> {",
146-
"%cst = arith.constant 0.0 : f32",
154+
"%cst = arith.constant 0.000000e+00 : f32",
147155
"%0 = tensor.empty() : tensor<2048x2048xf32>",
148156
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2048x2048xf32>) -> tensor<2048x2048xf32>",
149157
"%2 = linalg.matmul ins(%arg0, %arg1 : tensor<2048x1024xf16>, tensor<1024x2048xf16>)",

0 commit comments

Comments
 (0)