@@ -51,3 +51,54 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
5151
5252 func.return
5353}
54+
55+ // CHECK-LABEL: func @scaled_mfma_to_rocdl(
56+ // CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<32xf8E4M3FN>, %[[ARG3:.*]]: vector<32xf8E5M2>, %[[ARG4:.*]]: vector<32xf6E2M3FN>, %[[ARG5:.*]]: vector<32xf6E3M2FN>, %[[ARG6:.*]]: vector<32xf4E2M1FN>, %[[ARG7:.*]]: vector<4xf8E8M0FNU>, %[[ARG8:.*]]: f8E8M0FNU
57+ func.func @scaled_mfma_to_rocdl (%arg0 : vector <16 xf32 >,
58+ %arg1 : vector <4 xf32 >, %arg2 : vector <32 xf8 E4 M3 FN>,
59+ %arg3 : vector <32 xf8 E5 M2 >, %arg4 : vector <32 xf6 E2 M3 FN>,
60+ %arg5 : vector <32 xf6 E3 M2 FN>, %arg6 : vector <32 xf4 E2 M1 FN>,
61+ %arg7 : vector <4 xf8 E8 M0 FNU>, %arg8 : f8E8M0FNU ) {
62+
63+ // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
64+ // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
65+ // CHECK: %[[b0:.+]] = llvm.bitcast {{.*}} : vector<4xi8> to i32
66+ // CHECK: %[[z0:.+]] = llvm.zext {{.*}} : i8 to i32
67+
68+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
69+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg2 ) * (%arg8 [1 ] * %arg2 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf8 E4 M3 FN>, f8E8M0FNU , vector <32 xf8 E4 M3 FN>, vector <16 xf32 >
70+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
71+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg2 ) * (%arg8 [1 ] * %arg2 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf8 E4 M3 FN>, f8E8M0FNU , vector <32 xf8 E4 M3 FN>, vector <4 xf32 >
72+
73+ // CHECK: llvm.bitcast
74+
75+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
76+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg3 ) * (%arg8 [1 ] * %arg3 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf8 E5 M2 >, f8E8M0FNU , vector <32 xf8 E5 M2 >, vector <16 xf32 >
77+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
78+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg3 ) * (%arg8 [1 ] * %arg3 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf8 E5 M2 >, f8E8M0FNU , vector <32 xf8 E5 M2 >, vector <4 xf32 >
79+
80+ // CHECK: llvm.bitcast
81+
82+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
83+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg4 ) * (%arg8 [1 ] * %arg4 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf6 E2 M3 FN>, f8E8M0FNU , vector <32 xf6 E2 M3 FN>, vector <16 xf32 >
84+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
85+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg4 ) * (%arg8 [1 ] * %arg4 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf6 E2 M3 FN>, f8E8M0FNU , vector <32 xf6 E2 M3 FN>, vector <4 xf32 >
86+
87+ // CHECK: llvm.bitcast
88+ // CHECK: llvm.mlir.constant(3 : i32) : i32
89+
90+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
91+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg5 ) * (%arg8 [1 ] * %arg5 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf6 E3 M2 FN>, f8E8M0FNU , vector <32 xf6 E3 M2 FN>, vector <16 xf32 >
92+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
93+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg5 ) * (%arg8 [1 ] * %arg5 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf6 E3 M2 FN>, f8E8M0FNU , vector <32 xf6 E3 M2 FN>, vector <4 xf32 >
94+
95+ // CHECK: llvm.bitcast
96+ // CHECK: llvm.mlir.constant(4 : i32) : i32
97+
98+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
99+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg6 ) * (%arg8 [1 ] * %arg6 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, f8E8M0FNU , vector <32 xf4 E2 M1 FN>, vector <16 xf32 >
100+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
101+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg6 ) * (%arg8 [1 ] * %arg6 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, f8E8M0FNU , vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
102+
103+ func.return
104+ }
0 commit comments