1
- // ===- AllocTensorElimination .cpp - alloc_tensor op elimination -----------===//
1
+ // ===- EmptyTensorElimination .cpp - tensor.empty op elimination -----------===//
2
2
//
3
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
4
// See https://llvm.org/LICENSE.txt for license information.
10
10
11
11
#include " mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12
12
#include " mlir/Dialect/Bufferization/IR/Bufferization.h"
13
- #include " mlir/Dialect/Bufferization/Transforms/AllocTensorElimination .h"
13
+ #include " mlir/Dialect/Bufferization/Transforms/EmptyTensorElimination .h"
14
14
#include " mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15
15
#include " mlir/Dialect/Tensor/IR/Tensor.h"
16
16
#include " mlir/IR/Dominance.h"
17
17
#include " mlir/Pass/Pass.h"
18
18
19
19
namespace mlir {
20
20
namespace bufferization {
21
- #define GEN_PASS_DEF_ALLOCTENSORELIMINATION
21
+ #define GEN_PASS_DEF_EMPTYTENSORELIMINATION
22
22
#include " mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
23
23
} // namespace bufferization
24
24
} // namespace mlir
@@ -47,27 +47,27 @@ neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
47
47
}
48
48
49
49
// / Return true if the given `insertionPoint` dominates all uses of
50
- // / `allocTensorOp `.
50
+ // / `emptyTensorOp `.
51
51
static bool insertionPointDominatesUses (const DominanceInfo &domInfo,
52
52
Operation *insertionPoint,
53
- Operation *allocTensorOp ) {
54
- for (Operation *user : allocTensorOp ->getUsers ())
53
+ Operation *emptyTensorOp ) {
54
+ for (Operation *user : emptyTensorOp ->getUsers ())
55
55
if (!domInfo.dominates (insertionPoint, user))
56
56
return false ;
57
57
return true ;
58
58
}
59
59
60
- // / Find a valid insertion point for a replacement of `allocTensorOp `, assuming
60
+ // / Find a valid insertion point for a replacement of `emptyTensorOp `, assuming
61
61
// / that the replacement may use any value from `neededValues`.
62
62
static Operation *
63
- findValidInsertionPoint (Operation *allocTensorOp ,
63
+ findValidInsertionPoint (Operation *emptyTensorOp ,
64
64
const SmallVector<Value> &neededValues) {
65
65
DominanceInfo domInfo;
66
66
67
- // Gather all possible insertion points: the location of `allocTensorOp ` and
67
+ // Gather all possible insertion points: the location of `emptyTensorOp ` and
68
68
// right after the definition of each value in `neededValues`.
69
69
SmallVector<Operation *> insertionPointCandidates;
70
- insertionPointCandidates.push_back (allocTensorOp );
70
+ insertionPointCandidates.push_back (emptyTensorOp );
71
71
for (Value val : neededValues) {
72
72
// Note: The anchor op is using all of `neededValues`, so:
73
73
// * in case of a block argument: There must be at least one op in the block
@@ -90,7 +90,7 @@ findValidInsertionPoint(Operation *allocTensorOp,
90
90
neededValues))
91
91
continue ;
92
92
// Check if the insertion point is before all uses.
93
- if (!insertionPointDominatesUses (domInfo, insertionPoint, allocTensorOp ))
93
+ if (!insertionPointDominatesUses (domInfo, insertionPoint, emptyTensorOp ))
94
94
continue ;
95
95
return insertionPoint;
96
96
}
@@ -99,12 +99,12 @@ findValidInsertionPoint(Operation *allocTensorOp,
99
99
return nullptr ;
100
100
}
101
101
102
- // / Try to eliminate AllocTensorOps inside `op`. An AllocTensorOp is replaced
102
+ // / Try to eliminate tensor::EmptyOps inside `op`. A tensor::EmptyOp is replaced
103
103
// / with the result of `rewriteFunc` if it is anchored on a matching
104
104
// / OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
105
105
// / chain, starting from the OpOperand and always following the aliasing
106
- // / OpOperand, that eventually ends at a single AllocTensorOp .
107
- LogicalResult mlir::bufferization::eliminateAllocTensors (
106
+ // / OpOperand, that eventually ends at a single tensor::EmptyOp .
107
+ LogicalResult mlir::bufferization::eliminateEmptyTensors (
108
108
RewriterBase &rewriter, Operation *op, AnalysisState &state,
109
109
AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) {
110
110
OpBuilder::InsertionGuard g (rewriter);
@@ -119,56 +119,40 @@ LogicalResult mlir::bufferization::eliminateAllocTensors(
119
119
// Is this a matching OpOperand?
120
120
if (!anchorMatchFunc (operand, neededValues))
121
121
continue ;
122
- SetVector<Value> maybeAllocTensor =
123
- state.findValueInReverseUseDefChain (operand.get (), [&](Value val) {
124
- // Continue traversal until this function returns true.
125
- OpResult opResult = val.dyn_cast <OpResult>();
126
- if (!opResult)
127
- return true ;
128
- SmallVector<OpOperand *> opOperands =
129
- state.getAliasingOpOperand (opResult);
130
- if (!llvm::all_of (opOperands, [&](OpOperand *operand) {
131
- return state.isInPlace (*operand);
132
- }))
133
- return true ;
134
- // Only equivalent tensors are supported at the moment.
135
- // TODO: Support cases such as extract_slice(alloc_tensor)
136
- return !llvm::all_of (opOperands, [&](OpOperand *operand) {
137
- return state.areEquivalentBufferizedValues (operand->get (),
138
- opResult);
139
- });
140
- });
122
+ SetVector<Value> maybeEmptyTensor = state.findValueInReverseUseDefChain (
123
+ operand.get (), /* condition=*/ [&](Value val) { return false ; },
124
+ /* followEquivalentOnly=*/ true );
141
125
142
126
// Replace only if the reverse use-def chain ends at exactly one
143
- // AllocTensorOp .
144
- if (maybeAllocTensor .size () != 1 ||
145
- !maybeAllocTensor .front ().getDefiningOp <AllocTensorOp >())
127
+ // tensor::EmptyOp .
128
+ if (maybeEmptyTensor .size () != 1 ||
129
+ !maybeEmptyTensor .front ().getDefiningOp <tensor::EmptyOp >())
146
130
return WalkResult::skip ();
147
- Value allocTensor = maybeAllocTensor .front ();
131
+ Value emptyTensor = maybeEmptyTensor .front ();
148
132
149
133
// Replace only if the types match.
150
134
// TODO: This could be extended to support IR such as:
151
- // %0 = bufferization.alloc_tensor : tensor<128xf32>
135
+ // %0 = tensor.empty() : tensor<128xf32>
152
136
// %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
153
137
// %2 = tensor.expand_shape %1 ...
154
138
// %3 = tensor.insert_slice %2 into ...
155
- if (allocTensor .getType () != operand.get ().getType ())
139
+ if (emptyTensor .getType () != operand.get ().getType ())
156
140
return WalkResult::skip ();
157
141
158
142
// Find a suitable insertion point.
159
143
Operation *insertionPoint =
160
- findValidInsertionPoint (allocTensor .getDefiningOp (), neededValues);
144
+ findValidInsertionPoint (emptyTensor .getDefiningOp (), neededValues);
161
145
if (!insertionPoint)
162
146
continue ;
163
147
164
- // Create a replacement for the AllocTensorOp .
148
+ // Create a replacement for the tensor::EmptyOp .
165
149
rewriter.setInsertionPoint (insertionPoint);
166
- Value replacement = rewriteFunc (rewriter, allocTensor .getLoc (), operand);
150
+ Value replacement = rewriteFunc (rewriter, emptyTensor .getLoc (), operand);
167
151
if (!replacement)
168
152
continue ;
169
153
170
- // Replace the AllocTensorOp .
171
- rewriter.replaceOp (allocTensor .getDefiningOp (), replacement);
154
+ // Replace the tensor::EmptyOp .
155
+ rewriter.replaceOp (emptyTensor .getDefiningOp (), replacement);
172
156
}
173
157
174
158
// Advance to the next operation.
@@ -178,34 +162,35 @@ LogicalResult mlir::bufferization::eliminateAllocTensors(
178
162
return failure (status.wasInterrupted ());
179
163
}
180
164
181
- // / Try to eliminate AllocTensorOps inside `op`. An AllocTensorOp can be
165
+ // / Try to eliminate tensor::EmptyOps inside `op`. An tensor::EmptyOp can be
182
166
// / eliminated if it is eventually inserted into another tensor (and some other
183
167
// / conditions are met).
184
168
// /
185
169
// / E.g.:
186
- // / %0 = linalg.alloc_tensor
170
+ // / %0 = tensor.empty()
187
171
// / %1 = linalg.fill(%cst, %0) {inplace = [true]}
188
172
// / %2 = tensor.insert_slice %1 into %t[10][20][1]
189
173
// /
190
- // / AllocTensorOp elimination will try to fill %t inplace instead of filling a
174
+ // / tensor::EmptyOp elimination will try to fill %t inplace instead of filling a
191
175
// / new allocation %0 and inserting it into %t. This is done by replacing the
192
- // / AllocTensorOp with:
176
+ // / tensor::EmptyOp with:
193
177
// /
194
178
// / %0 = tensor.extract_slice %t[10][20][1]
195
179
// /
196
180
// / The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets
197
181
// / those bufferize inplace in the absence of other conflicts.
198
182
// /
199
- // / Starting from an InsertSliceOp, an AllocTensorOp at the end of the insert
183
+ // / Starting from an InsertSliceOp, an tensor::EmptyOp at the end of the insert
200
184
// / source's reverse use-def chain is eliminated if:
201
185
// / * On the reverse use-def chain path from the InsertSliceOp to the
202
- // / AllocTensorOp , all ops were decided to bufferize inplace and the buffer
186
+ // / tensor::EmptyOp , all ops were decided to bufferize inplace and the buffer
203
187
// / relation is "equivalent" (TODO: can be relaxed if needed).
204
- // / * The reverse use-def chain has exactly one end, which is the AllocTensorOp.
188
+ // / * The reverse use-def chain has exactly one end, which is the
189
+ // / tensor::EmptyOp.
205
190
LogicalResult
206
- mlir::bufferization::insertSliceAnchoredAllocTensorEliminationStep (
191
+ mlir::bufferization::insertSliceAnchoredEmptyTensorEliminationStep (
207
192
RewriterBase &rewriter, Operation *op, AnalysisState &state) {
208
- return eliminateAllocTensors (
193
+ return eliminateEmptyTensors (
209
194
rewriter, op, state,
210
195
/* anchorMatchFunc=*/
211
196
[&](OpOperand &operand, SmallVector<Value> &neededValues) {
@@ -239,10 +224,10 @@ mlir::bufferization::insertSliceAnchoredAllocTensorEliminationStep(
239
224
}
240
225
241
226
namespace {
242
- struct AllocTensorElimination
243
- : public bufferization::impl::AllocTensorEliminationBase <
244
- AllocTensorElimination > {
245
- AllocTensorElimination () = default ;
227
+ struct EmptyTensorElimination
228
+ : public bufferization::impl::EmptyTensorEliminationBase <
229
+ EmptyTensorElimination > {
230
+ EmptyTensorElimination () = default ;
246
231
247
232
void runOnOperation () override ;
248
233
@@ -253,7 +238,7 @@ struct AllocTensorElimination
253
238
};
254
239
} // namespace
255
240
256
- void AllocTensorElimination ::runOnOperation () {
241
+ void EmptyTensorElimination ::runOnOperation () {
257
242
Operation *op = getOperation ();
258
243
OneShotBufferizationOptions options;
259
244
OneShotAnalysisState state (op, options);
@@ -263,11 +248,11 @@ void AllocTensorElimination::runOnOperation() {
263
248
}
264
249
265
250
IRRewriter rewriter (op->getContext ());
266
- if (failed (bufferization::insertSliceAnchoredAllocTensorEliminationStep (
251
+ if (failed (bufferization::insertSliceAnchoredEmptyTensorEliminationStep (
267
252
rewriter, op, state)))
268
253
signalPassFailure ();
269
254
}
270
255
271
- std::unique_ptr<Pass> mlir::bufferization::createAllocTensorEliminationPass () {
272
- return std::make_unique<AllocTensorElimination >();
256
+ std::unique_ptr<Pass> mlir::bufferization::createEmptyTensorEliminationPass () {
257
+ return std::make_unique<EmptyTensorElimination >();
273
258
}
0 commit comments