@@ -52,6 +52,8 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
5252 func.return
5353}
5454
55+ // CHECK-LABEL: func @scaled_mfma_to_rocdl(
56+ // CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<32xf8E4M3FN>, %[[ARG3:.*]]: vector<32xf8E5M2>, %[[ARG4:.*]]: vector<32xf6E2M3FN>, %[[ARG5:.*]]: vector<32xf6E3M2FN>, %[[ARG6:.*]]: vector<32xf4E2M1FN>, %[[ARG7:.*]]: vector<4xi8>, %[[ARG8:.*]]: i8
5557func.func @scaled_mfma_to_rocdl (%arg0 : vector <16 xf32 >,
5658 %arg1 : vector <4 xf32 >, %arg2 : vector <32 xf8 E4 M3 FN>,
5759 %arg3 : vector <32 xf8 E5 M2 >, %arg4 : vector <32 xf6 E2 M3 FN>,
@@ -60,42 +62,42 @@ func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
6062
6163 // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
6264 // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
63- // CHECK: %[[c2 :.+]] = llvm.bitcast{{.*}} : vector<4xi8> to i32
64- // CHECK: %[[c3 :.+]] = llvm.zext{{.*}} : i8 to i32
65+ // CHECK: %[[b0 :.+]] = llvm.bitcast %[[ARG7]] : vector<4xi8> to i32
66+ // CHECK: %[[z0 :.+]] = llvm.zext %[[ARG8]] : i8 to i32
6567
66- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2 ]], %[[c1]], %[[c3 ]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
68+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0 ]], %[[c1]], %[[z0 ]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
6769 amdgpu.scaled_mfma (%arg7 [0 ] * %arg2 ) * (%arg8 [1 ] * %arg2 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xi8 >, vector <32 xf8 E4 M3 FN>, i8 , vector <32 xf8 E4 M3 FN>, vector <16 xf32 >
68- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2 ]], %[[c1]], %[[c3 ]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
70+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0 ]], %[[c1]], %[[z0 ]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
6971 amdgpu.scaled_mfma (%arg7 [0 ] * %arg2 ) * (%arg8 [1 ] * %arg2 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xi8 >, vector <32 xf8 E4 M3 FN>, i8 , vector <32 xf8 E4 M3 FN>, vector <4 xf32 >
7072
7173 // CHECK: llvm.bitcast
7274
73- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2 ]], %[[c1]], %[[c3 ]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
75+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0 ]], %[[c1]], %[[z0 ]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
7476 amdgpu.scaled_mfma (%arg7 [0 ] * %arg3 ) * (%arg8 [1 ] * %arg3 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xi8 >, vector <32 xf8 E5 M2 >, i8 , vector <32 xf8 E5 M2 >, vector <16 xf32 >
75- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2 ]], %[[c1]], %[[c3 ]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
77+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0 ]], %[[c1]], %[[z0 ]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
7678 amdgpu.scaled_mfma (%arg7 [0 ] * %arg3 ) * (%arg8 [1 ] * %arg3 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xi8 >, vector <32 xf8 E5 M2 >, i8 , vector <32 xf8 E5 M2 >, vector <4 xf32 >
7779
7880 // CHECK: llvm.bitcast
7981
80- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2 ]], %[[c1]], %[[c3 ]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
82+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0 ]], %[[c1]], %[[z0 ]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
8183 amdgpu.scaled_mfma (%arg7 [0 ] * %arg4 ) * (%arg8 [1 ] * %arg4 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xi8 >, vector <32 xf6 E2 M3 FN>, i8 , vector <32 xf6 E2 M3 FN>, vector <16 xf32 >
82- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2 ]], %[[c1]], %[[c3 ]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
84+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0 ]], %[[c1]], %[[z0 ]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
8385 amdgpu.scaled_mfma (%arg7 [0 ] * %arg4 ) * (%arg8 [1 ] * %arg4 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xi8 >, vector <32 xf6 E2 M3 FN>, i8 , vector <32 xf6 E2 M3 FN>, vector <4 xf32 >
8486
8587 // CHECK: llvm.bitcast
8688 // CHECK: llvm.mlir.constant(3 : i32) : i32
8789
88- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2 ]], %[[c1]], %[[c3 ]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
90+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0 ]], %[[c1]], %[[z0 ]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
8991 amdgpu.scaled_mfma (%arg7 [0 ] * %arg5 ) * (%arg8 [1 ] * %arg5 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xi8 >, vector <32 xf6 E3 M2 FN>, i8 , vector <32 xf6 E3 M2 FN>, vector <16 xf32 >
90- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2 ]], %[[c1]], %[[c3 ]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
92+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0 ]], %[[c1]], %[[z0 ]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
9193 amdgpu.scaled_mfma (%arg7 [0 ] * %arg5 ) * (%arg8 [1 ] * %arg5 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xi8 >, vector <32 xf6 E3 M2 FN>, i8 , vector <32 xf6 E3 M2 FN>, vector <4 xf32 >
9294
9395 // CHECK: llvm.bitcast
9496 // CHECK: llvm.mlir.constant(4 : i32) : i32
9597
96- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2 ]], %[[c1]], %[[c3 ]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
98+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0 ]], %[[c1]], %[[z0 ]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
9799 amdgpu.scaled_mfma (%arg7 [0 ] * %arg6 ) * (%arg8 [1 ] * %arg6 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xi8 >, vector <32 xf4 E2 M1 FN>, i8 , vector <32 xf4 E2 M1 FN>, vector <16 xf32 >
98- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2 ]], %[[c1]], %[[c3 ]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
100+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0 ]], %[[c1]], %[[z0 ]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
99101 amdgpu.scaled_mfma (%arg7 [0 ] * %arg6 ) * (%arg8 [1 ] * %arg6 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xi8 >, vector <32 xf4 E2 M1 FN>, i8 , vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
100102
101103 func.return
0 commit comments