Skip to content

Commit 45bb926

Browse files
authored
[mlir][tosa] Add e2e tests for matmul_t_block_scaled (#166567)
Added tests for dtypes fp6e3m2, fp6e2m3, fp4e2m1, mxint8 Signed-off-by: Udaya Ranga <[email protected]>
1 parent d47fdfe commit 45bb926

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,42 @@ func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>,
12761276
return %0 : tensor<4x8x16xf32>
12771277
}
12781278

1279+
// -----
1280+
// CHECK-LABEL: test_matmul_t_block_scaled_fp6e3m2_e2e
1281+
func.func @test_matmul_t_block_scaled_fp6e3m2_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> {
1282+
%a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32xf6E3M2FN>, tensor<6x2x1xf8E8M0FNU>)
1283+
%b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32xf6E3M2FN>, tensor<6x64x1xf8E8M0FNU>)
1284+
%res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf6E3M2FN>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32xf6E3M2FN>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32>
1285+
return %res : tensor<6x2x64xf32>
1286+
}
1287+
1288+
// -----
1289+
// CHECK-LABEL: test_matmul_t_block_scaled_fp6e2m3_e2e
1290+
func.func @test_matmul_t_block_scaled_fp6e2m3_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> {
1291+
%a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32xf6E2M3FN>, tensor<6x2x1xf8E8M0FNU>)
1292+
%b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32xf6E2M3FN>, tensor<6x64x1xf8E8M0FNU>)
1293+
%res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf6E2M3FN>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32xf6E2M3FN>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32>
1294+
return %res : tensor<6x2x64xf32>
1295+
}
1296+
1297+
// -----
1298+
// CHECK-LABEL: test_matmul_t_block_scaled_fp4e2m1_e2e
1299+
func.func @test_matmul_t_block_scaled_fp4e2m1_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> {
1300+
%a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32xf4E2M1FN>, tensor<6x2x1xf8E8M0FNU>)
1301+
%b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32xf4E2M1FN>, tensor<6x64x1xf8E8M0FNU>)
1302+
%res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf4E2M1FN>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32xf4E2M1FN>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32>
1303+
return %res : tensor<6x2x64xf32>
1304+
}
1305+
1306+
// -----
1307+
// CHECK-LABEL: test_matmul_t_block_scaled_mxint8_e2e
1308+
func.func @test_matmul_t_block_scaled_mxint8_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> {
1309+
%a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32x!tosa.mxint8>, tensor<6x2x1xf8E8M0FNU>)
1310+
%b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32x!tosa.mxint8>, tensor<6x64x1xf8E8M0FNU>)
1311+
%res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<6x2x32x!tosa.mxint8>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32x!tosa.mxint8>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32>
1312+
return %res : tensor<6x2x64xf32>
1313+
}
1314+
12791315
// -----
12801316
// CHECK-LABEL: test_cast_from_block_scaled_static
12811317
func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
@@ -1307,7 +1343,7 @@ func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<*
13071343
// -----
13081344
// CHECK-LABEL: test_cast_to_block_scaled_mxint8
13091345
func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) {
1310-
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
1346+
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
13111347
return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>
13121348
}
13131349

0 commit comments

Comments
 (0)