@@ -1276,6 +1276,42 @@ func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>,
12761276 return %0 : tensor <4 x8 x16 xf32 >
12771277}
12781278
1279+ // -----
1280+ // CHECK-LABEL: test_matmul_t_block_scaled_fp6e3m2_e2e
1281+ func.func @test_matmul_t_block_scaled_fp6e3m2_e2e (%arg0: tensor <6 x2 x32 xf32 >, %arg1: tensor <6 x64 x32 xf32 >) -> tensor <6 x2 x64 xf32 > {
1282+ %a , %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size <BLOCK_SIZE_32 > : i32 } : (tensor <6 x2 x32 xf32 >) -> (tensor <6 x2 x32 xf6 E3 M2 FN>, tensor <6 x2 x1 xf8 E8 M0 FNU>)
1283+ %b , %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size <BLOCK_SIZE_32 > : i32 } : (tensor <6 x64 x32 xf32 >) -> (tensor <6 x64 x32 xf6 E3 M2 FN>, tensor <6 x64 x1 xf8 E8 M0 FNU>)
1284+ %res = tosa.matmul_t_block_scaled %a , %sa , %b , %sb {block_size = #tosa.block_size <BLOCK_SIZE_32 > : i32 } : (tensor <6 x2 x32 xf6 E3 M2 FN>, tensor <6 x2 x1 xf8 E8 M0 FNU>, tensor <6 x64 x32 xf6 E3 M2 FN>, tensor <6 x64 x1 xf8 E8 M0 FNU>) -> tensor <6 x2 x64 xf32 >
1285+ return %res : tensor <6 x2 x64 xf32 >
1286+ }
1287+
1288+ // -----
1289+ // CHECK-LABEL: test_matmul_t_block_scaled_fp6e2m3_e2e
1290+ func.func @test_matmul_t_block_scaled_fp6e2m3_e2e (%arg0: tensor <6 x2 x32 xf32 >, %arg1: tensor <6 x64 x32 xf32 >) -> tensor <6 x2 x64 xf32 > {
1291+ %a , %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size <BLOCK_SIZE_32 > : i32 } : (tensor <6 x2 x32 xf32 >) -> (tensor <6 x2 x32 xf6 E2 M3 FN>, tensor <6 x2 x1 xf8 E8 M0 FNU>)
1292+ %b , %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size <BLOCK_SIZE_32 > : i32 } : (tensor <6 x64 x32 xf32 >) -> (tensor <6 x64 x32 xf6 E2 M3 FN>, tensor <6 x64 x1 xf8 E8 M0 FNU>)
1293+ %res = tosa.matmul_t_block_scaled %a , %sa , %b , %sb {block_size = #tosa.block_size <BLOCK_SIZE_32 > : i32 } : (tensor <6 x2 x32 xf6 E2 M3 FN>, tensor <6 x2 x1 xf8 E8 M0 FNU>, tensor <6 x64 x32 xf6 E2 M3 FN>, tensor <6 x64 x1 xf8 E8 M0 FNU>) -> tensor <6 x2 x64 xf32 >
1294+ return %res : tensor <6 x2 x64 xf32 >
1295+ }
1296+
1297+ // -----
1298+ // CHECK-LABEL: test_matmul_t_block_scaled_fp4e2m1_e2e
1299+ func.func @test_matmul_t_block_scaled_fp4e2m1_e2e (%arg0: tensor <6 x2 x32 xf32 >, %arg1: tensor <6 x64 x32 xf32 >) -> tensor <6 x2 x64 xf32 > {
1300+ %a , %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size <BLOCK_SIZE_32 > : i32 } : (tensor <6 x2 x32 xf32 >) -> (tensor <6 x2 x32 xf4 E2 M1 FN>, tensor <6 x2 x1 xf8 E8 M0 FNU>)
1301+ %b , %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size <BLOCK_SIZE_32 > : i32 } : (tensor <6 x64 x32 xf32 >) -> (tensor <6 x64 x32 xf4 E2 M1 FN>, tensor <6 x64 x1 xf8 E8 M0 FNU>)
1302+ %res = tosa.matmul_t_block_scaled %a , %sa , %b , %sb {block_size = #tosa.block_size <BLOCK_SIZE_32 > : i32 } : (tensor <6 x2 x32 xf4 E2 M1 FN>, tensor <6 x2 x1 xf8 E8 M0 FNU>, tensor <6 x64 x32 xf4 E2 M1 FN>, tensor <6 x64 x1 xf8 E8 M0 FNU>) -> tensor <6 x2 x64 xf32 >
1303+ return %res : tensor <6 x2 x64 xf32 >
1304+ }
1305+
1306+ // -----
1307+ // CHECK-LABEL: test_matmul_t_block_scaled_mxint8_e2e
1308+ func.func @test_matmul_t_block_scaled_mxint8_e2e (%arg0: tensor <6 x2 x32 xf32 >, %arg1: tensor <6 x64 x32 xf32 >) -> tensor <6 x2 x64 xf32 > {
1309+ %a , %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size <BLOCK_SIZE_32 > : i32 } : (tensor <6 x2 x32 xf32 >) -> (tensor <6 x2 x32 x!tosa.mxint8 >, tensor <6 x2 x1 xf8 E8 M0 FNU>)
1310+ %b , %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size <BLOCK_SIZE_32 > : i32 } : (tensor <6 x64 x32 xf32 >) -> (tensor <6 x64 x32 x!tosa.mxint8 >, tensor <6 x64 x1 xf8 E8 M0 FNU>)
1311+ %res = tosa.matmul_t_block_scaled %a , %sa , %b , %sb {block_size = #tosa.block_size <BLOCK_SIZE_32 > : i32 } : (tensor <6 x2 x32 x!tosa.mxint8 >, tensor <6 x2 x1 xf8 E8 M0 FNU>, tensor <6 x64 x32 x!tosa.mxint8 >, tensor <6 x64 x1 xf8 E8 M0 FNU>) -> tensor <6 x2 x64 xf32 >
1312+ return %res : tensor <6 x2 x64 xf32 >
1313+ }
1314+
12791315// -----
12801316// CHECK-LABEL: test_cast_from_block_scaled_static
12811317func.func @test_cast_from_block_scaled_static (%arg0: tensor <4 x32 xf4 E2 M1 FN>, %arg1: tensor <4 x1 xf8 E8 M0 FNU>) -> tensor <4 x32 xf32 > {
@@ -1307,7 +1343,7 @@ func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<*
13071343// -----
13081344// CHECK-LABEL: test_cast_to_block_scaled_mxint8
13091345func.func @test_cast_to_block_scaled_mxint8 (%arg0: tensor <4 x32 xf32 >) -> (tensor <4 x32 x!tosa.mxint8 >, tensor <4 x1 xf8 E8 M0 FNU>) {
1310- %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size <BLOCK_SIZE_32 > : i32 , stochastic_round = false } : (tensor <4 x32 xf32 >) -> (tensor <4 x32 x!tosa.mxint8 >, tensor <4 x1 xf8 E8 M0 FNU>)
1346+ %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size <BLOCK_SIZE_32 > : i32 } : (tensor <4 x32 xf32 >) -> (tensor <4 x32 x!tosa.mxint8 >, tensor <4 x1 xf8 E8 M0 FNU>)
13111347 return %0#0 , %0#1 : tensor <4 x32 x!tosa.mxint8 >, tensor <4 x1 xf8 E8 M0 FNU>
13121348}
13131349
0 commit comments