Skip to content

Commit 3b387e7

Browse files
committed
[mlir][NFC] Pre-commit test for linalg hoisting
1 parent cb355de commit 3b387e7

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

mlir/test/Dialect/Linalg/hoisting.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,3 +802,54 @@ module attributes {transform.with_named_sequence} {
802802
transform.yield
803803
}
804804
}
805+
806+
// -----
807+
808+
// Test hoisting of vector.transfer_read/transfer_write pairs with same location
809+
// and this location is marked with assume_align.
810+
811+
// CHECK-LABEL: func.func @hoist_vector_transfer_read_write() {
812+
// CHECK: %c0 = arith.constant 0 : index
813+
// CHECK-NEXT: %c256 = arith.constant 256 : index
814+
// CHECK-NEXT: %c4096 = arith.constant 4096 : index
815+
// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f16
816+
// CHECK-NEXT: %alloc = memref.alloc() : memref<4096x4096xf16>
817+
// CHECK-NEXT: %alloc_0 = memref.alloc() : memref<4096x4096xf16>
818+
// CHECK-NEXT: %assume_align = memref.assume_alignment %alloc, 64 : memref<4096x4096xf16>
819+
// CHECK-NEXT: scf.for %arg0 = %c256 to %c4096 step %c256 {
820+
// CHECK-NEXT: %0 = vector.transfer_read %assume_align[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
821+
// CHECK-NEXT: %1 = vector.transfer_read %alloc_0[%arg0, %arg0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
822+
// CHECK-NEXT: %2 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %1, %0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
823+
// CHECK-NEXT: vector.transfer_write %2, %assume_align[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
824+
// CHECK-NEXT: }
825+
// CHECK-NEXT: return
826+
// CHECK-NEXT: }
827+
828+
func.func @hoist_vector_transfer_read_write() {
829+
%c0 = arith.constant 0 : index
830+
%c64 = arith.constant 64 : index
831+
%c256 = arith.constant 256 : index
832+
%c4096 = arith.constant 4096 : index
833+
%cst_0 = arith.constant 0.000000e+00 : f16
834+
%m0 = memref.alloc() : memref<4096x4096xf16>
835+
%m1 = memref.alloc() : memref<4096x4096xf16>
836+
%assume_align_0 = memref.assume_alignment %m0, 64 : memref<4096x4096xf16>
837+
%assume_align_1 = memref.assume_alignment %m1, 64 : memref<4096x4096xf16>
838+
scf.for %arg0 = %c256 to %c4096 step %c256 {
839+
%1 = vector.transfer_read %assume_align_0[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
840+
%2 = vector.transfer_read %m1[%arg0, %arg0], %cst_0 {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
841+
%3 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %2, %1 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
842+
vector.transfer_write %3, %assume_align_0[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
843+
}
844+
return
845+
}
846+
847+
module attributes {transform.with_named_sequence} {
848+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
849+
%0 = transform.structured.match ops{["func.func"]} in %arg1
850+
: (!transform.any_op) -> !transform.any_op
851+
transform.structured.hoist_redundant_vector_transfers %0
852+
: (!transform.any_op) -> !transform.any_op
853+
transform.yield
854+
}
855+
}

0 commit comments

Comments
 (0)