Skip to content

Commit 9e7183e

Browse files
authored
[Codegen] Canonicalize loops and subviews after copy vectorization (#22344)
The new `memref.copy` tiling and vectorization patterns can generate loops and subviews that can be canonicalized further. If this isn't done, it can lead to errors in a later `FoldMemRefAliasOp` pass as the subviews are supposed to be canonicalized at that point. Signed-off-by: Jorn Tuyls <[email protected]>
1 parent 3815582 commit 9e7183e

File tree

4 files changed

+118
-38
lines changed

4 files changed

+118
-38
lines changed

compiler/src/iree/compiler/Codegen/Common/VectorizeMemrefCopy.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
#include "iree/compiler/Codegen/Common/Passes.h"
8+
#include "iree/compiler/Codegen/Transforms/Transforms.h"
89
#include "iree/compiler/Codegen/Utils/Utils.h"
910
#include "mlir/Dialect/Arith/IR/Arith.h"
1011
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
12+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1113
#include "mlir/Dialect/SCF/IR/SCF.h"
1214
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1315
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -60,6 +62,11 @@ struct TileLinalgCopy final : OpRewritePattern<memref::CopyOp> {
6062
for (Operation *tiledOp : tilingResult->tiledOps) {
6163
tiledOp->setAttr(kIsTiled, mlir::UnitAttr::get(copyOp.getContext()));
6264
}
65+
// Put an marker on the loop ops, so they can be targeted for
66+
// simplification.
67+
for (LoopLikeOpInterface loop : llvm::reverse(tilingResult->loops)) {
68+
loop->setAttr(kIsTiled, mlir::UnitAttr::get(loop.getContext()));
69+
}
6370
if (tilingInterfaceOp->use_empty()) {
6471
rewriter.eraseOp(tilingInterfaceOp);
6572
}
@@ -104,12 +111,18 @@ struct VectorizeMemrefCopyPass final
104111
patterns.add<TileLinalgCopy>(&getContext());
105112
patterns.add<linalg::CopyVectorizationPattern>(&getContext());
106113
patterns.add<ConvertLinalgCopyToMemrefCopy>(&getContext());
114+
// Try to remove generated single iteration loops and canonicalize generated
115+
// subview operations.
116+
populateRemoveSingleIterationLoopPattern(
117+
patterns,
118+
[&](scf::ForOp forOp) -> bool { return forOp->hasAttr(kIsTiled); });
119+
memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
107120
(void)applyPatternsGreedily(funcOp, std::move(patterns));
108121

109122
// Clean up the temporary isTiled markers.
110-
funcOp->walk([](memref::CopyOp copyOp) {
111-
if (copyOp->hasAttr(kIsTiled)) {
112-
copyOp->removeAttr(kIsTiled);
123+
funcOp->walk([](Operation *op) {
124+
if (op->hasAttr(kIsTiled)) {
125+
op->removeAttr(kIsTiled);
113126
}
114127
});
115128
}

compiler/src/iree/compiler/Codegen/Common/test/vectorize_memref_copy.mlir

Lines changed: 80 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,9 @@ func.func @memref_copy(%source: memref<2x2xf32>, %dest: memref<2x2xf32>) {
99
// CHECK-SAME: %[[SOURCE:[A-Za-z0-9]+]]: memref<2x2xf32>
1010
// CHECK-SAME: %[[DEST:[A-Za-z0-9]+]]: memref<2x2xf32>
1111
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
12-
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
13-
// CHECK: scf.for %[[ARG2:.+]] = %[[C0]] to %[[C2]] step %[[C2]]
14-
// CHECK: scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C2]]
15-
// CHECK: %[[SOURCE_SUBVIEW:.+]] = memref.subview %[[SOURCE]][%[[ARG2]], %[[ARG3]]] [2, 2] [1, 1]
16-
// CHECK: %[[DEST_SUBVIEW:.+]] = memref.subview %[[DEST]][%[[ARG2]], %[[ARG3]]] [2, 2] [1, 1]
17-
// CHECK: %[[RD:.+]] = vector.transfer_read %[[SOURCE_SUBVIEW]]
18-
// CHECK: vector.transfer_write %[[RD]], %[[DEST_SUBVIEW]]
12+
// CHECK-DAG: %[[POISON:.+]] = ub.poison : f32
13+
// CHECK: %[[RD:.+]] = vector.transfer_read %[[SOURCE]][%[[C0]], %[[C0]]], %[[POISON]] {in_bounds = [true, true]} : memref<2x2xf32>, vector<2x2xf32>
14+
// CHECK: vector.transfer_write %[[RD]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<2x2xf32>
1915

2016
// -----
2117

@@ -28,13 +24,9 @@ func.func @linalg_copy(%source: memref<2x2xf32>, %dest: memref<2x2xf32>) {
2824
// CHECK-SAME: %[[SOURCE:[A-Za-z0-9]+]]: memref<2x2xf32>
2925
// CHECK-SAME: %[[DEST:[A-Za-z0-9]+]]: memref<2x2xf32>
3026
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
31-
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
32-
// CHECK: scf.for %[[ARG2:.+]] = %[[C0]] to %[[C2]] step %[[C2]]
33-
// CHECK: scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C2]]
34-
// CHECK: %[[SOURCE_SUBVIEW:.+]] = memref.subview %[[SOURCE]][%[[ARG2]], %[[ARG3]]] [2, 2] [1, 1]
35-
// CHECK: %[[DEST_SUBVIEW:.+]] = memref.subview %[[DEST]][%[[ARG2]], %[[ARG3]]] [2, 2] [1, 1]
36-
// CHECK: %[[RD:.+]] = vector.transfer_read %[[SOURCE_SUBVIEW]]
37-
// CHECK: vector.transfer_write %[[RD]], %[[DEST_SUBVIEW]]
27+
// CHECK-DAG: %[[POISON:.+]] = ub.poison : f32
28+
// CHECK: %[[RD:.+]] = vector.transfer_read %[[SOURCE]][%[[C0]], %[[C0]]], %[[POISON]] {in_bounds = [true, true]} : memref<2x2xf32>, vector<2x2xf32>
29+
// CHECK: vector.transfer_write %[[RD]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<2x2xf32>
3830

3931
// -----
4032

@@ -44,6 +36,7 @@ func.func @memref_copy_not_multiple_of_preferred(%source: memref<2x6xf32>, %dest
4436
memref.copy %source, %dest : memref<2x6xf32> to memref<2x6xf32>
4537
return
4638
}
39+
4740
// CHECK-LABEL: func.func @memref_copy_not_multiple_of_preferred
4841
// CHECK-SAME: %[[SOURCE:[A-Za-z0-9]+]]: memref<2x6xf32>
4942
// CHECK-SAME: %[[DEST:[A-Za-z0-9]+]]: memref<2x6xf32>
@@ -74,11 +67,10 @@ func.func @memref_copy_not_multiple_on_penultimate_dim(%source: memref<3x2xf32>,
7467
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
7568
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
7669
// CHECK: scf.for %[[ARG2:.+]] = %[[C0]] to %[[C3]] step %[[C2]]
77-
// CHECK: scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C2]]
78-
// CHECK: %[[MIN:.+]] = affine.min affine_map<(d0) -> (-d0 + 3, 2)>(%[[ARG2]])
79-
// CHECK: %[[SOURCE_SUBVIEW:.+]] = memref.subview %[[SOURCE]][%[[ARG2]], %[[ARG3]]] [%[[MIN]], 2] [1, 1]
80-
// CHECK: %[[DEST_SUBVIEW:.+]] = memref.subview %[[DEST]][%[[ARG2]], %[[ARG3]]] [%[[MIN]], 2] [1, 1]
81-
// CHECK: memref.copy %[[SOURCE_SUBVIEW]], %[[DEST_SUBVIEW]]
70+
// CHECK: %[[MIN:.+]] = affine.min affine_map<(d0) -> (-d0 + 3, 2)>(%[[ARG2]])
71+
// CHECK: %[[SOURCE_SUBVIEW:.+]] = memref.subview %[[SOURCE]][%[[ARG2]], 0] [%[[MIN]], 2] [1, 1]
72+
// CHECK: %[[DEST_SUBVIEW:.+]] = memref.subview %[[DEST]][%[[ARG2]], 0] [%[[MIN]], 2] [1, 1]
73+
// CHECK: memref.copy %[[SOURCE_SUBVIEW]], %[[DEST_SUBVIEW]]
8274

8375
// -----
8476

@@ -91,14 +83,12 @@ func.func @memref_copy_dynamic(%source: memref<?x4xf32>, %dest: memref<?x4xf32>)
9183
// CHECK-SAME: %[[DEST:[A-Za-z0-9]+]]: memref<?x4xf32>
9284
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
9385
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
94-
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
9586
// CHECK-DAG: %[[DIM:.+]] = memref.dim %[[SOURCE]], %[[C0]] : memref<?x4xf32>
9687
// CHECK: scf.for %[[ARG2:.+]] = %[[C0]] to %[[DIM]] step %[[C1]]
97-
// CHECK: scf.for %[[ARG3:.+]] = %[[C0]] to %[[C4]] step %[[C4]]
98-
// CHECK: %[[SOURCE_SUBVIEW:.+]] = memref.subview %[[SOURCE]][%[[ARG2]], %[[ARG3]]] [1, 4] [1, 1]
99-
// CHECK: %[[DEST_SUBVIEW:.+]] = memref.subview %[[DEST]][%[[ARG2]], %[[ARG3]]] [1, 4] [1, 1]
100-
// CHECK: %[[RD:.+]] = vector.transfer_read %[[SOURCE_SUBVIEW]]
101-
// CHECK: vector.transfer_write %[[RD]], %[[DEST_SUBVIEW]]
88+
// CHECK: %[[SOURCE_SUBVIEW:.+]] = memref.subview %[[SOURCE]][%[[ARG2]], 0] [1, 4] [1, 1]
89+
// CHECK: %[[DEST_SUBVIEW:.+]] = memref.subview %[[DEST]][%[[ARG2]], 0] [1, 4] [1, 1]
90+
// CHECK: %[[RD:.+]] = vector.transfer_read %[[SOURCE_SUBVIEW]]
91+
// CHECK: vector.transfer_write %[[RD]], %[[DEST_SUBVIEW]]
10292

10393
// -----
10494

@@ -119,3 +109,68 @@ func.func @memref_copy_dynamic_inner_dim(%source: memref<4x?xf32>, %dest: memref
119109
// CHECK: %[[SOURCE_SUBVIEW:.+]] = memref.subview %[[SOURCE]][%[[ARG2]], %[[ARG3]]] [1, %[[MIN]]] [1, 1]
120110
// CHECK: %[[DEST_SUBVIEW:.+]] = memref.subview %[[DEST]][%[[ARG2]], %[[ARG3]]] [1, %[[MIN]]] [1, 1]
121111
// CHECK: memref.copy %[[SOURCE_SUBVIEW]], %[[DEST_SUBVIEW]]
112+
113+
// -----
114+
115+
// Test that the single iteration loops are removed and the subview ops are canonicalized
116+
// (`memref<1x?xbf16, strided<[4, 1]>` instead of `memref<1x?xbf16, strided<[4, 1], offset: ?>`).
117+
118+
func.func @memref_copy_fully_dynamic(%source: memref<1x4xbf16>, %dest: memref<32x?xbf16, strided<[40, 1], offset: ?>>, %dim: index) {
119+
%c0 = arith.constant 0 : index
120+
scf.forall (%arg0) in (3) {
121+
%0 = affine.min affine_map<(d0) -> (d0 * -16 + 40, 16)>(%arg0)
122+
%1:2 = affine.delinearize_index %dim into (2, 64) : index, index
123+
%2:3 = affine.delinearize_index %1#1 into (4, 16) : index, index, index
124+
%3 = affine.linearize_index disjoint [%2#1, %c0] by (4, 4) : index
125+
%4 = affine.linearize_index disjoint [%1#0, %2#2] by (2, 16) : index
126+
%5 = affine.max affine_map<()[s0] -> (-s0 + 32, 0)>()[%4]
127+
%6 = affine.min affine_map<()[s0] -> (1, s0)>()[%5]
128+
%7 = affine.max affine_map<(d0)[s0] -> (0, d0 - s0)>(%0)[%3]
129+
%8 = affine.min affine_map<(d0) -> (4, d0)>(%7)
130+
%subview_0 = memref.subview %source[0, 0] [%6, %8] [1, 1] : memref<1x4xbf16> to memref<?x?xbf16, strided<[4, 1]>>
131+
%subview_1 = memref.subview %dest[%4, %3] [%6, %8] [1, 1] : memref<32x?xbf16, strided<[40, 1], offset: ?>> to memref<?x?xbf16, strided<[40, 1], offset: ?>>
132+
memref.copy %subview_0, %subview_1 : memref<?x?xbf16, strided<[4, 1]>> to memref<?x?xbf16, strided<[40, 1], offset: ?>>
133+
}
134+
return
135+
}
136+
// CHECK-LABEL: func.func @memref_copy_fully_dynamic
137+
// CHECK-SAME: %[[SOURCE:[A-Za-z0-9]+]]: memref<1x4xbf16>
138+
// CHECK-SAME: %[[DEST:[A-Za-z0-9]+]]: memref<32x?xbf16, strided<[40, 1], offset: ?>>
139+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
140+
// CHECK-DAG: %[[LIN_0:.+]] = affine.linearize_index disjoint [%{{.+}}, %{{.+}}] by (4, 4) : index
141+
// CHECK-DAG: %[[LIN_1:.+]] = affine.linearize_index disjoint [%{{.+}}, %{{.+}}] by (2, 16) : index
142+
// CHECK-DAG: %[[MIN_0:.+]] = affine.min affine_map<()[s0] -> (1, s0)>()[%{{.+}}]
143+
// CHECK-DAG: %[[MIN_1:.+]] = affine.min affine_map<(d0) -> (4, d0)>(%{{.+}})
144+
// CHECK-DAG: %[[SUBVIEW_0:.+]] = memref.subview %[[SOURCE]][0, 0] [%[[MIN_0]], %[[MIN_1]]] [1, 1]
145+
// CHECK-SAME: memref<1x4xbf16> to memref<?x?xbf16, strided<[4, 1]>>
146+
// CHECK-DAG: %[[SUBVIEW_1:.+]] = memref.subview %[[DEST]][%[[LIN_1]], %[[LIN_0]]] [%[[MIN_0]], %[[MIN_1]]] [1, 1]
147+
// CHECK-SAME: memref<32x?xbf16, strided<[40, 1], offset: ?>> to memref<?x?xbf16, strided<[40, 1], offset: ?>>
148+
// CHECK-DAG: %[[CMP_0:.+]] = arith.cmpi sgt, %[[MIN_0]], %[[C0]] : index
149+
// CHECK: scf.if %[[CMP_0]] {
150+
// CHECK: %[[CMP_1:.+]] = arith.cmpi sgt, %[[MIN_1]], %[[C0]] : index
151+
// CHECK: scf.if %[[CMP_1]] {
152+
// CHECK: %[[MIN_2:.+]] = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 8)>(%[[C0]])[%[[MIN_1]]]
153+
// CHECK: %[[SUBIEW_2:.+]] = memref.subview %[[SUBVIEW_0]][0, 0] [1, %[[MIN_2]]] [1, 1]
154+
// CHECK-SAME: memref<?x?xbf16, strided<[4, 1]>> to memref<1x?xbf16, strided<[4, 1]>>
155+
// CHECK: %[[SUBVIEW_3:.+]] = memref.subview %[[SUBVIEW_1]][0, 0] [1, %[[MIN_2]]] [1, 1]
156+
// CHECK-SAME: memref<?x?xbf16, strided<[40, 1], offset: ?>> to memref<1x?xbf16, strided<[40, 1], offset: ?>>
157+
// CHECK: memref.copy %[[SUBIEW_2]], %[[SUBVIEW_3]]
158+
159+
// -----
160+
161+
// Test that scf.for operations with `_is_tiled` attribute are simplified. The `memref.copy` should still be vectorized as well.
162+
163+
func.func @for_with_tiled_attr(%source: memref<4x?xf32>, %dest: memref<4x?xf32>) {
164+
%c0 = arith.constant 0 : index
165+
%c1 = arith.constant 1 : index
166+
scf.for %arg0 = %c0 to %c1 step %c1 {
167+
%subview_0 = memref.subview %source[%arg0, 0] [4, 1] [1, 1] : memref<4x?xf32> to memref<4x1xf32, strided<[?, 1], offset: ?>>
168+
%subview_1 = memref.subview %dest[%arg0, 0] [4, 1] [1, 1] : memref<4x?xf32> to memref<4x1xf32, strided<[?, 1], offset: ?>>
169+
memref.copy %subview_0, %subview_1 : memref<4x1xf32, strided<[?, 1], offset: ?>> to memref<4x1xf32, strided<[?, 1], offset: ?>>
170+
} {_is_tiled}
171+
return
172+
}
173+
// CHECK-LABEL: func.func @for_with_tiled_attr
174+
// CHECK-NOT: scf.for
175+
// CHECK: vector.transfer_read
176+
// CHECK: vector.transfer_write

compiler/src/iree/compiler/Codegen/Transforms/RemoveSingleIterationLoop.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,22 @@ static void replaceForWithIf(PatternRewriter &rewriter, scf::ForOp op,
6060
namespace {
6161
/// Rewriting pattern that replaces single-iteration loops with their bodies.
6262
struct SimplifyTrivialLoops : public OpRewritePattern<scf::ForOp> {
63-
using Base::Base;
63+
64+
SimplifyTrivialLoops(MLIRContext *context, ForControlFnRef controlFn)
65+
: OpRewritePattern(context), controlFn(controlFn) {}
6466

6567
LogicalResult matchAndRewrite(scf::ForOp op,
6668
PatternRewriter &rewriter) const override {
67-
if (!(neverRunsSecondIteration(op))) {
68-
return failure();
69+
if (controlFn && !controlFn(op)) {
70+
return rewriter.notifyMatchFailure(
71+
op, "doesn't match according to the the control function");
6972
}
70-
71-
// The second iteration is never run
72-
// so the loop atmost can have 1 iteration. Inline its body and remove the
73-
// loop.
73+
if (!neverRunsSecondIteration(op)) {
74+
return rewriter.notifyMatchFailure(op,
75+
"is not a single-iteration for loop");
76+
}
77+
// The second iteration is never run so the loop atmost can have 1
78+
// iteration. Inline its body and remove the loop.
7479
SmallVector<Value> blockArgs;
7580
blockArgs.reserve(op.getInitArgs().size() + 1);
7681
blockArgs.push_back(op.getLowerBound());
@@ -82,12 +87,16 @@ struct SimplifyTrivialLoops : public OpRewritePattern<scf::ForOp> {
8287
}
8388
return success();
8489
}
90+
91+
private:
92+
ForControlFnRef controlFn;
8593
};
8694

8795
} // namespace
8896

89-
void populateRemoveSingleIterationLoopPattern(RewritePatternSet &patterns) {
90-
patterns.add<SimplifyTrivialLoops>(patterns.getContext());
97+
void populateRemoveSingleIterationLoopPattern(RewritePatternSet &patterns,
98+
ForControlFnRef controlFn) {
99+
patterns.add<SimplifyTrivialLoops>(patterns.getContext(), controlFn);
91100
}
92101

93102
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Transforms/Transforms.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
namespace mlir::iree_compiler {
2727

28+
using ForControlFnRef = llvm::function_ref<bool(scf::ForOp)>;
29+
2830
/// Get the `offsets`, `sizes` and `strides` for a `storeOp` (or `loadOp`). This
2931
/// method clones the operations that generate the `Value`s used for
3032
/// specifying the offsets, sizesm strides and dynamic dims of the
@@ -100,7 +102,8 @@ using GetMinMaxExprFn =
100102

101103
/// Insert pattern to remove single iteration loop. The pattern will detect
102104
/// single iteration loops based on the range returned ValueBoundsOpInterface.
103-
void populateRemoveSingleIterationLoopPattern(RewritePatternSet &patterns);
105+
void populateRemoveSingleIterationLoopPattern(
106+
RewritePatternSet &patterns, ForControlFnRef controlFn = nullptr);
104107

105108
// Group of Alloc operations that have overlapping liveranges.
106109
using AliasGroup = SmallVector<Operation *>;

0 commit comments

Comments
 (0)