@@ -204,38 +204,21 @@ func.func @scaled_mfma_less_than_4(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4
204204  return  %res_0  : vector <4 xf32 >
205205}
206206
207- 
208207// ----- 
209208
210209// CHECK-LABEL: func @scaled_mfma_ugly_shapes 
211- // CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> 
212- // CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> 
213- // CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[0] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> 
214- // CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> 
215210// CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> 
216211// CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> 
217212// CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> 
218213// CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32> 
219- func.func  @scaled_mfma_ugly_shapes (%opA:  vector <32 xf4 E2 M1 FN>, %opB:  vector <32 xf4 E2 M1 FN>, %scalesA:  vector <5 x5 xf8 E8 M0 FNU>, %scalesB:  vector <7 x23 xf8 E8 M0 FNU>) -> (vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >,  vector < 4 x f32 >,  vector < 4 x f32 >,  vector < 4 x f32 >,  vector < 4 x f32 > ) {
214+ func.func  @scaled_mfma_ugly_shapes (%opA:  vector <32 xf4 E2 M1 FN>, %opB:  vector <32 xf4 E2 M1 FN>, %scalesA:  vector <5 x5 xf8 E8 M0 FNU>, %scalesB:  vector <7 x23 xf8 E8 M0 FNU>) -> (vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >) {
220215  %cst_0  = arith.constant  dense <0.000000e+00 > : vector <4 xf32 >
221216  %cst_1  = arith.constant  dense <5.877470e-39 > : vector <4 xf8 E8 M0 FNU>
222-   %scaleA_0_0  = vector.extract  %scalesA [0 , 0 ] : f8E8M0FNU  from  vector <5 x5 xf8 E8 M0 FNU>
223-   %scaleA_0_1  = vector.extract  %scalesA [1 , 0 ] : f8E8M0FNU  from  vector <5 x5 xf8 E8 M0 FNU>
224-   %scaleA_0_2  = vector.extract  %scalesA [2 , 0 ] : f8E8M0FNU  from  vector <5 x5 xf8 E8 M0 FNU>
225-   %scaleA_0_3  = vector.extract  %scalesA [3 , 0 ] : f8E8M0FNU  from  vector <5 x5 xf8 E8 M0 FNU>
226217  %scaleA_0_4  = vector.extract  %scalesA [4 , 0 ] : f8E8M0FNU  from  vector <5 x5 xf8 E8 M0 FNU>
227218  %scaleA_0_5  = vector.extract  %scalesA [4 , 1 ] : f8E8M0FNU  from  vector <5 x5 xf8 E8 M0 FNU>
228219  %scaleA_0_6  = vector.extract  %scalesA [4 , 2 ] : f8E8M0FNU  from  vector <5 x5 xf8 E8 M0 FNU>
229220  %scaleA_0_7  = vector.extract  %scalesA [4 , 3 ] : f8E8M0FNU  from  vector <5 x5 xf8 E8 M0 FNU>
230221
231-   // idx = 138 + 8 = 146 => opsel = 2 
232-   %scaleB_6_8  = vector.extract  %scalesB [6 , 8 ] : f8E8M0FNU  from  vector <7 x23 xf8 E8 M0 FNU>
233-   // idx = 147 => opsel = 3 
234-   %scaleB_6_9  = vector.extract  %scalesB [6 , 9 ] : f8E8M0FNU  from  vector <7 x23 xf8 E8 M0 FNU>
235-   // idx = 148 => opsel = 0 
236-   %scaleB_6_10  = vector.extract  %scalesB [6 , 10 ] : f8E8M0FNU  from  vector <7 x23 xf8 E8 M0 FNU>
237-   // idx = 149 => opsel = 1 
238-   %scaleB_6_11  = vector.extract  %scalesB [6 , 11 ] : f8E8M0FNU  from  vector <7 x23 xf8 E8 M0 FNU>
239222  // idx = 160 => opsel = 3 (last idx of last 4 bytes) 
240223  %scaleB_6_22  = vector.extract  %scalesB [6 , 22 ] : f8E8M0FNU  from  vector <7 x23 xf8 E8 M0 FNU>
241224  // idx = 159 => opsel = 3 
@@ -245,31 +228,19 @@ func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4
245228  // idx = 157 => opsel = 1 
246229  %scaleB_6_19  = vector.extract  %scalesB [6 , 19 ] : f8E8M0FNU  from  vector <7 x23 xf8 E8 M0 FNU>
247230
248-   %sA_0_0  = vector.insert  %scaleA_0_0 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
249-   %sA_0_1  = vector.insert  %scaleA_0_1 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
250-   %sA_0_2  = vector.insert  %scaleA_0_2 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
251-   %sA_0_3  = vector.insert  %scaleA_0_3 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
252231  %sA_0_4  = vector.insert  %scaleA_0_4 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
253232  %sA_0_5  = vector.insert  %scaleA_0_5 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
254233  %sA_0_6  = vector.insert  %scaleA_0_6 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
255234  %sA_0_7  = vector.insert  %scaleA_0_7 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
256235
257-   %sB_6_8  = vector.insert  %scaleB_6_8 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
258-   %sB_6_9  = vector.insert  %scaleB_6_9 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
259-   %sB_6_10  = vector.insert  %scaleB_6_10 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
260-   %sB_6_11  = vector.insert  %scaleB_6_11 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
261236  %sB_6_22  = vector.insert  %scaleB_6_22 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
262237  %sB_6_21  = vector.insert  %scaleB_6_21 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
263238  %sB_6_20  = vector.insert  %scaleB_6_20 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
264239  %sB_6_19  = vector.insert  %scaleB_6_19 , %cst_1  [0 ] : f8E8M0FNU  into  vector <4 xf8 E8 M0 FNU>
265240
266-   %res_0  = amdgpu.scaled_mfma (%sA_0_0 [0 ] * %opA ) * (%sB_6_8 [0 ] * %opB ) + %cst_0  {k  = 128  : i32 , m  = 16  : i32 , n  = 16  : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
267-   %res_1  = amdgpu.scaled_mfma (%sA_0_1 [0 ] * %opA ) * (%sB_6_9 [0 ] * %opB ) + %cst_0  {k  = 128  : i32 , m  = 16  : i32 , n  = 16  : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
268-   %res_2  = amdgpu.scaled_mfma (%sA_0_2 [0 ] * %opA ) * (%sB_6_10 [0 ] * %opB ) + %cst_0  {k  = 128  : i32 , m  = 16  : i32 , n  = 16  : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
269-   %res_3  = amdgpu.scaled_mfma (%sA_0_3 [0 ] * %opA ) * (%sB_6_11 [0 ] * %opB ) + %cst_0  {k  = 128  : i32 , m  = 16  : i32 , n  = 16  : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
270241  %res_4  = amdgpu.scaled_mfma (%sA_0_4 [0 ] * %opA ) * (%sB_6_22 [0 ] * %opB ) + %cst_0  {k  = 128  : i32 , m  = 16  : i32 , n  = 16  : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
271242  %res_5  = amdgpu.scaled_mfma (%sA_0_5 [0 ] * %opA ) * (%sB_6_21 [0 ] * %opB ) + %cst_0  {k  = 128  : i32 , m  = 16  : i32 , n  = 16  : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
272243  %res_6  = amdgpu.scaled_mfma (%sA_0_6 [0 ] * %opA ) * (%sB_6_20 [0 ] * %opB ) + %cst_0  {k  = 128  : i32 , m  = 16  : i32 , n  = 16  : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
273244  %res_7  = amdgpu.scaled_mfma (%sA_0_7 [0 ] * %opA ) * (%sB_6_19 [0 ] * %opB ) + %cst_0  {k  = 128  : i32 , m  = 16  : i32 , n  = 16  : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
274-   return  %res_0  ,  %res_1 ,  %res_2 ,  %res_3 ,  % res_4%res_5 , %res_6 , %res_7  :  vector < 4 x f32 >,  vector < 4 x f32 >,  vector < 4 x f32 >,  vector < 4 x f32 >,  vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >
245+   return  %res_4 , %res_5 , %res_6 , %res_7  : vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >, vector <4 xf32 >
275246}
0 commit comments