@@ -802,3 +802,54 @@ module attributes {transform.with_named_sequence} {
802802 transform.yield
803803 }
804804}
805+
806+ // -----
807+
808+ // Test hoisting of vector.transfer_read/transfer_write pairs with same location
809+ // and this location is marked with assume_align.
810+
811+ // CHECK-LABEL: func.func @hoist_vector_transfer_read_write() {
812+ // CHECK: %c0 = arith.constant 0 : index
813+ // CHECK-NEXT: %c256 = arith.constant 256 : index
814+ // CHECK-NEXT: %c4096 = arith.constant 4096 : index
815+ // CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f16
816+ // CHECK-NEXT: %alloc = memref.alloc() : memref<4096x4096xf16>
817+ // CHECK-NEXT: %alloc_0 = memref.alloc() : memref<4096x4096xf16>
818+ // CHECK-NEXT: %assume_align = memref.assume_alignment %alloc, 64 : memref<4096x4096xf16>
819+ // CHECK-NEXT: scf.for %arg0 = %c256 to %c4096 step %c256 {
820+ // CHECK-NEXT: %0 = vector.transfer_read %assume_align[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
821+ // CHECK-NEXT: %1 = vector.transfer_read %alloc_0[%arg0, %arg0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
822+ // CHECK-NEXT: %2 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %1, %0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
823+ // CHECK-NEXT: vector.transfer_write %2, %assume_align[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
824+ // CHECK-NEXT: }
825+ // CHECK-NEXT: return
826+ // CHECK-NEXT: }
827+
828+ func.func @hoist_vector_transfer_read_write () {
829+ %c0 = arith.constant 0 : index
830+ %c64 = arith.constant 64 : index
831+ %c256 = arith.constant 256 : index
832+ %c4096 = arith.constant 4096 : index
833+ %cst_0 = arith.constant 0.000000e+00 : f16
834+ %m0 = memref.alloc () : memref <4096 x4096 xf16 >
835+ %m1 = memref.alloc () : memref <4096 x4096 xf16 >
836+ %assume_align_0 = memref.assume_alignment %m0 , 64 : memref <4096 x4096 xf16 >
837+ %assume_align_1 = memref.assume_alignment %m1 , 64 : memref <4096 x4096 xf16 >
838+ scf.for %arg0 = %c256 to %c4096 step %c256 {
839+ %1 = vector.transfer_read %assume_align_0 [%c0 , %c0 ], %cst_0 {in_bounds = [true , true ]} : memref <4096 x4096 xf16 >, vector <16 x16 xf16 >
840+ %2 = vector.transfer_read %m1 [%arg0 , %arg0 ], %cst_0 {in_bounds = [true , true ]} : memref <4096 x4096 xf16 >, vector <16 x16 xf16 >
841+ %3 = vector.contract {index ing_maps = [affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>, affine_map <(d0 , d1 , d2 ) -> (d2 , d1 )>, affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>], iterator_types = [" parallel" , " parallel" , " reduction" ], kind = #vector.kind <add >} %2 , %2 , %1 : vector <16 x16 xf16 >, vector <16 x16 xf16 > into vector <16 x16 xf16 >
842+ vector.transfer_write %3 , %assume_align_0 [%c0 , %c0 ] {in_bounds = [true , true ]} : vector <16 x16 xf16 >, memref <4096 x4096 xf16 >
843+ }
844+ return
845+ }
846+
847+ module attributes {transform.with_named_sequence } {
848+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
849+ %0 = transform.structured.match ops {[" func.func" ]} in %arg1
850+ : (!transform.any_op ) -> !transform.any_op
851+ transform.structured.hoist_redundant_vector_transfers %0
852+ : (!transform.any_op ) -> !transform.any_op
853+ transform.yield
854+ }
855+ }
0 commit comments