@@ -33,12 +33,15 @@ def test_n_t_f16_f32_f16():
33
33
match_lines (
34
34
mlir ,
35
35
[
36
+ "#map = affine_map<(d0, d1, d2) -> (d0, d2)>" ,
37
+ "#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>" ,
38
+ "#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>" ,
36
39
"module {" ,
37
40
"func.func @main(%arg0: tensor<512x14336xf16>, %arg1: tensor<4096x14336xf16>) -> tensor<512x4096xf16> {" ,
38
41
"%cst = arith.constant 0.000000e+00 : f32" ,
39
42
"%0 = tensor.empty() : tensor<512x4096xf32>" ,
40
43
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x4096xf32>) -> tensor<512x4096xf32>" ,
41
- "%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<512x14336xf16>, tensor<4096x14336xf16>) outs(%1 : tensor<512x4096xf32>) -> tensor<512x4096xf32>" ,
44
+ "%2 = linalg.matmul indexing_maps = [#map, #map1, #map2] ins(%arg0, %arg1 : tensor<512x14336xf16>, tensor<4096x14336xf16>) outs(%1 : tensor<512x4096xf32>) -> tensor<512x4096xf32>" ,
42
45
"%3 = arith.truncf %2 : tensor<512x4096xf32> to tensor<512x4096xf16>" ,
43
46
"return %3 : tensor<512x4096xf16>" ,
44
47
],
@@ -61,12 +64,17 @@ def test_n_t_f8_f32_f8():
61
64
match_lines (
62
65
mlir ,
63
66
[
67
+ "#map = affine_map<(d0, d1, d2) -> (d0, d2)>" ,
68
+ "#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>" ,
69
+ "#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>" ,
64
70
"module {" ,
65
71
"func.func @main(%arg0: tensor<512x14336xf8E4M3FNUZ>, %arg1: tensor<4096x14336xf8E4M3FNUZ>) -> tensor<512x4096xf8E4M3FNUZ> {" ,
66
72
"%cst = arith.constant 0.000000e+00 : f32" ,
67
73
"%0 = tensor.empty() : tensor<512x4096xf32>" ,
68
74
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x4096xf32>) -> tensor<512x4096xf32>" ,
69
- "%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<512x14336xf8E4M3FNUZ>, tensor<4096x14336xf8E4M3FNUZ>) outs(%1 : tensor<512x4096xf32>) -> tensor<512x4096xf32>" ,
75
+ "%2 = linalg.matmul indexing_maps = [#map, #map1, #map2] "
76
+ "ins(%arg0, %arg1 : tensor<512x14336xf8E4M3FNUZ>, tensor<4096x14336xf8E4M3FNUZ>) "
77
+ "outs(%1 : tensor<512x4096xf32>) -> tensor<512x4096xf32>" ,
70
78
"%3 = arith.truncf %2 : tensor<512x4096xf32> to tensor<512x4096xf8E4M3FNUZ>" ,
71
79
"return %3 : tensor<512x4096xf8E4M3FNUZ>" ,
72
80
],
@@ -90,14 +98,19 @@ def test_n_t_f16_f32_f16_dynamic_dim_M():
90
98
match_lines (
91
99
mlir ,
92
100
[
101
+ "#map = affine_map<(d0, d1, d2) -> (d0, d2)>" ,
102
+ "#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>" ,
103
+ "#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>" ,
93
104
"module {" ,
94
105
"func.func @main(%arg0: tensor<?x4096xf16>, %arg1: tensor<14336x4096xf16>) -> tensor<?x14336xf16> {" ,
95
106
"%cst = arith.constant 0.000000e+00 : f32" ,
96
107
"%c0 = arith.constant 0 : index" ,
97
108
"%dim = tensor.dim %arg0, %c0 : tensor<?x4096xf16>" ,
98
109
"%0 = tensor.empty(%dim) : tensor<?x14336xf32>" ,
99
110
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x14336xf32>) -> tensor<?x14336xf32>" ,
100
- "%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<?x4096xf16>, tensor<14336x4096xf16>) outs(%1 : tensor<?x14336xf32>) -> tensor<?x14336xf32>" ,
111
+ "%2 = linalg.matmul indexing_maps = [#map, #map1, #map2] "
112
+ "ins(%arg0, %arg1 : tensor<?x4096xf16>, tensor<14336x4096xf16>) "
113
+ "outs(%1 : tensor<?x14336xf32>) -> tensor<?x14336xf32>" ,
101
114
"%3 = arith.truncf %2 : tensor<?x14336xf32> to tensor<?x14336xf16>" ,
102
115
"return %3 : tensor<?x14336xf16>" ,
103
116
],
@@ -121,14 +134,19 @@ def test_t_n_f16_f32_f16_dynamic_dim_N():
121
134
match_lines (
122
135
mlir ,
123
136
[
137
+ "#map = affine_map<(d0, d1, d2) -> (d2, d0)>" ,
138
+ "#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>" ,
139
+ "#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>" ,
124
140
"module {" ,
125
141
"func.func @main(%arg0: tensor<4096x512xf16>, %arg1: tensor<4096x?xf16>) -> tensor<512x?xf16> {" ,
126
142
"%cst = arith.constant 0.000000e+00 : f32" ,
127
143
"%c1 = arith.constant 1 : index" ,
128
144
"%dim = tensor.dim %arg1, %c1 : tensor<4096x?xf16>" ,
129
145
"%0 = tensor.empty(%dim) : tensor<512x?xf32>" ,
130
146
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x?xf32>) -> tensor<512x?xf32>" ,
131
- "%2 = linalg.matmul_transpose_a {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<4096x512xf16>, tensor<4096x?xf16>) outs(%1 : tensor<512x?xf32>) -> tensor<512x?xf32>" ,
147
+ "%2 = linalg.matmul indexing_maps = [#map, #map1, #map2] "
148
+ "ins(%arg0, %arg1 : tensor<4096x512xf16>, tensor<4096x?xf16>) "
149
+ "outs(%1 : tensor<512x?xf32>) -> tensor<512x?xf32>" ,
132
150
"%3 = arith.truncf %2 : tensor<512x?xf32> to tensor<512x?xf16>" ,
133
151
"return %3 : tensor<512x?xf16>" ,
134
152
],
@@ -213,12 +231,17 @@ def test_n_t_bf16_f32_bf16():
213
231
match_lines (
214
232
mlir ,
215
233
[
234
+ "#map = affine_map<(d0, d1, d2) -> (d0, d2)>" ,
235
+ "#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>" ,
236
+ "#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>" ,
216
237
"module {" ,
217
238
"func.func @main(%arg0: tensor<2x8192xbf16>, %arg1: tensor<1280x8192xbf16>) -> tensor<2x1280xbf16> {" ,
218
239
"%cst = arith.constant 0.000000e+00 : f32" ,
219
240
"%0 = tensor.empty() : tensor<2x1280xf32>" ,
220
241
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x1280xf32>) -> tensor<2x1280xf32>" ,
221
- "%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<2x8192xbf16>, tensor<1280x8192xbf16>) outs(%1 : tensor<2x1280xf32>) -> tensor<2x1280xf32>" ,
242
+ "%2 = linalg.matmul indexing_maps = [#map, #map1, #map2] "
243
+ "ins(%arg0, %arg1 : tensor<2x8192xbf16>, tensor<1280x8192xbf16>) "
244
+ "outs(%1 : tensor<2x1280xf32>) -> tensor<2x1280xf32>" ,
222
245
"%3 = arith.truncf %2 : tensor<2x1280xf32> to tensor<2x1280xbf16>" ,
223
246
"return %3 : tensor<2x1280xbf16>" ,
224
247
],
@@ -241,12 +264,17 @@ def test_t_n_f16_f32_f16():
241
264
match_lines (
242
265
mlir ,
243
266
[
267
+ "#map = affine_map<(d0, d1, d2) -> (d2, d0)>" ,
268
+ "#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>" ,
269
+ "#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>" ,
244
270
"module {" ,
245
271
"func.func @main(%arg0: tensor<5120x32000xf16>, %arg1: tensor<5120x1xf16>) -> tensor<32000x1xf16> {" ,
246
272
"%cst = arith.constant 0.000000e+00 : f32" ,
247
273
"%0 = tensor.empty() : tensor<32000x1xf32>" ,
248
274
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32000x1xf32>) -> tensor<32000x1xf32>" ,
249
- "%2 = linalg.matmul_transpose_a {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<5120x32000xf16>, tensor<5120x1xf16>) outs(%1 : tensor<32000x1xf32>) -> tensor<32000x1xf32>" ,
275
+ "%2 = linalg.matmul indexing_maps = [#map, #map1, #map2] "
276
+ "ins(%arg0, %arg1 : tensor<5120x32000xf16>, tensor<5120x1xf16>) "
277
+ "outs(%1 : tensor<32000x1xf32>) -> tensor<32000x1xf32>" ,
250
278
"%3 = arith.truncf %2 : tensor<32000x1xf32> to tensor<32000x1xf16>" ,
251
279
"return %3 : tensor<32000x1xf16>" ,
252
280
],
@@ -269,12 +297,17 @@ def test_t_n_bf16_f32_bf16():
269
297
match_lines (
270
298
mlir ,
271
299
[
300
+ "#map = affine_map<(d0, d1, d2) -> (d2, d0)>" ,
301
+ "#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>" ,
302
+ "#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>" ,
272
303
"module {" ,
273
304
"func.func @main(%arg0: tensor<5120x32000xbf16>, %arg1: tensor<5120x1xbf16>) -> tensor<32000x1xbf16> {" ,
274
305
"%cst = arith.constant 0.000000e+00 : f32" ,
275
306
"%0 = tensor.empty() : tensor<32000x1xf32>" ,
276
307
"%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32000x1xf32>) -> tensor<32000x1xf32>" ,
277
- "%2 = linalg.matmul_transpose_a {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<5120x32000xbf16>, tensor<5120x1xbf16>) outs(%1 : tensor<32000x1xf32>) -> tensor<32000x1xf32>" ,
308
+ "%2 = linalg.matmul indexing_maps = [#map, #map1, #map2] "
309
+ "ins(%arg0, %arg1 : tensor<5120x32000xbf16>, tensor<5120x1xbf16>) "
310
+ "outs(%1 : tensor<32000x1xf32>) -> tensor<32000x1xf32>" ,
278
311
"%3 = arith.truncf %2 : tensor<32000x1xf32> to tensor<32000x1xbf16>" ,
279
312
"return %3 : tensor<32000x1xbf16>" ,
280
313
],
@@ -325,12 +358,17 @@ def test_n_t_i8_i32_i8():
325
358
match_lines (
326
359
mlir ,
327
360
[
361
+ "#map = affine_map<(d0, d1, d2) -> (d0, d2)>" ,
362
+ "#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>" ,
363
+ "#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>" ,
328
364
"module {" ,
329
365
"func.func @main(%arg0: tensor<128x128xi8>, %arg1: tensor<128x128xi8>) -> tensor<128x128xi8> {" ,
330
366
"%c0_i32 = arith.constant 0 : i32" ,
331
367
"%0 = tensor.empty() : tensor<128x128xi32>" ,
332
368
"%1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<128x128xi32>) -> tensor<128x128xi32>" ,
333
- "%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<128x128xi8>, tensor<128x128xi8>) outs(%1 : tensor<128x128xi32>) -> tensor<128x128xi32>" ,
369
+ "%2 = linalg.matmul indexing_maps = [#map, #map1, #map2] "
370
+ "ins(%arg0, %arg1 : tensor<128x128xi8>, tensor<128x128xi8>) "
371
+ "outs(%1 : tensor<128x128xi32>) -> tensor<128x128xi32>" ,
334
372
"%3 = arith.trunci %2 : tensor<128x128xi32> to tensor<128x128xi8>" ,
335
373
"return %3 : tensor<128x128xi8>" ,
336
374
],
0 commit comments