@@ -204,6 +204,41 @@ func.func @entry(%arg0: memref<8x32x32x32xbf16>, %arg1: memref<2x32x16x32x2xbf16
204204
205205// -----
206206
207+ func.func @optimal_register_blocking_3x1 (%arg0: memref <1 x48 x16 x2 xbf16 >, %arg1: memref <1 x16 x16 x2 xbf16 >, %arg2: memref <48 x16 xf32 >) -> memref <48 x16 xf32 > {
208+ %0 = ub.poison : f32
209+ %1 = ub.poison : bf16
210+ %c0 = arith.constant 0 : index
211+ %c48 = arith.constant 48 : index
212+ %c16 = arith.constant 16 : index
213+ %c1 = arith.constant 1 : index
214+ scf.for %arg3 = %c0 to %c48 step %c48 {
215+ scf.for %arg4 = %c0 to %c16 step %c16 {
216+ %subview = memref.subview %arg2 [%arg3 , %arg4 ] [48 , 16 ] [1 , 1 ] : memref <48 x16 xf32 > to memref <48 x16 xf32 , strided <[16 , 1 ], offset : ?>>
217+ %2 = vector.transfer_read %subview [%c0 , %c0 ], %0 {in_bounds = [true , true ]} : memref <48 x16 xf32 , strided <[16 , 1 ], offset : ?>>, vector <48 x16 xf32 >
218+ %3 = scf.for %arg5 = %c0 to %c1 step %c1 iter_args (%arg6 = %2 ) -> (vector <48 x16 xf32 >) {
219+ %4 = scf.for %arg7 = %c0 to %c16 step %c16 iter_args (%arg8 = %arg6 ) -> (vector <48 x16 xf32 >) {
220+ %subview_0 = memref.subview %arg0 [%arg5 , %arg3 , %arg7 , 0 ] [1 , 48 , 16 , 2 ] [1 , 1 , 1 , 1 ] : memref <1 x48 x16 x2 xbf16 > to memref <1 x48 x16 x2 xbf16 , strided <[1536 , 32 , 2 , 1 ], offset : ?>>
221+ %subview_1 = memref.subview %arg1 [%arg5 , %arg7 , %arg4 , 0 ] [1 , 16 , 16 , 2 ] [1 , 1 , 1 , 1 ] : memref <1 x16 x16 x2 xbf16 > to memref <1 x16 x16 x2 xbf16 , strided <[512 , 32 , 2 , 1 ], offset : ?>>
222+ %5 = vector.transfer_read %subview_0 [%c0 , %c0 , %c0 , %c0 ], %1 {in_bounds = [true , true , true , true ]} : memref <1 x48 x16 x2 xbf16 , strided <[1536 , 32 , 2 , 1 ], offset : ?>>, vector <1 x48 x16 x2 xbf16 >
223+ %6 = vector.transfer_read %subview_1 [%c0 , %c0 , %c0 , %c0 ], %1 {in_bounds = [true , true , true , true ]} : memref <1 x16 x16 x2 xbf16 , strided <[512 , 32 , 2 , 1 ], offset : ?>>, vector <1 x16 x16 x2 xbf16 >
224+ %7 = vector.contract {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d2 , d4 , d1 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d4 , d3 , d1 )>, affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d2 , d3 )>], iterator_types = [" reduction" , " reduction" , " parallel" , " parallel" , " reduction" ], kind = #vector.kind <add >} %5 , %6 , %arg8 : vector <1 x48 x16 x2 xbf16 >, vector <1 x16 x16 x2 xbf16 > into vector <48 x16 xf32 >
225+ scf.yield %7 : vector <48 x16 xf32 >
226+ }
227+ scf.yield %4 : vector <48 x16 xf32 >
228+ }
229+ vector.transfer_write %3 , %subview [%c0 , %c0 ] {in_bounds = [true , true ]} : vector <48 x16 xf32 >, memref <48 x16 xf32 , strided <[16 , 1 ], offset : ?>>
230+ }
231+ }
232+ return %arg2 : memref <48 x16 xf32 >
233+ }
234+
235+ // CHECK-LABEL: func.func @optimal_register_blocking_3x1
236+ // CHECK-COUNT-3: amx.tile_load
237+ // CHECK-COUNT-3: amx.tile_mulf
238+ // CHECK-COUNT-3: amx.tile_store
239+
240+ // -----
241+
207242// This tests shows the lowering of a mixed precision vector.contract
208243// (i8 x i8 -> i32) to AMX dialect.
209244func.func @entry (%arg0: memref <4 x16 x64 x64 xi8 >, %arg1: memref <16 x16 x16 x64 x4 xi8 >, %arg2: memref <4 x16 x64 x64 xi32 >) {
0 commit comments