@@ -159,3 +159,88 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds:
159159 : f32 , memref <128 x72 xf32 , 1 >, memref <?x?xf32 , 3 >
160160 func.return
161161}
162+
163+ // -----
164+
165+ // CHECK-LABEL: func @scaled_mfma
166+ // CHECK: %[[SCALE_1:.*]] = vector.extract_strided_slice %0 {offsets = [0], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
167+ // CHECK: %[[SCALE_2:.*]] = vector.extract_strided_slice %2 {offsets = [4], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
168+ // CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
169+ // CHECK: %[[SCALE_3:.*]] = vector.extract_strided_slice %5 {offsets = [8], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
170+ // CHECK: %[[SCALE_4:.*]] = vector.extract_strided_slice %7 {offsets = [12], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
171+ // CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
172+ func.func @scaled_mfma (%opA: vector <32 xf4 E2 M1 FN>, %opB: vector <32 xf4 E2 M1 FN>, %scalesA: vector <2 x1 x8 x1 xf8 E8 M0 FNU>, %scalesB: vector <2 x1 x8 x1 xf8 E8 M0 FNU>) -> (vector <4 xf32 >, vector <4 xf32 >) {
173+ %cst_0 = arith.constant dense <0.000000e+00 > : vector <4 xf32 >
174+ %cst_1 = arith.constant dense <5.877470e-39 > : vector <4 xf8 E8 M0 FNU>
175+ %scaleA = vector.extract %scalesA [0 , 0 , 3 , 0 ] : f8E8M0FNU from vector <2 x1 x8 x1 xf8 E8 M0 FNU>
176+ %sA = vector.insert %scaleA , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
177+ %scaleB = vector.extract %scalesB [0 , 0 , 6 , 0 ] : f8E8M0FNU from vector <2 x1 x8 x1 xf8 E8 M0 FNU>
178+ %sB = vector.insert %scaleB , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
179+ %res_0 = amdgpu.scaled_mfma (%sA [0 ] * %opA ) * (%sB [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
180+ %scaleC = vector.extract %scalesA [1 , 0 , 1 , 0 ] : f8E8M0FNU from vector <2 x1 x8 x1 xf8 E8 M0 FNU>
181+ %sC = vector.insert %scaleC , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
182+ %scaleD = vector.extract %scalesB [1 , 0 , 4 , 0 ] : f8E8M0FNU from vector <2 x1 x8 x1 xf8 E8 M0 FNU>
183+ %sD = vector.insert %scaleD , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
184+ %res_1 = amdgpu.scaled_mfma (%sC [0 ] * %opA ) * (%sD [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
185+ return %res_0 , %res_1 : vector <4 xf32 >, vector <4 xf32 >
186+ }
187+
188+ // -----
189+
190+ // CHECK-LABEL: func @scaled_mfma_less_than_4
191+ // CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU>
192+ // CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
193+ // CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU>
194+ // CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
195+ // CHECK: amdgpu.scaled_mfma({{.*}}[0] * {{.*}}) * ({{.*}}[0] * {{.*}}
196+ func.func @scaled_mfma_less_than_4 (%opA: vector <32 xf4 E2 M1 FN>, %opB: vector <32 xf4 E2 M1 FN>, %scalesA: vector <2 xf8 E8 M0 FNU>, %scalesB: vector <2 xf8 E8 M0 FNU>) -> vector <4 xf32 > {
197+ %cst_0 = arith.constant dense <0.000000e+00 > : vector <4 xf32 >
198+ %cst_1 = arith.constant dense <5.877470e-39 > : vector <4 xf8 E8 M0 FNU>
199+ %scaleA = vector.extract %scalesA [0 ] : f8E8M0FNU from vector <2 xf8 E8 M0 FNU>
200+ %sA = vector.insert %scaleA , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
201+ %scaleB = vector.extract %scalesB [1 ] : f8E8M0FNU from vector <2 xf8 E8 M0 FNU>
202+ %sB = vector.insert %scaleB , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
203+ %res_0 = amdgpu.scaled_mfma (%sA [0 ] * %opA ) * (%sB [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
204+ return %res_0 : vector <4 xf32 >
205+ }
206+
207+ // -----
208+
209+ // CHECK-LABEL: func @scaled_mfma_ugly_shapes
210+ // CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
211+ // CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
212+ // CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
213+ // CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
214+ func.func @scaled_mfma_ugly_shapes (%opA: vector <32 xf4 E2 M1 FN>, %opB: vector <32 xf4 E2 M1 FN>, %scalesA: vector <5 x5 xf8 E8 M0 FNU>, %scalesB: vector <7 x23 xf8 E8 M0 FNU>) -> (vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >) {
215+ %cst_0 = arith.constant dense <0.000000e+00 > : vector <4 xf32 >
216+ %cst_1 = arith.constant dense <5.877470e-39 > : vector <4 xf8 E8 M0 FNU>
217+ %scaleA_0_4 = vector.extract %scalesA [4 , 0 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
218+ %scaleA_0_5 = vector.extract %scalesA [4 , 1 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
219+ %scaleA_0_6 = vector.extract %scalesA [4 , 2 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
220+ %scaleA_0_7 = vector.extract %scalesA [4 , 3 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
221+
222+ // idx = 160 => opsel = 3 (last idx of last 4 bytes)
223+ %scaleB_6_22 = vector.extract %scalesB [6 , 22 ] : f8E8M0FNU from vector <7 x23 xf8 E8 M0 FNU>
224+ // idx = 159 => opsel = 3
225+ %scaleB_6_21 = vector.extract %scalesB [6 , 21 ] : f8E8M0FNU from vector <7 x23 xf8 E8 M0 FNU>
226+ // idx = 158 => opsel = 2
227+ %scaleB_6_20 = vector.extract %scalesB [6 , 20 ] : f8E8M0FNU from vector <7 x23 xf8 E8 M0 FNU>
228+ // idx = 157 => opsel = 1
229+ %scaleB_6_19 = vector.extract %scalesB [6 , 19 ] : f8E8M0FNU from vector <7 x23 xf8 E8 M0 FNU>
230+
231+ %sA_0_4 = vector.insert %scaleA_0_4 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
232+ %sA_0_5 = vector.insert %scaleA_0_5 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
233+ %sA_0_6 = vector.insert %scaleA_0_6 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
234+ %sA_0_7 = vector.insert %scaleA_0_7 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
235+
236+ %sB_6_22 = vector.insert %scaleB_6_22 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
237+ %sB_6_21 = vector.insert %scaleB_6_21 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
238+ %sB_6_20 = vector.insert %scaleB_6_20 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
239+ %sB_6_19 = vector.insert %scaleB_6_19 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
240+
241+ %res_4 = amdgpu.scaled_mfma (%sA_0_4 [0 ] * %opA ) * (%sB_6_22 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
242+ %res_5 = amdgpu.scaled_mfma (%sA_0_5 [0 ] * %opA ) * (%sB_6_21 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
243+ %res_6 = amdgpu.scaled_mfma (%sA_0_6 [0 ] * %opA ) * (%sB_6_20 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
244+ %res_7 = amdgpu.scaled_mfma (%sA_0_7 [0 ] * %opA ) * (%sB_6_19 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
245+ return %res_4 , %res_5 , %res_6 , %res_7 : vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >
246+ }
0 commit comments