Skip to content

Commit 706c1da

Browse files
committed
[mlir][hoisting] Support memref.assume_alignment in linalg hoisting
All ViewLike operations are excluded by hoisting optimization. But assume_alignment just mark memref's alignment, we should check its memref instead of itself.
1 parent 3b387e7 commit 706c1da

File tree

2 files changed

+27
-10
lines changed

2 files changed

+27
-10
lines changed

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,24 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
199199
return true;
200200
}
201201

202+
static bool skipViewLike(Operation *source0, Operation *source1) {
203+
bool viewLikeCheck = true;
204+
auto assumeAlignOp = dyn_cast_or_null<memref::AssumeAlignmentOp>(source0);
205+
if (assumeAlignOp && source0 == source1) {
206+
Value sourceMemRef = assumeAlignOp.getMemref();
207+
Operation *sourceOp = sourceMemRef.getDefiningOp();
208+
return isa_and_nonnull<ViewLikeOpInterface>(sourceOp);
209+
}
210+
211+
if (source0 && isa_and_nonnull<ViewLikeOpInterface>(source0))
212+
return true;
213+
214+
if (source1 && isa_and_nonnull<ViewLikeOpInterface>(source1))
215+
return true;
216+
217+
return false;
218+
}
219+
202220
void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
203221
bool verifyNonZeroTrip) {
204222
bool changed = true;
@@ -312,12 +330,10 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
312330
transferRead.getPermutationMap() != transferWrite.getPermutationMap())
313331
return WalkResult::advance();
314332

315-
auto *source = transferRead.getBase().getDefiningOp();
316-
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
317-
return WalkResult::advance();
333+
auto *source0 = transferRead.getBase().getDefiningOp();
334+
auto *source1 = transferWrite.getBase().getDefiningOp();
318335

319-
source = transferWrite.getBase().getDefiningOp();
320-
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
336+
if (skipViewLike(source0, source1))
321337
return WalkResult::advance();
322338

323339
// TODO: may want to memoize this information for performance but it

mlir/test/Dialect/Linalg/hoisting.mlir

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -816,12 +816,13 @@ module attributes {transform.with_named_sequence} {
816816
// CHECK-NEXT: %alloc = memref.alloc() : memref<4096x4096xf16>
817817
// CHECK-NEXT: %alloc_0 = memref.alloc() : memref<4096x4096xf16>
818818
// CHECK-NEXT: %assume_align = memref.assume_alignment %alloc, 64 : memref<4096x4096xf16>
819-
// CHECK-NEXT: scf.for %arg0 = %c256 to %c4096 step %c256 {
820-
// CHECK-NEXT: %0 = vector.transfer_read %assume_align[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
821-
// CHECK-NEXT: %1 = vector.transfer_read %alloc_0[%arg0, %arg0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
822-
// CHECK-NEXT: %2 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %1, %0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
823-
// CHECK-NEXT: vector.transfer_write %2, %assume_align[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
819+
// CHECK-NEXT: %0 = vector.transfer_read %assume_align[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
820+
// CHECK-NEXT: %1 = scf.for %arg0 = %c256 to %c4096 step %c256 iter_args(%arg1 = %0) -> (vector<16x16xf16>) {
821+
// CHECK-NEXT: %2 = vector.transfer_read %alloc_0[%arg0, %arg0], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<16x16xf16>
822+
// CHECK-NEXT: %3 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %2, %arg1 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
823+
// CHECK-NEXT: scf.yield %3 : vector<16x16xf16>
824824
// CHECK-NEXT: }
825+
// CHECK-NEXT: vector.transfer_write %1, %assume_align[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<4096x4096xf16>
825826
// CHECK-NEXT: return
826827
// CHECK-NEXT: }
827828

0 commit comments

Comments
 (0)