@@ -159,3 +159,88 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds:
159
159
: f32 , memref <128 x72 xf32 , 1 >, memref <?x?xf32 , 3 >
160
160
func.return
161
161
}
162
+
163
+ // -----
164
+
165
+ // CHECK-LABEL: func @scaled_mfma
166
+ // CHECK: %[[SCALE_1:.*]] = vector.extract_strided_slice %0 {offsets = [0], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
167
+ // CHECK: %[[SCALE_2:.*]] = vector.extract_strided_slice %2 {offsets = [4], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
168
+ // CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
169
+ // CHECK: %[[SCALE_3:.*]] = vector.extract_strided_slice %5 {offsets = [8], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
170
+ // CHECK: %[[SCALE_4:.*]] = vector.extract_strided_slice %7 {offsets = [12], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
171
+ // CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
172
+ func.func @scaled_mfma (%opA: vector <32 xf4 E2 M1 FN>, %opB: vector <32 xf4 E2 M1 FN>, %scalesA: vector <2 x1 x8 x1 xf8 E8 M0 FNU>, %scalesB: vector <2 x1 x8 x1 xf8 E8 M0 FNU>) -> (vector <4 xf32 >, vector <4 xf32 >) {
173
+ %cst_0 = arith.constant dense <0.000000e+00 > : vector <4 xf32 >
174
+ %cst_1 = arith.constant dense <5.877470e-39 > : vector <4 xf8 E8 M0 FNU>
175
+ %scaleA = vector.extract %scalesA [0 , 0 , 3 , 0 ] : f8E8M0FNU from vector <2 x1 x8 x1 xf8 E8 M0 FNU>
176
+ %sA = vector.insert %scaleA , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
177
+ %scaleB = vector.extract %scalesB [0 , 0 , 6 , 0 ] : f8E8M0FNU from vector <2 x1 x8 x1 xf8 E8 M0 FNU>
178
+ %sB = vector.insert %scaleB , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
179
+ %res_0 = amdgpu.scaled_mfma (%sA [0 ] * %opA ) * (%sB [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
180
+ %scaleC = vector.extract %scalesA [1 , 0 , 1 , 0 ] : f8E8M0FNU from vector <2 x1 x8 x1 xf8 E8 M0 FNU>
181
+ %sC = vector.insert %scaleC , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
182
+ %scaleD = vector.extract %scalesB [1 , 0 , 4 , 0 ] : f8E8M0FNU from vector <2 x1 x8 x1 xf8 E8 M0 FNU>
183
+ %sD = vector.insert %scaleD , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
184
+ %res_1 = amdgpu.scaled_mfma (%sC [0 ] * %opA ) * (%sD [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
185
+ return %res_0 , %res_1 : vector <4 xf32 >, vector <4 xf32 >
186
+ }
187
+
188
+ // -----
189
+
190
+ // CHECK-LABEL: func @scaled_mfma_less_than_4
191
+ // CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU>
192
+ // CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
193
+ // CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU>
194
+ // CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
195
+ // CHECK: amdgpu.scaled_mfma({{.*}}[0] * {{.*}}) * ({{.*}}[0] * {{.*}}
196
+ func.func @scaled_mfma_less_than_4 (%opA: vector <32 xf4 E2 M1 FN>, %opB: vector <32 xf4 E2 M1 FN>, %scalesA: vector <2 xf8 E8 M0 FNU>, %scalesB: vector <2 xf8 E8 M0 FNU>) -> vector <4 xf32 > {
197
+ %cst_0 = arith.constant dense <0.000000e+00 > : vector <4 xf32 >
198
+ %cst_1 = arith.constant dense <5.877470e-39 > : vector <4 xf8 E8 M0 FNU>
199
+ %scaleA = vector.extract %scalesA [0 ] : f8E8M0FNU from vector <2 xf8 E8 M0 FNU>
200
+ %sA = vector.insert %scaleA , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
201
+ %scaleB = vector.extract %scalesB [1 ] : f8E8M0FNU from vector <2 xf8 E8 M0 FNU>
202
+ %sB = vector.insert %scaleB , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
203
+ %res_0 = amdgpu.scaled_mfma (%sA [0 ] * %opA ) * (%sB [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
204
+ return %res_0 : vector <4 xf32 >
205
+ }
206
+
207
+ // -----
208
+
209
+ // CHECK-LABEL: func @scaled_mfma_ugly_shapes
210
+ // CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
211
+ // CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
212
+ // CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
213
+ // CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
214
+ func.func @scaled_mfma_ugly_shapes (%opA: vector <32 xf4 E2 M1 FN>, %opB: vector <32 xf4 E2 M1 FN>, %scalesA: vector <5 x5 xf8 E8 M0 FNU>, %scalesB: vector <7 x23 xf8 E8 M0 FNU>) -> (vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >) {
215
+ %cst_0 = arith.constant dense <0.000000e+00 > : vector <4 xf32 >
216
+ %cst_1 = arith.constant dense <5.877470e-39 > : vector <4 xf8 E8 M0 FNU>
217
+ %scaleA_0_4 = vector.extract %scalesA [4 , 0 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
218
+ %scaleA_0_5 = vector.extract %scalesA [4 , 1 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
219
+ %scaleA_0_6 = vector.extract %scalesA [4 , 2 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
220
+ %scaleA_0_7 = vector.extract %scalesA [4 , 3 ] : f8E8M0FNU from vector <5 x5 xf8 E8 M0 FNU>
221
+
222
+ // idx = 160 => opsel = 3 (last idx of last 4 bytes)
223
+ %scaleB_6_22 = vector.extract %scalesB [6 , 22 ] : f8E8M0FNU from vector <7 x23 xf8 E8 M0 FNU>
224
+ // idx = 159 => opsel = 3
225
+ %scaleB_6_21 = vector.extract %scalesB [6 , 21 ] : f8E8M0FNU from vector <7 x23 xf8 E8 M0 FNU>
226
+ // idx = 158 => opsel = 2
227
+ %scaleB_6_20 = vector.extract %scalesB [6 , 20 ] : f8E8M0FNU from vector <7 x23 xf8 E8 M0 FNU>
228
+ // idx = 157 => opsel = 1
229
+ %scaleB_6_19 = vector.extract %scalesB [6 , 19 ] : f8E8M0FNU from vector <7 x23 xf8 E8 M0 FNU>
230
+
231
+ %sA_0_4 = vector.insert %scaleA_0_4 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
232
+ %sA_0_5 = vector.insert %scaleA_0_5 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
233
+ %sA_0_6 = vector.insert %scaleA_0_6 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
234
+ %sA_0_7 = vector.insert %scaleA_0_7 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
235
+
236
+ %sB_6_22 = vector.insert %scaleB_6_22 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
237
+ %sB_6_21 = vector.insert %scaleB_6_21 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
238
+ %sB_6_20 = vector.insert %scaleB_6_20 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
239
+ %sB_6_19 = vector.insert %scaleB_6_19 , %cst_1 [0 ] : f8E8M0FNU into vector <4 xf8 E8 M0 FNU>
240
+
241
+ %res_4 = amdgpu.scaled_mfma (%sA_0_4 [0 ] * %opA ) * (%sB_6_22 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
242
+ %res_5 = amdgpu.scaled_mfma (%sA_0_5 [0 ] * %opA ) * (%sB_6_21 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
243
+ %res_6 = amdgpu.scaled_mfma (%sA_0_6 [0 ] * %opA ) * (%sB_6_20 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
244
+ %res_7 = amdgpu.scaled_mfma (%sA_0_7 [0 ] * %opA ) * (%sB_6_19 [0 ] * %opB ) + %cst_0 {k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
245
+ return %res_4 , %res_5 , %res_6 , %res_7 : vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >
246
+ }
0 commit comments