Skip to content

Commit e4fb921

Browse files
authored
gemmbenchmark: drop matmul_transpose variants (#84)
This PR drops matmul_transpose variants since transposed linalg named op variants will be dropped in llvm/llvm-project#147961. Issue: iree-org/iree#21349 Signed-off-by: Bangtian Liu <[email protected]>
1 parent f695efc commit e4fb921

File tree

2 files changed

+71
-12
lines changed

2 files changed

+71
-12
lines changed

iree_kernel_benchmark/gemmbench/gemm_utils.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,33 @@ def main(arg0, arg1):
197197
)
198198
filled_tensor = linalg.fill(zero_element, outs=[empty_tensor])
199199

200+
# Define dimension expressions.
201+
d0 = ir.AffineDimExpr.get(0) # M
202+
d1 = ir.AffineDimExpr.get(1) # N
203+
d2 = ir.AffineDimExpr.get(2) # K
204+
# Default maps.
205+
map_A = ir.AffineMap.get(3, 0, [d0, d2])
206+
map_B = ir.AffineMap.get(3, 0, [d2, d1])
207+
map_C = ir.AffineMap.get(3, 0, [d0, d1])
200208
if tA == "T":
201-
acc = linalg.matmul_transpose_a(arg0, arg1, outs=[filled_tensor])
209+
map_A = ir.AffineMap.get(3, 0, [d2, d0])
202210
elif tB == "T":
203-
acc = linalg.matmul_transpose_b(arg0, arg1, outs=[filled_tensor])
204-
else:
205-
acc = linalg.matmul(arg0, arg1, outs=[filled_tensor])
211+
map_B = ir.AffineMap.get(3, 0, [d1, d2])
212+
213+
indexing_maps = ir.ArrayAttr.get(
214+
[
215+
ir.AffineMapAttr.get(map_A),
216+
ir.AffineMapAttr.get(map_B),
217+
ir.AffineMapAttr.get(map_C),
218+
]
219+
)
220+
221+
acc = linalg.matmul(
222+
arg0,
223+
arg1,
224+
outs=[filled_tensor],
225+
indexing_maps=indexing_maps,
226+
)
206227

207228
if acc_element_type == result_element_type:
208229
return acc

tests/test_gemmbench_mlir_gen.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,15 @@ def test_n_t_f16_f32_f16():
3333
match_lines(
3434
mlir,
3535
[
36+
"#map = affine_map<(d0, d1, d2) -> (d0, d2)>",
37+
"#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>",
38+
"#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>",
3639
"module {",
3740
"func.func @main(%arg0: tensor<512x14336xf16>, %arg1: tensor<4096x14336xf16>) -> tensor<512x4096xf16> {",
3841
"%cst = arith.constant 0.000000e+00 : f32",
3942
"%0 = tensor.empty() : tensor<512x4096xf32>",
4043
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x4096xf32>) -> tensor<512x4096xf32>",
41-
"%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<512x14336xf16>, tensor<4096x14336xf16>) outs(%1 : tensor<512x4096xf32>) -> tensor<512x4096xf32>",
44+
"%2 = linalg.matmul indexing_maps = [#map, #map1, #map2] ins(%arg0, %arg1 : tensor<512x14336xf16>, tensor<4096x14336xf16>) outs(%1 : tensor<512x4096xf32>) -> tensor<512x4096xf32>",
4245
"%3 = arith.truncf %2 : tensor<512x4096xf32> to tensor<512x4096xf16>",
4346
"return %3 : tensor<512x4096xf16>",
4447
],
@@ -61,12 +64,17 @@ def test_n_t_f8_f32_f8():
6164
match_lines(
6265
mlir,
6366
[
67+
"#map = affine_map<(d0, d1, d2) -> (d0, d2)>",
68+
"#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>",
69+
"#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>",
6470
"module {",
6571
"func.func @main(%arg0: tensor<512x14336xf8E4M3FNUZ>, %arg1: tensor<4096x14336xf8E4M3FNUZ>) -> tensor<512x4096xf8E4M3FNUZ> {",
6672
"%cst = arith.constant 0.000000e+00 : f32",
6773
"%0 = tensor.empty() : tensor<512x4096xf32>",
6874
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x4096xf32>) -> tensor<512x4096xf32>",
69-
"%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<512x14336xf8E4M3FNUZ>, tensor<4096x14336xf8E4M3FNUZ>) outs(%1 : tensor<512x4096xf32>) -> tensor<512x4096xf32>",
75+
"%2 = linalg.matmul indexing_maps = [#map, #map1, #map2] "
76+
"ins(%arg0, %arg1 : tensor<512x14336xf8E4M3FNUZ>, tensor<4096x14336xf8E4M3FNUZ>) "
77+
"outs(%1 : tensor<512x4096xf32>) -> tensor<512x4096xf32>",
7078
"%3 = arith.truncf %2 : tensor<512x4096xf32> to tensor<512x4096xf8E4M3FNUZ>",
7179
"return %3 : tensor<512x4096xf8E4M3FNUZ>",
7280
],
@@ -90,14 +98,19 @@ def test_n_t_f16_f32_f16_dynamic_dim_M():
9098
match_lines(
9199
mlir,
92100
[
101+
"#map = affine_map<(d0, d1, d2) -> (d0, d2)>",
102+
"#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>",
103+
"#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>",
93104
"module {",
94105
"func.func @main(%arg0: tensor<?x4096xf16>, %arg1: tensor<14336x4096xf16>) -> tensor<?x14336xf16> {",
95106
"%cst = arith.constant 0.000000e+00 : f32",
96107
"%c0 = arith.constant 0 : index",
97108
"%dim = tensor.dim %arg0, %c0 : tensor<?x4096xf16>",
98109
"%0 = tensor.empty(%dim) : tensor<?x14336xf32>",
99110
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x14336xf32>) -> tensor<?x14336xf32>",
100-
"%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<?x4096xf16>, tensor<14336x4096xf16>) outs(%1 : tensor<?x14336xf32>) -> tensor<?x14336xf32>",
111+
"%2 = linalg.matmul indexing_maps = [#map, #map1, #map2] "
112+
"ins(%arg0, %arg1 : tensor<?x4096xf16>, tensor<14336x4096xf16>) "
113+
"outs(%1 : tensor<?x14336xf32>) -> tensor<?x14336xf32>",
101114
"%3 = arith.truncf %2 : tensor<?x14336xf32> to tensor<?x14336xf16>",
102115
"return %3 : tensor<?x14336xf16>",
103116
],
@@ -121,14 +134,19 @@ def test_t_n_f16_f32_f16_dynamic_dim_N():
121134
match_lines(
122135
mlir,
123136
[
137+
"#map = affine_map<(d0, d1, d2) -> (d2, d0)>",
138+
"#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>",
139+
"#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>",
124140
"module {",
125141
"func.func @main(%arg0: tensor<4096x512xf16>, %arg1: tensor<4096x?xf16>) -> tensor<512x?xf16> {",
126142
"%cst = arith.constant 0.000000e+00 : f32",
127143
"%c1 = arith.constant 1 : index",
128144
"%dim = tensor.dim %arg1, %c1 : tensor<4096x?xf16>",
129145
"%0 = tensor.empty(%dim) : tensor<512x?xf32>",
130146
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x?xf32>) -> tensor<512x?xf32>",
131-
"%2 = linalg.matmul_transpose_a {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<4096x512xf16>, tensor<4096x?xf16>) outs(%1 : tensor<512x?xf32>) -> tensor<512x?xf32>",
147+
"%2 = linalg.matmul indexing_maps = [#map, #map1, #map2] "
148+
"ins(%arg0, %arg1 : tensor<4096x512xf16>, tensor<4096x?xf16>) "
149+
"outs(%1 : tensor<512x?xf32>) -> tensor<512x?xf32>",
132150
"%3 = arith.truncf %2 : tensor<512x?xf32> to tensor<512x?xf16>",
133151
"return %3 : tensor<512x?xf16>",
134152
],
@@ -213,12 +231,17 @@ def test_n_t_bf16_f32_bf16():
213231
match_lines(
214232
mlir,
215233
[
234+
"#map = affine_map<(d0, d1, d2) -> (d0, d2)>",
235+
"#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>",
236+
"#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>",
216237
"module {",
217238
"func.func @main(%arg0: tensor<2x8192xbf16>, %arg1: tensor<1280x8192xbf16>) -> tensor<2x1280xbf16> {",
218239
"%cst = arith.constant 0.000000e+00 : f32",
219240
"%0 = tensor.empty() : tensor<2x1280xf32>",
220241
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x1280xf32>) -> tensor<2x1280xf32>",
221-
"%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<2x8192xbf16>, tensor<1280x8192xbf16>) outs(%1 : tensor<2x1280xf32>) -> tensor<2x1280xf32>",
242+
"%2 = linalg.matmul indexing_maps = [#map, #map1, #map2] "
243+
"ins(%arg0, %arg1 : tensor<2x8192xbf16>, tensor<1280x8192xbf16>) "
244+
"outs(%1 : tensor<2x1280xf32>) -> tensor<2x1280xf32>",
222245
"%3 = arith.truncf %2 : tensor<2x1280xf32> to tensor<2x1280xbf16>",
223246
"return %3 : tensor<2x1280xbf16>",
224247
],
@@ -241,12 +264,17 @@ def test_t_n_f16_f32_f16():
241264
match_lines(
242265
mlir,
243266
[
267+
"#map = affine_map<(d0, d1, d2) -> (d2, d0)>",
268+
"#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>",
269+
"#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>",
244270
"module {",
245271
"func.func @main(%arg0: tensor<5120x32000xf16>, %arg1: tensor<5120x1xf16>) -> tensor<32000x1xf16> {",
246272
"%cst = arith.constant 0.000000e+00 : f32",
247273
"%0 = tensor.empty() : tensor<32000x1xf32>",
248274
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32000x1xf32>) -> tensor<32000x1xf32>",
249-
"%2 = linalg.matmul_transpose_a {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<5120x32000xf16>, tensor<5120x1xf16>) outs(%1 : tensor<32000x1xf32>) -> tensor<32000x1xf32>",
275+
"%2 = linalg.matmul indexing_maps = [#map, #map1, #map2] "
276+
"ins(%arg0, %arg1 : tensor<5120x32000xf16>, tensor<5120x1xf16>) "
277+
"outs(%1 : tensor<32000x1xf32>) -> tensor<32000x1xf32>",
250278
"%3 = arith.truncf %2 : tensor<32000x1xf32> to tensor<32000x1xf16>",
251279
"return %3 : tensor<32000x1xf16>",
252280
],
@@ -269,12 +297,17 @@ def test_t_n_bf16_f32_bf16():
269297
match_lines(
270298
mlir,
271299
[
300+
"#map = affine_map<(d0, d1, d2) -> (d2, d0)>",
301+
"#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>",
302+
"#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>",
272303
"module {",
273304
"func.func @main(%arg0: tensor<5120x32000xbf16>, %arg1: tensor<5120x1xbf16>) -> tensor<32000x1xbf16> {",
274305
"%cst = arith.constant 0.000000e+00 : f32",
275306
"%0 = tensor.empty() : tensor<32000x1xf32>",
276307
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32000x1xf32>) -> tensor<32000x1xf32>",
277-
"%2 = linalg.matmul_transpose_a {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<5120x32000xbf16>, tensor<5120x1xbf16>) outs(%1 : tensor<32000x1xf32>) -> tensor<32000x1xf32>",
308+
"%2 = linalg.matmul indexing_maps = [#map, #map1, #map2] "
309+
"ins(%arg0, %arg1 : tensor<5120x32000xbf16>, tensor<5120x1xbf16>) "
310+
"outs(%1 : tensor<32000x1xf32>) -> tensor<32000x1xf32>",
278311
"%3 = arith.truncf %2 : tensor<32000x1xf32> to tensor<32000x1xbf16>",
279312
"return %3 : tensor<32000x1xbf16>",
280313
],
@@ -325,12 +358,17 @@ def test_n_t_i8_i32_i8():
325358
match_lines(
326359
mlir,
327360
[
361+
"#map = affine_map<(d0, d1, d2) -> (d0, d2)>",
362+
"#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>",
363+
"#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>",
328364
"module {",
329365
"func.func @main(%arg0: tensor<128x128xi8>, %arg1: tensor<128x128xi8>) -> tensor<128x128xi8> {",
330366
"%c0_i32 = arith.constant 0 : i32",
331367
"%0 = tensor.empty() : tensor<128x128xi32>",
332368
"%1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<128x128xi32>) -> tensor<128x128xi32>",
333-
"%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<128x128xi8>, tensor<128x128xi8>) outs(%1 : tensor<128x128xi32>) -> tensor<128x128xi32>",
369+
"%2 = linalg.matmul indexing_maps = [#map, #map1, #map2] "
370+
"ins(%arg0, %arg1 : tensor<128x128xi8>, tensor<128x128xi8>) "
371+
"outs(%1 : tensor<128x128xi32>) -> tensor<128x128xi32>",
334372
"%3 = arith.trunci %2 : tensor<128x128xi32> to tensor<128x128xi8>",
335373
"return %3 : tensor<128x128xi8>",
336374
],

0 commit comments

Comments
 (0)