Skip to content

Commit cc09909

Browse files
committed
[mlir][bufferization]-Replace only one use in TensorEmptyElimination
This MR hanldes the second case where we want to replace only the specific use which was visited in the `use-def` chain (when traversing from the tensor.insert_slice's source). This scenario of replacing all the uses of the tensor.empty may lead into additional read effects after bufferization of the specific subset extract/subview which should not be the case, Thus eliminating a potential copies.
1 parent 3054b21 commit cc09909

File tree

5 files changed

+57
-42
lines changed

5 files changed

+57
-42
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,8 @@ class AnalysisState {
459459
/// Starting from `value`, follow the use-def chain in reverse, always
460460
/// selecting the aliasing OpOperands. Find and return Values for which
461461
/// `condition` evaluates to true. OpOperands of such matching Values are not
462-
/// traversed any further.
462+
/// traversed any further, The visited aliasing opOperands will be preserved
463+
/// through `visitedOpOperands`.
463464
///
464465
/// When reaching the end of a chain, also return the last Value of that
465466
/// chain if `config.alwaysIncludeLeaves` is set.
@@ -484,7 +485,8 @@ class AnalysisState {
484485
/// `config`.
485486
SetVector<Value> findValueInReverseUseDefChain(
486487
Value value, llvm::function_ref<bool(Value)> condition,
487-
TraversalConfig config = TraversalConfig()) const;
488+
TraversalConfig config = TraversalConfig(),
489+
llvm::DenseSet<OpOperand*> *visitedOpOperands = nullptr) const;
488490

489491
/// Find the values that may define the contents of the given value at
490492
/// runtime. A block argument is always a definition. An OpResult is a

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,10 +483,12 @@ bool AnalysisState::isValueRead(Value value) const {
483483
// Starting from `value`, follow the use-def chain in reverse, always selecting
484484
// the aliasing OpOperands. Find and return Values for which `condition`
485485
// evaluates to true. OpOperands of such matching Values are not traversed any
486-
// further.
486+
// further, The visited aliasing opOperands will be preserved through
487+
// `visitedOpOperands`.
487488
llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
488489
Value value, llvm::function_ref<bool(Value)> condition,
489-
TraversalConfig config) const {
490+
TraversalConfig config,
491+
llvm::DenseSet<OpOperand*> *visitedOpOperands) const {
490492
llvm::DenseSet<Value> visited;
491493
llvm::SetVector<Value> result, workingSet;
492494
workingSet.insert(value);
@@ -553,6 +555,8 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
553555
}
554556

555557
workingSet.insert(a.opOperand->get());
558+
if (visitedOpOperands)
559+
visitedOpOperands->insert(a.opOperand);
556560
}
557561
}
558562

mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -57,44 +57,40 @@ neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
5757
return true;
5858
}
5959

60-
/// Return true if the given `insertionPoint` dominates all uses of
61-
/// `emptyTensorOp`.
62-
static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
63-
Operation *insertionPoint,
64-
Operation *emptyTensorOp) {
65-
return llvm::all_of(emptyTensorOp->getUsers(), [&](Operation *user) {
66-
return domInfo.dominates(insertionPoint, user);
67-
});
68-
}
69-
70-
/// Find a valid insertion point for a replacement of `emptyTensorOp`, assuming
60+
/// Find a valid insertion point for a replacement of `useToBeEliminated`, assuming
7161
/// that the replacement may use any value from `neededValues`.
7262
static Operation *
73-
findValidInsertionPoint(Operation *emptyTensorOp,
63+
findValidInsertionPoint(OpOperand *useToBeEliminated,
7464
const SmallVector<Value> &neededValues) {
7565
DominanceInfo domInfo;
7666

67+
Operation * candidateInsertionPoint = useToBeEliminated->getOwner();
68+
assert(isa<OpResult>(useToBeEliminated->get()) && "expected a result value");
69+
// Both `tensor.empty` and its user are within different blocks.
70+
if (useToBeEliminated->getOwner()->getBlock() != useToBeEliminated->get().getDefiningOp()->getBlock())
71+
candidateInsertionPoint = useToBeEliminated->get().getDefiningOp();
72+
7773
// Trying to move the needed values before the `emptyTensorOp`.
7874
for (Value val : neededValues) {
79-
if (valueDominateInsertionPoint(domInfo, emptyTensorOp, val))
75+
if (valueDominateInsertionPoint(domInfo, candidateInsertionPoint, val))
8076
continue;
8177
Operation *definingOp = val.getDefiningOp();
8278
if (!definingOp)
8379
continue;
8480

8581
bool isItSafeToMoveOp =
8682
llvm::all_of(definingOp->getOperands(), [&](Value operand) {
87-
return valueDominateInsertionPoint(domInfo, emptyTensorOp, operand);
83+
return valueDominateInsertionPoint(domInfo, candidateInsertionPoint, operand);
8884
});
8985

9086
if (isItSafeToMoveOp)
91-
definingOp->moveBefore(emptyTensorOp);
87+
definingOp->moveBefore(candidateInsertionPoint);
9288
}
9389

94-
// Gather all possible insertion points: the location of `emptyTensorOp` and
90+
// Gather all possible insertion points: the location of `candidateInsertionPoint` and
9591
// right after the definition of each value in `neededValues`.
9692
SmallVector<Operation *> insertionPointCandidates;
97-
insertionPointCandidates.push_back(emptyTensorOp);
93+
insertionPointCandidates.push_back(candidateInsertionPoint);
9894
for (Value val : neededValues) {
9995
// Note: The anchor op is using all of `neededValues`, so:
10096
// * in case of a block argument: There must be at least one op in the block
@@ -116,8 +112,8 @@ findValidInsertionPoint(Operation *emptyTensorOp,
116112
if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
117113
neededValues))
118114
continue;
119-
// Check if the insertion point is before all uses.
120-
if (!insertionPointDominatesUses(domInfo, insertionPoint, emptyTensorOp))
115+
// Check if the insertion point is before the use to be replaced.
116+
if (!domInfo.dominates(insertionPoint, useToBeEliminated->getOwner()))
121117
continue;
122118
return insertionPoint;
123119
}
@@ -129,8 +125,9 @@ findValidInsertionPoint(Operation *emptyTensorOp,
129125
LogicalResult mlir::bufferization::eliminateEmptyTensors(
130126
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
131127
OpBuilder::InsertionGuard g(rewriter);
132-
128+
llvm::DenseSet<OpOperand *> visitedOpOperands;
133129
op->walk([&](SubsetInsertionOpInterface op) {
130+
visitedOpOperands.clear();
134131
OpOperand &source = op.getSourceOperand();
135132
// Skip operands that do not bufferize inplace. "tensor.empty" could still
136133
// be replaced, but the transformation may not be beneficial.
@@ -158,15 +155,23 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
158155
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
159156
source.get(), /*condition=*/
160157
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
161-
config);
158+
config, &visitedOpOperands);
162159

163160
for (Value v : emptyTensors) {
164161
Operation *emptyTensorOp = v.getDefiningOp();
165162

163+
// Find the use to be replaced from the use-def chain
164+
auto iter = llvm::find_if(visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand){
165+
return llvm::count(emptyTensorOp->getUses(), *opOperand);
166+
});
167+
if (iter == visitedOpOperands.end())
168+
continue;
169+
OpOperand *useToBeReplaced = *iter;
170+
166171
// Find a suitable insertion point. If no suitable insertion point for
167172
// the replacement can be found, skip this replacement.
168173
Operation *insertionPoint =
169-
findValidInsertionPoint(emptyTensorOp, neededValues);
174+
findValidInsertionPoint(useToBeReplaced, neededValues);
170175
if (!insertionPoint)
171176
continue;
172177

@@ -185,8 +190,8 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
185190
replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
186191
replacement);
187192
}
188-
// Replace the tensor::EmptyOp.
189-
rewriter.replaceOp(emptyTensorOp, replacement);
193+
// Replace the specific use of the tensor::EmptyOp.
194+
useToBeReplaced->getOwner()->setOperand(useToBeReplaced->getOperandNumber(), replacement);
190195
state.resetCache();
191196
}
192197

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,8 @@ func.func @buffer_forwarding_no_conflict(%arg0: tensor<?xf32> {bufferization.wri
5252

5353
// CHECK-LABEL: func @buffer_forwarding_conflict_with_different_element_type
5454
func.func @buffer_forwarding_conflict_with_different_element_type(%arg0: tensor<?xf32> {bufferization.writable = true}, %arg1: index) -> (tensor<?xf32>, tensor<?xf32>) {
55-
// CHECK: tensor.extract_slice
56-
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]
5755
%cst = arith.constant 0.000000e+00 : f32
56+
// CHECK: bufferization.alloc_tensor(%arg1)
5857
%0 = tensor.empty(%arg1) : tensor<?xf32>
5958

6059
// CHECK: bufferization.alloc_tensor(%arg1)
@@ -64,6 +63,10 @@ func.func @buffer_forwarding_conflict_with_different_element_type(%arg0: tensor<
6463
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]
6564
%2 = linalg.copy ins(%0 : tensor<?xf32>) outs(%1 : tensor<?xbf16>) -> tensor<?xbf16>
6665

66+
67+
// CHECK: tensor.extract_slice
68+
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]
69+
6770
// CHECK: linalg.copy
6871
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]
6972
%3 = linalg.copy ins(%2 : tensor<?xbf16>) outs(%0 : tensor<?xf32>) -> tensor<?xf32>

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
396396
func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
397397
%cst_1 = arith.constant 1.0 : f32
398398
%cst_2 = arith.constant 2.0 : f32
399-
// CHECK: memref.alloc
399+
// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
400400
// CHECK-NOT: memref.alloc
401401
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
402402
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
@@ -413,10 +413,9 @@ func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
413413

414414
// -----
415415

416-
// `EmptyTensorElimination` replaces all of the uses of the tensor
417-
// empty with the new injected `SubsetExtraction`, without to consider
418-
// the specific use has been tracked, sometimes creating a non existent
419-
// bufferization conflicts.
416+
// `EmptyTensorElimination` will replace the specific use of the tensor
417+
// empty with the new injected `SubsetExtraction`, i.e. the specific use
418+
// which has been tracked.
420419

421420
// CHECK-ELIM-LABEL: func.func @mutli_use_of_the_same_tensor_empty
422421
// CHECK-LABEL: func.func @mutli_use_of_the_same_tensor_empty
@@ -425,15 +424,16 @@ func.func @mutli_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> {
425424
%cst_2 = arith.constant 2.0 : f32
426425
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
427426
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
428-
// CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice
429-
// CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]]
430-
// CHECK-ELIM: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]]
427+
// CHECK-ELIM: %[[VAL_4:.*]] = tensor.extract_slice %[[VAL_2:.*]]
428+
// CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_4]] : tensor<5x6x64xf32>)
429+
// CHECK-ELIM: %[[VAL_6:.*]] = tensor.insert_slice
431430
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
431+
// CHECK-ELIM: %[[VAL_7:.*]] = tensor.extract_slice %[[VAL_6]]
432+
// CHECK-ELIM: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_7]] : tensor<5x6x64xf32>)
432433
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
433-
// CHECK: memref.copy
434+
// CHECK-NOT: memref.copy
434435
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
435436
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
436-
// CHECK: memref.copy
437437
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
438438
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
439439
return %inserted_slice_2 : tensor<5x6x128xf32>
@@ -446,7 +446,8 @@ func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: t
446446
-> (tensor<5x6x128xf32>, tensor<5x6x64xf32>) {
447447
%cst_1 = arith.constant 1.0 : f32
448448
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
449-
// CHECK: memref.alloc
449+
// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x64xf32>
450+
// CHECK-NOT: memref.alloc
450451
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
451452
%res_2 = linalg.generic{
452453
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
@@ -458,7 +459,7 @@ func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: t
458459
%res = arith.addf %in, %in : f32
459460
linalg.yield %res : f32
460461
} -> tensor<5x6x64xf32>
461-
// CHECK: memref.copy
462+
// CHECK-NOT: memref.copy
462463
%inserted_slice_1 = tensor.insert_slice %res_1 into %arg1[0, 0, 0][5, 6, 64][1, 1, 1]
463464
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
464465
return %inserted_slice_1, %res_2 : tensor<5x6x128xf32>, tensor<5x6x64xf32>

0 commit comments

Comments
 (0)