Skip to content

Commit e62681e

Browse files
[mlir][bufferize] Eliminate tensor.empty ops instead of bufferization.alloc_tensor ops
tensor.empty op elimination is an optimization that brings IR in a more bufferization-friendly form. E.g.: ``` %0 = tensor.empty() %1 = linalg.fill(%cst, %0) {inplace = [true]} %2 = tensor.insert_slice %1 into %t[10][20][1] ``` Is rewritten to: ``` %0 = tensor.extract_slice %t[10][20][1] %1 = linalg.fill(%cst, %0) {inplace = [true]} %2 = tensor.insert_slice %1 into %t[10][20][1] ``` This optimization used to operate on bufferization.alloc_tensor ops. This is not correct because the documentation of bufferization.alloc_tensor says that it always bufferizes to an allocation. Instead, this optimization should operate on tensor.empty ops, which can then be lowered to bufferization.alloc_tensor ops (if they don't get eliminated). Differential Revision: https://reviews.llvm.org/D137162
1 parent 882ddab commit e62681e

File tree

7 files changed

+81
-100
lines changed

7 files changed

+81
-100
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
1-
//===- AllocTensorElimination.h - alloc_tensor op elimination -------------===//
1+
//===- EmptyTensorElimination.h - tensor.empty op elimination -------------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ALLOCTENSORELIMINATION_H
10-
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ALLOCTENSORELIMINATION_H
9+
#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_EMPTYTENSORELIMINATION_H
10+
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_EMPTYTENSORELIMINATION_H
1111

1212
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
1313

1414
namespace mlir {
1515
namespace bufferization {
1616

17-
/// A function that matches anchor OpOperands for AllocTensorOp elimination.
17+
/// A function that matches anchor OpOperands for tensor::EmptyOp elimination.
1818
/// If an OpOperand is matched, the function should populate the SmallVector
1919
/// with all values that are needed during `RewriteFn` to produce the
2020
/// replacement value.
@@ -23,26 +23,26 @@ using AnchorMatchFn = std::function<bool(OpOperand &, SmallVector<Value> &)>;
2323
/// A function that rewrites matched anchors.
2424
using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
2525

26-
/// Try to eliminate AllocTensorOps inside `op`.
26+
/// Try to eliminate tensor::EmptyOps inside `op`.
2727
///
28-
/// * `rewriteFunc` generates the replacement for the AllocTensorOp.
29-
/// * Only AllocTensorOps that are anchored on a matching OpOperand as per
28+
/// * `rewriteFunc` generates the replacement for the tensor::EmptyOp.
29+
/// * Only tensor::EmptyOps that are anchored on a matching OpOperand as per
3030
/// `anchorMatchFunc` are considered. "Anchored" means that there is a path
3131
/// on the reverse SSA use-def chain, starting from the OpOperand and always
3232
/// following the aliasing OpOperand, that eventually ends at a single
33-
/// AllocTensorOp.
34-
LogicalResult eliminateAllocTensors(RewriterBase &rewriter, Operation *op,
33+
/// tensor::EmptyOp.
34+
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
3535
bufferization::AnalysisState &state,
3636
AnchorMatchFn anchorMatchFunc,
3737
RewriteFn rewriteFunc);
3838

39-
/// Try to eliminate AllocTensorOps inside `op` that are anchored on an
39+
/// Try to eliminate tensor::EmptyOps inside `op` that are anchored on an
4040
/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
4141
/// (and some other conditions are met).
42-
LogicalResult insertSliceAnchoredAllocTensorEliminationStep(
42+
LogicalResult insertSliceAnchoredEmptyTensorEliminationStep(
4343
RewriterBase &rewriter, Operation *op, bufferization::AnalysisState &state);
4444

4545
} // namespace bufferization
4646
} // namespace mlir
4747

48-
#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ALLOCTENSORELIMINATION_H
48+
#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_EMPTYTENSORELIMINATION_H

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ createPromoteBuffersToStackPass(unsigned maxAllocSizeInBytes = 1024,
9090
std::unique_ptr<Pass>
9191
createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc);
9292

93-
/// Create a pass that tries to eliminate alloc_tensor ops that are anchored on
93+
/// Create a pass that tries to eliminate tensor.empty ops that are anchored on
9494
/// insert_slice ops.
95-
std::unique_ptr<Pass> createAllocTensorEliminationPass();
95+
std::unique_ptr<Pass> createEmptyTensorEliminationPass();
9696

9797
/// Create a pass that bufferizes ops from the bufferization dialect.
9898
std::unique_ptr<Pass> createBufferizationBufferizePass();

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -371,16 +371,16 @@ def TensorCopyInsertion : Pass<"tensor-copy-insertion"> {
371371
let constructor = "mlir::bufferization::createTensorCopyInsertionPass()";
372372
}
373373

374-
def AllocTensorElimination : Pass<"eliminate-alloc-tensors"> {
375-
let summary = "Try to eliminate all alloc_tensor ops.";
374+
def EmptyTensorElimination : Pass<"eliminate-empty-tensors"> {
375+
let summary = "Try to eliminate all tensor.empty ops.";
376376
let description = [{
377-
This pass tries to eliminate all insert_slice op-anchored alloc_tensor ops.
378-
I.e., when a value that is equivalent to an alloc_tensor op is inserted into
377+
This pass tries to eliminate all insert_slice op-anchored tensor.empty ops.
378+
I.e., when a value that is equivalent to an tensor.empty op is inserted into
379379
another tensor, this pass tries to rewrite the IR in such a way that the
380380
destination tensor of the insert_slice op is used directly instead of the
381-
alloc_tensor result.
381+
tensor.empty result.
382382
}];
383-
let constructor = "mlir::bufferization::createAllocTensorEliminationPass()";
383+
let constructor = "mlir::bufferization::createEmptyTensorEliminationPass()";
384384
}
385385

386386
#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES

mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
add_mlir_dialect_library(MLIRBufferizationTransforms
2-
AllocTensorElimination.cpp
32
Bufferize.cpp
43
BufferDeallocation.cpp
54
BufferOptimizations.cpp
65
BufferResultsToOutParams.cpp
76
BufferUtils.cpp
87
BufferViewFlowAnalysis.cpp
98
DropEquivalentBufferResults.cpp
9+
EmptyTensorElimination.cpp
1010
EmptyTensorToAllocTensor.cpp
1111
FuncBufferizableOpInterfaceImpl.cpp
1212
OneShotAnalysis.cpp

mlir/lib/Dialect/Bufferization/Transforms/AllocTensorElimination.cpp renamed to mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp

Lines changed: 46 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- AllocTensorElimination.cpp - alloc_tensor op elimination -----------===//
1+
//===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -10,15 +10,15 @@
1010

1111
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1212
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13-
#include "mlir/Dialect/Bufferization/Transforms/AllocTensorElimination.h"
13+
#include "mlir/Dialect/Bufferization/Transforms/EmptyTensorElimination.h"
1414
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
1515
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1616
#include "mlir/IR/Dominance.h"
1717
#include "mlir/Pass/Pass.h"
1818

1919
namespace mlir {
2020
namespace bufferization {
21-
#define GEN_PASS_DEF_ALLOCTENSORELIMINATION
21+
#define GEN_PASS_DEF_EMPTYTENSORELIMINATION
2222
#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
2323
} // namespace bufferization
2424
} // namespace mlir
@@ -47,27 +47,27 @@ neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
4747
}
4848

4949
/// Return true if the given `insertionPoint` dominates all uses of
50-
/// `allocTensorOp`.
50+
/// `emptyTensorOp`.
5151
static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
5252
Operation *insertionPoint,
53-
Operation *allocTensorOp) {
54-
for (Operation *user : allocTensorOp->getUsers())
53+
Operation *emptyTensorOp) {
54+
for (Operation *user : emptyTensorOp->getUsers())
5555
if (!domInfo.dominates(insertionPoint, user))
5656
return false;
5757
return true;
5858
}
5959

60-
/// Find a valid insertion point for a replacement of `allocTensorOp`, assuming
60+
/// Find a valid insertion point for a replacement of `emptyTensorOp`, assuming
6161
/// that the replacement may use any value from `neededValues`.
6262
static Operation *
63-
findValidInsertionPoint(Operation *allocTensorOp,
63+
findValidInsertionPoint(Operation *emptyTensorOp,
6464
const SmallVector<Value> &neededValues) {
6565
DominanceInfo domInfo;
6666

67-
// Gather all possible insertion points: the location of `allocTensorOp` and
67+
// Gather all possible insertion points: the location of `emptyTensorOp` and
6868
// right after the definition of each value in `neededValues`.
6969
SmallVector<Operation *> insertionPointCandidates;
70-
insertionPointCandidates.push_back(allocTensorOp);
70+
insertionPointCandidates.push_back(emptyTensorOp);
7171
for (Value val : neededValues) {
7272
// Note: The anchor op is using all of `neededValues`, so:
7373
// * in case of a block argument: There must be at least one op in the block
@@ -90,7 +90,7 @@ findValidInsertionPoint(Operation *allocTensorOp,
9090
neededValues))
9191
continue;
9292
// Check if the insertion point is before all uses.
93-
if (!insertionPointDominatesUses(domInfo, insertionPoint, allocTensorOp))
93+
if (!insertionPointDominatesUses(domInfo, insertionPoint, emptyTensorOp))
9494
continue;
9595
return insertionPoint;
9696
}
@@ -99,12 +99,12 @@ findValidInsertionPoint(Operation *allocTensorOp,
9999
return nullptr;
100100
}
101101

102-
/// Try to eliminate AllocTensorOps inside `op`. An AllocTensorOp is replaced
102+
/// Try to eliminate tensor::EmptyOps inside `op`. A tensor::EmptyOp is replaced
103103
/// with the result of `rewriteFunc` if it is anchored on a matching
104104
/// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
105105
/// chain, starting from the OpOperand and always following the aliasing
106-
/// OpOperand, that eventually ends at a single AllocTensorOp.
107-
LogicalResult mlir::bufferization::eliminateAllocTensors(
106+
/// OpOperand, that eventually ends at a single tensor::EmptyOp.
107+
LogicalResult mlir::bufferization::eliminateEmptyTensors(
108108
RewriterBase &rewriter, Operation *op, AnalysisState &state,
109109
AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) {
110110
OpBuilder::InsertionGuard g(rewriter);
@@ -119,56 +119,40 @@ LogicalResult mlir::bufferization::eliminateAllocTensors(
119119
// Is this a matching OpOperand?
120120
if (!anchorMatchFunc(operand, neededValues))
121121
continue;
122-
SetVector<Value> maybeAllocTensor =
123-
state.findValueInReverseUseDefChain(operand.get(), [&](Value val) {
124-
// Continue traversal until this function returns true.
125-
OpResult opResult = val.dyn_cast<OpResult>();
126-
if (!opResult)
127-
return true;
128-
SmallVector<OpOperand *> opOperands =
129-
state.getAliasingOpOperand(opResult);
130-
if (!llvm::all_of(opOperands, [&](OpOperand *operand) {
131-
return state.isInPlace(*operand);
132-
}))
133-
return true;
134-
// Only equivalent tensors are supported at the moment.
135-
// TODO: Support cases such as extract_slice(alloc_tensor)
136-
return !llvm::all_of(opOperands, [&](OpOperand *operand) {
137-
return state.areEquivalentBufferizedValues(operand->get(),
138-
opResult);
139-
});
140-
});
122+
SetVector<Value> maybeEmptyTensor = state.findValueInReverseUseDefChain(
123+
operand.get(), /*condition=*/[&](Value val) { return false; },
124+
/*followEquivalentOnly=*/true);
141125

142126
// Replace only if the reverse use-def chain ends at exactly one
143-
// AllocTensorOp.
144-
if (maybeAllocTensor.size() != 1 ||
145-
!maybeAllocTensor.front().getDefiningOp<AllocTensorOp>())
127+
// tensor::EmptyOp.
128+
if (maybeEmptyTensor.size() != 1 ||
129+
!maybeEmptyTensor.front().getDefiningOp<tensor::EmptyOp>())
146130
return WalkResult::skip();
147-
Value allocTensor = maybeAllocTensor.front();
131+
Value emptyTensor = maybeEmptyTensor.front();
148132

149133
// Replace only if the types match.
150134
// TODO: This could be extended to support IR such as:
151-
// %0 = bufferization.alloc_tensor : tensor<128xf32>
135+
// %0 = tensor.empty() : tensor<128xf32>
152136
// %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
153137
// %2 = tensor.expand_shape %1 ...
154138
// %3 = tensor.insert_slice %2 into ...
155-
if (allocTensor.getType() != operand.get().getType())
139+
if (emptyTensor.getType() != operand.get().getType())
156140
return WalkResult::skip();
157141

158142
// Find a suitable insertion point.
159143
Operation *insertionPoint =
160-
findValidInsertionPoint(allocTensor.getDefiningOp(), neededValues);
144+
findValidInsertionPoint(emptyTensor.getDefiningOp(), neededValues);
161145
if (!insertionPoint)
162146
continue;
163147

164-
// Create a replacement for the AllocTensorOp.
148+
// Create a replacement for the tensor::EmptyOp.
165149
rewriter.setInsertionPoint(insertionPoint);
166-
Value replacement = rewriteFunc(rewriter, allocTensor.getLoc(), operand);
150+
Value replacement = rewriteFunc(rewriter, emptyTensor.getLoc(), operand);
167151
if (!replacement)
168152
continue;
169153

170-
// Replace the AllocTensorOp.
171-
rewriter.replaceOp(allocTensor.getDefiningOp(), replacement);
154+
// Replace the tensor::EmptyOp.
155+
rewriter.replaceOp(emptyTensor.getDefiningOp(), replacement);
172156
}
173157

174158
// Advance to the next operation.
@@ -178,34 +162,35 @@ LogicalResult mlir::bufferization::eliminateAllocTensors(
178162
return failure(status.wasInterrupted());
179163
}
180164

181-
/// Try to eliminate AllocTensorOps inside `op`. An AllocTensorOp can be
165+
/// Try to eliminate tensor::EmptyOps inside `op`. An tensor::EmptyOp can be
182166
/// eliminated if it is eventually inserted into another tensor (and some other
183167
/// conditions are met).
184168
///
185169
/// E.g.:
186-
/// %0 = linalg.alloc_tensor
170+
/// %0 = tensor.empty()
187171
/// %1 = linalg.fill(%cst, %0) {inplace = [true]}
188172
/// %2 = tensor.insert_slice %1 into %t[10][20][1]
189173
///
190-
/// AllocTensorOp elimination will try to fill %t inplace instead of filling a
174+
/// tensor::EmptyOp elimination will try to fill %t inplace instead of filling a
191175
/// new allocation %0 and inserting it into %t. This is done by replacing the
192-
/// AllocTensorOp with:
176+
/// tensor::EmptyOp with:
193177
///
194178
/// %0 = tensor.extract_slice %t[10][20][1]
195179
///
196180
/// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets
197181
/// those bufferize inplace in the absence of other conflicts.
198182
///
199-
/// Starting from an InsertSliceOp, an AllocTensorOp at the end of the insert
183+
/// Starting from an InsertSliceOp, an tensor::EmptyOp at the end of the insert
200184
/// source's reverse use-def chain is eliminated if:
201185
/// * On the reverse use-def chain path from the InsertSliceOp to the
202-
/// AllocTensorOp, all ops were decided to bufferize inplace and the buffer
186+
/// tensor::EmptyOp, all ops were decided to bufferize inplace and the buffer
203187
/// relation is "equivalent" (TODO: can be relaxed if needed).
204-
/// * The reverse use-def chain has exactly one end, which is the AllocTensorOp.
188+
/// * The reverse use-def chain has exactly one end, which is the
189+
/// tensor::EmptyOp.
205190
LogicalResult
206-
mlir::bufferization::insertSliceAnchoredAllocTensorEliminationStep(
191+
mlir::bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
207192
RewriterBase &rewriter, Operation *op, AnalysisState &state) {
208-
return eliminateAllocTensors(
193+
return eliminateEmptyTensors(
209194
rewriter, op, state,
210195
/*anchorMatchFunc=*/
211196
[&](OpOperand &operand, SmallVector<Value> &neededValues) {
@@ -239,10 +224,10 @@ mlir::bufferization::insertSliceAnchoredAllocTensorEliminationStep(
239224
}
240225

241226
namespace {
242-
struct AllocTensorElimination
243-
: public bufferization::impl::AllocTensorEliminationBase<
244-
AllocTensorElimination> {
245-
AllocTensorElimination() = default;
227+
struct EmptyTensorElimination
228+
: public bufferization::impl::EmptyTensorEliminationBase<
229+
EmptyTensorElimination> {
230+
EmptyTensorElimination() = default;
246231

247232
void runOnOperation() override;
248233

@@ -253,7 +238,7 @@ struct AllocTensorElimination
253238
};
254239
} // namespace
255240

256-
void AllocTensorElimination::runOnOperation() {
241+
void EmptyTensorElimination::runOnOperation() {
257242
Operation *op = getOperation();
258243
OneShotBufferizationOptions options;
259244
OneShotAnalysisState state(op, options);
@@ -263,11 +248,11 @@ void AllocTensorElimination::runOnOperation() {
263248
}
264249

265250
IRRewriter rewriter(op->getContext());
266-
if (failed(bufferization::insertSliceAnchoredAllocTensorEliminationStep(
251+
if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
267252
rewriter, op, state)))
268253
signalPassFailure();
269254
}
270255

271-
std::unique_ptr<Pass> mlir::bufferization::createAllocTensorEliminationPass() {
272-
return std::make_unique<AllocTensorElimination>();
256+
std::unique_ptr<Pass> mlir::bufferization::createEmptyTensorEliminationPass() {
257+
return std::make_unique<EmptyTensorElimination>();
273258
}

mlir/test/Dialect/Linalg/one-shot-bufferize-analysis-init-tensor-elimination.mlir renamed to mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
1-
// RUN: mlir-opt %s -eliminate-alloc-tensors -one-shot-bufferize="bufferize-function-boundaries test-analysis-only allow-return-allocs" -split-input-file | FileCheck %s
2-
3-
//===----------------------------------------------------------------------===//
4-
// AllocTensorOp elimination
5-
//===----------------------------------------------------------------------===//
1+
// RUN: mlir-opt %s -eliminate-empty-tensors -empty-tensor-to-alloc-tensor -one-shot-bufferize="bufferize-function-boundaries test-analysis-only allow-return-allocs" -split-input-file | FileCheck %s
62

73
// CHECK-LABEL: func @buffer_forwarding_conflict
84
func.func @buffer_forwarding_conflict(%arg0: tensor<?xf32> {bufferization.writable = true}, %arg1: index) -> (tensor<?xf32>, tensor<?xf32>) {
95
%cst = arith.constant 0.000000e+00 : f32
106
// CHECK: tensor.extract_slice
117
// CHECK-SAME: {__inplace_operands_attr__ = ["false", "none"]
128
// Instead of allocating, share buffer with some inplace bufferization?
13-
%0 = bufferization.alloc_tensor(%arg1) : tensor<?xf32>
9+
%0 = tensor.empty(%arg1) : tensor<?xf32>
1410

1511
// CHECK: linalg.fill
1612
// CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]
@@ -37,7 +33,7 @@ func.func @buffer_forwarding_no_conflict(%arg0: tensor<?xf32> {bufferization.wri
3733
// CHECK: tensor.extract_slice
3834
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]
3935
// Instead of allocating, share buffer with some inplace bufferization?
40-
%0 = bufferization.alloc_tensor(%arg1) : tensor<?xf32>
36+
%0 = tensor.empty(%arg1) : tensor<?xf32>
4137

4238
// CHECK: linalg.fill
4339
// CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]

0 commit comments

Comments
 (0)