@@ -34,9 +34,7 @@ def test_n_t_f16_f32_f16():
3434 "%cst = arith.constant 0.000000e+00 : f32" ,
3535 "%0 = tensor.empty() : tensor<512x4096xf32>" ,
3636 "%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x4096xf32>) -> tensor<512x4096xf32>" ,
37- "%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<512x14336xf16>, tensor<4096x14336xf16>)" ,
38- "outs(%1 : tensor<512x4096xf32>)" ,
39- "-> tensor<512x4096xf32>" ,
37+ "%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<512x14336xf16>, tensor<4096x14336xf16>) outs(%1 : tensor<512x4096xf32>) -> tensor<512x4096xf32>" ,
4038 "%3 = arith.truncf %2 : tensor<512x4096xf32> to tensor<512x4096xf16>" ,
4139 "return %3 : tensor<512x4096xf16>" ,
4240 ],
@@ -64,9 +62,7 @@ def test_n_t_bf16_f32_bf16():
6462 "%cst = arith.constant 0.000000e+00 : f32" ,
6563 "%0 = tensor.empty() : tensor<2x1280xf32>" ,
6664 "%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x1280xf32>) -> tensor<2x1280xf32>" ,
67- "%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<2x8192xbf16>, tensor<1280x8192xbf16>)" ,
68- "outs(%1 : tensor<2x1280xf32>)" ,
69- "-> tensor<2x1280xf32>" ,
65+ "%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<2x8192xbf16>, tensor<1280x8192xbf16>) outs(%1 : tensor<2x1280xf32>) -> tensor<2x1280xf32>" ,
7066 "%3 = arith.truncf %2 : tensor<2x1280xf32> to tensor<2x1280xbf16>" ,
7167 "return %3 : tensor<2x1280xbf16>" ,
7268 ],
@@ -94,9 +90,7 @@ def test_t_n_f16_f32_f16():
9490 "%cst = arith.constant 0.000000e+00 : f32" ,
9591 "%0 = tensor.empty() : tensor<32000x1xf32>" ,
9692 "%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32000x1xf32>) -> tensor<32000x1xf32>" ,
97- "%2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<5120x32000xf16>, tensor<5120x1xf16>)" ,
98- "outs(%1 : tensor<32000x1xf32>)" ,
99- "-> tensor<32000x1xf32>" ,
93+ "%2 = linalg.matmul_transpose_a {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<5120x32000xf16>, tensor<5120x1xf16>) outs(%1 : tensor<32000x1xf32>) -> tensor<32000x1xf32>" ,
10094 "%3 = arith.truncf %2 : tensor<32000x1xf32> to tensor<32000x1xf16>" ,
10195 "return %3 : tensor<32000x1xf16>" ,
10296 ],
@@ -124,9 +118,7 @@ def test_t_n_bf16_f32_bf16():
124118 "%cst = arith.constant 0.000000e+00 : f32" ,
125119 "%0 = tensor.empty() : tensor<32000x1xf32>" ,
126120 "%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32000x1xf32>) -> tensor<32000x1xf32>" ,
127- "%2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<5120x32000xbf16>, tensor<5120x1xbf16>)" ,
128- "outs(%1 : tensor<32000x1xf32>)" ,
129- "-> tensor<32000x1xf32>" ,
121+ "%2 = linalg.matmul_transpose_a {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<5120x32000xbf16>, tensor<5120x1xbf16>) outs(%1 : tensor<32000x1xf32>) -> tensor<32000x1xf32>" ,
130122 "%3 = arith.truncf %2 : tensor<32000x1xf32> to tensor<32000x1xbf16>" ,
131123 "return %3 : tensor<32000x1xbf16>" ,
132124 ],
@@ -154,9 +146,7 @@ def test_n_n_f16_f32_f16():
154146 "%cst = arith.constant 0.000000e+00 : f32" ,
155147 "%0 = tensor.empty() : tensor<2048x2048xf32>" ,
156148 "%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2048x2048xf32>) -> tensor<2048x2048xf32>" ,
157- "%2 = linalg.matmul ins(%arg0, %arg1 : tensor<2048x1024xf16>, tensor<1024x2048xf16>)" ,
158- "outs(%1 : tensor<2048x2048xf32>)" ,
159- "-> tensor<2048x2048xf32>" ,
149+ "%2 = linalg.matmul ins(%arg0, %arg1 : tensor<2048x1024xf16>, tensor<1024x2048xf16>) outs(%1 : tensor<2048x2048xf32>) -> tensor<2048x2048xf32>" ,
160150 "%3 = arith.truncf %2 : tensor<2048x2048xf32> to tensor<2048x2048xf16>" ,
161151 "return %3 : tensor<2048x2048xf16>" ,
162152 ],
@@ -181,12 +171,10 @@ def test_n_t_i8_i32_i8():
181171 [
182172 "module {" ,
183173 "func.func @main(%arg0: tensor<128x128xi8>, %arg1: tensor<128x128xi8>) -> tensor<128x128xi8> {" ,
184- "%cst = arith.constant 0 : i32" ,
174+ "%c0_i32 = arith.constant 0 : i32" ,
185175 "%0 = tensor.empty() : tensor<128x128xi32>" ,
186- "%1 = linalg.fill ins(%cst : i32) outs(%0 : tensor<128x128xi32>) -> tensor<128x128xi32>" ,
187- "%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<128x128xi8>, tensor<128x128xi8>)" ,
188- "outs(%1 : tensor<128x128xi32>)" ,
189- "-> tensor<128x128xi32>" ,
176+ "%1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<128x128xi32>) -> tensor<128x128xi32>" ,
177+ "%2 = linalg.matmul_transpose_b {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<128x128xi8>, tensor<128x128xi8>) outs(%1 : tensor<128x128xi32>) -> tensor<128x128xi32>" ,
190178 "%3 = arith.trunci %2 : tensor<128x128xi32> to tensor<128x128xi8>" ,
191179 "return %3 : tensor<128x128xi8>" ,
192180 ],
0 commit comments