Skip to content

Commit d8db47a

Browse files
committed
[mlir][vector] Hoist transfer pairs when the source is AssumeAlignmentOp
ffb9bbf makes memref::AssumeAlignmentOp be ViewLikeOp, which disables the hoisting support when AssumeAlignmentOp is present. In the past, it is not an issue because the op does not have a result. After the op has a result, the hoisting is not working if transfer ops operate on AssumeAlignmentOp. Signed-off-by: hanhanW <[email protected]>
1 parent b85e929 commit d8db47a

File tree

3 files changed

+19
-9
lines changed

3 files changed

+19
-9
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ namespace linalg {
2727
/// dominated by the transfer_write (i.e. no aliasing between the write and
2828
/// the read across the loop)
2929
/// 4. The source operands for vector.transfer_{read|write} do not originate
30-
/// from Ops implementing ViewLikeOpInterface (to reduce the risk of
31-
/// aliasing).
30+
/// from ops implementing ViewLikeOpInterface (to reduce the risk of
31+
/// aliasing), except memref::AssumeAlignmentOp.
3232
/// 5. If `verifyNonZeroTrip` is true, then the lower bound of the loop must
3333
/// be statically smaller than the upper bound of the loop, guaranteeing that
3434
/// the loop body will execute at least once.
@@ -39,8 +39,8 @@ namespace linalg {
3939
///
4040
/// TODO: To further improve hoisting opportunities, fold aliasing memref
4141
/// operations into respective vector.transfer{read|write} operations and
42-
/// avoid using ops implementing ViewLikeOpInterface as the source for transfer
43-
/// Ops.
42+
/// avoid using ops implementing ViewLikeOpInterface, except
43+
/// memref::AssumeAlignmentOp, as the source for transfer ops.
4444
///
4545
/// WARNING: This hoisting does not model parallelism and is generally incorrect
4646
/// when used on distributed loops with memref semantics!

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/Dialect/Func/IR/FuncOps.h"
2222
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2323
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
24+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2425
#include "mlir/Dialect/SCF/IR/SCF.h"
2526
#include "mlir/Dialect/SCF/Utils/Utils.h"
2627
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -303,7 +304,8 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
303304
// 1. indices, vector type and permutation map are the same (i.e., the
304305
// transfer_read/transfer_write ops are matching),
305306
// 2. source operands for transfer.{read|write} do not originate from
306-
// Ops implementing ViewLikeOpInterface.
307+
// ops implementing ViewLikeOpInterface, except
308+
// memref::AssumeAlingmentOp.
307309
// 3. no other operations in the loop access the same memref except
308310
// for transfer_read/transfer_write accessing statically disjoint
309311
// slices.
@@ -313,11 +315,13 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
313315
return WalkResult::advance();
314316

315317
auto *source = transferRead.getBase().getDefiningOp();
316-
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
318+
if (source && isa_and_nonnull<ViewLikeOpInterface>(source) &&
319+
!isa<memref::AssumeAlignmentOp>(source))
317320
return WalkResult::advance();
318321

319322
source = transferWrite.getBase().getDefiningOp();
320-
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
323+
if (source && isa_and_nonnull<ViewLikeOpInterface>(source) &&
324+
!isa<memref::AssumeAlignmentOp>(source))
321325
return WalkResult::advance();
322326

323327
// TODO: may want to memoize this information for performance but it

mlir/test/Dialect/Linalg/hoisting.mlir

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@ func.func @hoist_vector_transfer_pairs(
1818
%val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
1919
%c0 = arith.constant 0 : index
2020
%cst = arith.constant 0.0 : f32
21+
%assume_align = memref.assume_alignment %memref0, 64 : memref<?x?xf32>
2122

2223
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
23-
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) {
24+
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
25+
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<1xf32>) {
2426
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<2xf32>
25-
// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) {
27+
// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>, vector<1xf32>) {
2628
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<3xf32>
2729
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<4xf32>
2830
// CHECK: "some_crippling_use"(%[[MEMREF4]]) : (memref<?x?xf32>) -> ()
@@ -43,6 +45,7 @@ func.func @hoist_vector_transfer_pairs(
4345
// CHECK: scf.yield {{.*}} : vector<1xf32>
4446
// CHECK: }
4547
// CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref<?x?xf32>
48+
// CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref<?x?xf32>
4649
// CHECK: "unrelated_use"(%[[MEMREF1]]) : (memref<?x?xf32>) -> ()
4750
scf.for %i = %lb to %ub step %step {
4851
scf.for %j = %lb to %ub step %step {
@@ -53,19 +56,22 @@ func.func @hoist_vector_transfer_pairs(
5356
"some_crippling_use"(%memref4) : (memref<?x?xf32>) -> ()
5457
%r4 = vector.transfer_read %memref4[%c0, %c0], %cst: memref<?x?xf32>, vector<5xf32>
5558
%r5 = vector.transfer_read %memref5[%c0, %c0], %cst: memref<?x?xf32>, vector<6xf32>
59+
%r6 = vector.transfer_read %assume_align[%c0, %c0], %cst: memref<?x?xf32>, vector<1xf32>
5660
"some_crippling_use"(%memref5) : (memref<?x?xf32>) -> ()
5761
%u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
5862
%u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
5963
%u2 = "some_use"(%memref2, %r2) : (memref<?x?xf32>, vector<3xf32>) -> vector<3xf32>
6064
%u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32>
6165
%u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32>
6266
%u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32>
67+
%u6 = "some_use"(%r6) : (vector<1xf32>) -> vector<1xf32>
6368
vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
6469
vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref<?x?xf32>
6570
vector.transfer_write %u2, %memref2[%c0, %c0] : vector<3xf32>, memref<?x?xf32>
6671
vector.transfer_write %u3, %memref3[%c0, %c0] : vector<4xf32>, memref<?x?xf32>
6772
vector.transfer_write %u4, %memref4[%c0, %c0] : vector<5xf32>, memref<?x?xf32>
6873
vector.transfer_write %u5, %memref5[%c0, %c0] : vector<6xf32>, memref<?x?xf32>
74+
vector.transfer_write %u6, %assume_align[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
6975
"some_crippling_use"(%memref3) : (memref<?x?xf32>) -> ()
7076
}
7177
"unrelated_use"(%memref0) : (memref<?x?xf32>) -> ()

0 commit comments

Comments
 (0)