@@ -163,11 +163,11 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds:
163163// ----- 
164164
165165// CHECK-LABEL: func @scaled_mfma 
166- // CHECK: %[[SCALE_1:.*]] = vector.extract %{{.*}} [0]  : vector<4xf8E8M0FNU> from  vector<4x4xf8E8M0FNU > 
167- // CHECK: %[[SCALE_2:.*]] = vector.extract %{{.*}}[1]  : vector<4xf8E8M0FNU> from  vector<4x4xf8E8M0FNU > 
166+ // CHECK: %[[SCALE_1:.*]] = vector.extract_strided_slice %0 {offsets =  [0], sizes = [4], strides = [1]}  : vector<16xf8E8M0FNU> to  vector<4xf8E8M0FNU > 
167+ // CHECK: %[[SCALE_2:.*]] = vector.extract_strided_slice %2 {offsets = [4], sizes = [4], strides = [1]}  : vector<16xf8E8M0FNU> to  vector<4xf8E8M0FNU > 
168168// CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}} 
169- // CHECK: %[[SCALE_3:.*]] = vector.extract %{{.*}}[2]  : vector<4xf8E8M0FNU> from  vector<4x4xf8E8M0FNU > 
170- // CHECK: %[[SCALE_4:.*]] = vector.extract %{{.*}}[3]  : vector<4xf8E8M0FNU> from  vector<4x4xf8E8M0FNU > 
169+ // CHECK: %[[SCALE_3:.*]] = vector.extract_strided_slice %5 {offsets = [8], sizes = [4], strides = [1]}  : vector<16xf8E8M0FNU> to  vector<4xf8E8M0FNU > 
170+ // CHECK: %[[SCALE_4:.*]] = vector.extract_strided_slice %7 {offsets = [12], sizes = [4], strides = [1]}  : vector<16xf8E8M0FNU> to  vector<4xf8E8M0FNU > 
171171// CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}} 
172172func.func  @scaled_mfma (%opA:  vector <32 xf4 E2 M1 FN>, %opB:  vector <32 xf4 E2 M1 FN>, %scalesA:  vector <2 x1 x8 x1 xf8 E8 M0 FNU>, %scalesB:  vector <2 x1 x8 x1 xf8 E8 M0 FNU>) -> (vector <4 xf32 >, vector <4 xf32 >) {
173173  %cst_0  = arith.constant  dense <0.000000e+00 > : vector <4 xf32 >
@@ -184,3 +184,92 @@ func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %sc
184184  %res_1  = amdgpu.scaled_mfma (%sC [0 ] * %opA ) * (%sD [0 ] * %opB ) + %cst_0  {k  = 128  : i32 , m  = 16  : i32 , n  = 16  : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
185185  return  %res_0 , %res_1  : vector <4 xf32 >, vector <4 xf32 >
186186}
187+ 
188+ // ----- 
189+ 
190+ // CHECK-LABEL: func @scaled_mfma_less_than_4 
191+ // CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU> 
192+ // CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU> 
193+ // CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU> 
194+ // CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU> 
195+ // CHECK: amdgpu.scaled_mfma({{.*}}[0] * {{.*}}) * ({{.*}}[0] * {{.*}} 
196+ func.func  @scaled_mfma_less_than_4 (%opA:  vector <32 xf4 E2 M1 FN>, %opB:  vector <32 xf4 E2 M1 FN>, %scalesA:  vector <2 xf8 E8 M0 FNU>, %scalesB:  vector <2 xf8 E8 M0 FNU>) -> vector <4 xf32 > {
197+   %cst_0  = arith.constant  dense <0.000000e+00 > : vector <4 xf32 >
198+   %cst_1  = arith.constant  dense <5.877470e-39 > : vector <4 xf8 E8 M0 FNU>
199+   %scaleA  = vector.extract  %scalesA [0 ] : f8E8M0FNU  from  vector <2 xf8 E8 M0 FNU>
200+   %sA  = vector.insert  %scaleA , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
201+   %scaleB  = vector.extract  %scalesB [1 ] : f8E8M0FNU  from  vector <2 xf8 E8 M0 FNU>
202+   %sB  = vector.insert  %scaleB , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
203+   %res_0  = amdgpu.scaled_mfma (%sA [0 ] * %opA ) * (%sB [0 ] * %opB ) + %cst_0  {k  = 128  : i32 , m  = 16  : i32 , n  = 16  : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
204+   return  %res_0  : vector <4 xf32 >
205+ }
206+ 
207+ 
208+ // ----- 
209+ 
210+ // CHECK-LABEL: func @scaled_mfma_ugly_shapes 
211+ // CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> 
212+ // CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> 
213+ // CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[0] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> 
214+ // CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> 
215+ // CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> 
216+ // CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> 
217+ // CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> 
218+ // CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> 
219+ func.func  @scaled_mfma_ugly_shapes (%opA:  vector <32 xf4 E2 M1 FN>, %opB:  vector <32 xf4 E2 M1 FN>, %scalesA:  vector <5 x5 xf8 E8 M0 FNU>, %scalesB:  vector <7 x23 xf8 E8 M0 FNU>) -> (vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >) {
220+   %cst_0  = arith.constant  dense <0.000000e+00 > : vector <4 xf32 >
221+   %cst_1  = arith.constant  dense <5.877470e-39 > : vector <4 xf8 E8 M0 FNU>
222+   %scaleA_0_0  = vector.extract  %scalesA [0 , 0 ] : f8E8M0FNU  from  vector <5 x5 xf8 E8 M0 FNU>
223+   %scaleA_0_1  = vector.extract  %scalesA [1 , 0 ] : f8E8M0FNU  from  vector <5 x5 xf8 E8 M0 FNU>
224+   %scaleA_0_2  = vector.extract  %scalesA [2 , 0 ] : f8E8M0FNU  from  vector <5 x5 xf8 E8 M0 FNU>
225+   %scaleA_0_3  = vector.extract  %scalesA [3 , 0 ] : f8E8M0FNU  from  vector <5 x5 xf8 E8 M0 FNU>
226+   %scaleA_0_4  = vector.extract  %scalesA [4 , 0 ] : f8E8M0FNU  from  vector <5 x5 xf8 E8 M0 FNU>
227+   %scaleA_0_5  = vector.extract  %scalesA [4 , 1 ] : f8E8M0FNU  from  vector <5 x5 xf8 E8 M0 FNU>
228+   %scaleA_0_6  = vector.extract  %scalesA [4 , 2 ] : f8E8M0FNU  from  vector <5 x5 xf8 E8 M0 FNU>
229+   %scaleA_0_7  = vector.extract  %scalesA [4 , 3 ] : f8E8M0FNU  from  vector <5 x5 xf8 E8 M0 FNU>
230+ 
231+   // idx = 138 + 8 = 146 => opsel = 2 
232+   %scaleB_6_8  = vector.extract  %scalesB [6 , 8 ] : f8E8M0FNU  from  vector <7 x23 xf8 E8 M0 FNU>
233+   // idx = 147 => opsel = 3 
234+   %scaleB_6_9  = vector.extract  %scalesB [6 , 9 ] : f8E8M0FNU  from  vector <7 x23 xf8 E8 M0 FNU>
235+   // idx = 148 => opsel = 0 
236+   %scaleB_6_10  = vector.extract  %scalesB [6 , 10 ] : f8E8M0FNU  from  vector <7 x23 xf8 E8 M0 FNU>
237+   // idx = 149 => opsel = 1 
238+   %scaleB_6_11  = vector.extract  %scalesB [6 , 11 ] : f8E8M0FNU  from  vector <7 x23 xf8 E8 M0 FNU>
239+   // idx = 160 => opsel = 3 (last idx of last 4 bytes) 
240+   %scaleB_6_22  = vector.extract  %scalesB [6 , 22 ] : f8E8M0FNU  from  vector <7 x23 xf8 E8 M0 FNU>
241+   // idx = 159 => opsel = 3 
242+   %scaleB_6_21  = vector.extract  %scalesB [6 , 21 ] : f8E8M0FNU  from  vector <7 x23 xf8 E8 M0 FNU>
243+   // idx = 158 => opsel = 2 
244+   %scaleB_6_20  = vector.extract  %scalesB [6 , 20 ] : f8E8M0FNU  from  vector <7 x23 xf8 E8 M0 FNU>
245+   // idx = 157 => opsel = 1 
246+   %scaleB_6_19  = vector.extract  %scalesB [6 , 19 ] : f8E8M0FNU  from  vector <7 x23 xf8 E8 M0 FNU>
247+ 
248+   %sA_0_0  = vector.insert  %scaleA_0_0 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
249+   %sA_0_1  = vector.insert  %scaleA_0_1 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
250+   %sA_0_2  = vector.insert  %scaleA_0_2 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
251+   %sA_0_3  = vector.insert  %scaleA_0_3 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
252+   %sA_0_4  = vector.insert  %scaleA_0_4 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
253+   %sA_0_5  = vector.insert  %scaleA_0_5 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
254+   %sA_0_6  = vector.insert  %scaleA_0_6 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
255+   %sA_0_7  = vector.insert  %scaleA_0_7 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
256+ 
257+   %sB_6_8  = vector.insert  %scaleB_6_8 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
258+   %sB_6_9  = vector.insert  %scaleB_6_9 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
259+   %sB_6_10  = vector.insert  %scaleB_6_10 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
260+   %sB_6_11  = vector.insert  %scaleB_6_11 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
261+   %sB_6_22  = vector.insert  %scaleB_6_22 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
262+   %sB_6_21  = vector.insert  %scaleB_6_21 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
263+   %sB_6_20  = vector.insert  %scaleB_6_20 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
264+   %sB_6_19  = vector.insert  %scaleB_6_19 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
265+   
266+   %res_0  = amdgpu.scaled_mfma (%sA_0_0 [0 ] * %opA ) * (%sB_6_8 [0 ] * %opB ) + %cst_0  {k  = 128  : i32 , m  = 16  : i32 , n  = 16  : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
267+   %res_1  = amdgpu.scaled_mfma (%sA_0_1 [0 ] * %opA ) * (%sB_6_9 [0 ] * %opB ) + %cst_0  {k  = 128  : i32 , m  = 16  : i32 , n  = 16  : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
268+   %res_2  = amdgpu.scaled_mfma (%sA_0_2 [0 ] * %opA ) * (%sB_6_10 [0 ] * %opB ) + %cst_0  {k  = 128  : i32 , m  = 16  : i32 , n  = 16  : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
269+   %res_3  = amdgpu.scaled_mfma (%sA_0_3 [0 ] * %opA ) * (%sB_6_11 [0 ] * %opB ) + %cst_0  {k  = 128  : i32 , m  = 16  : i32 , n  = 16  : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
270+   %res_4  = amdgpu.scaled_mfma (%sA_0_4 [0 ] * %opA ) * (%sB_6_22 [0 ] * %opB ) + %cst_0  {k  = 128  : i32 , m  = 16  : i32 , n  = 16  : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
271+   %res_5  = amdgpu.scaled_mfma (%sA_0_5 [0 ] * %opA ) * (%sB_6_21 [0 ] * %opB ) + %cst_0  {k  = 128  : i32 , m  = 16  : i32 , n  = 16  : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
272+   %res_6  = amdgpu.scaled_mfma (%sA_0_6 [0 ] * %opA ) * (%sB_6_20 [0 ] * %opB ) + %cst_0  {k  = 128  : i32 , m  = 16  : i32 , n  = 16  : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
273+   %res_7  = amdgpu.scaled_mfma (%sA_0_7 [0 ] * %opA ) * (%sB_6_19 [0 ] * %opB ) + %cst_0  {k  = 128  : i32 , m  = 16  : i32 , n  = 16  : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
274+   return  %res_0 , %res_1 , %res_2 , %res_3 , %res_4 , %res_5 , %res_6 , %res_7  : vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >
275+ }
0 commit comments