Skip to content

Commit 3054b21

Browse files
committed
[mlir][bufferization]-Try to move the needed values for subsetExtract in EmptyTensorElimination
In this MR, we will handle the case were we may fail finding a legal/suitable insertion point for the subsetExtract which is about to replace the empty tensor. For this reason, now we try also to move the needed values which are responsible to create the `subsetExtract` before the candidate insertion point (tensor.empty about to be eliminated).
1 parent ba1a870 commit 3054b21

File tree

2 files changed

+44
-18
lines changed

2 files changed

+44
-18
lines changed

mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,32 @@ namespace bufferization {
2828
using namespace mlir;
2929
using namespace mlir::bufferization;
3030

31+
/// Return true if `val` is in scope at the given
32+
/// `insertionPoint`.
33+
static bool valueDominateInsertionPoint(const DominanceInfo &domInfo,
34+
Operation *insertionPoint, Value val) {
35+
if (auto bbArg = dyn_cast<BlockArgument>(val)) {
36+
Block *owner = bbArg.getOwner();
37+
if (!owner->findAncestorOpInBlock(*insertionPoint))
38+
return false;
39+
} else {
40+
auto opResult = cast<OpResult>(val);
41+
if (!domInfo.properlyDominates(opResult.getOwner(), insertionPoint))
42+
return false;
43+
}
44+
return true;
45+
}
46+
3147
/// Return true if all `neededValues` are in scope at the given
3248
/// `insertionPoint`.
3349
static bool
3450
neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
3551
Operation *insertionPoint,
3652
const SmallVector<Value> &neededValues) {
37-
for (Value val : neededValues) {
38-
if (auto bbArg = dyn_cast<BlockArgument>(val)) {
39-
Block *owner = bbArg.getOwner();
40-
if (!owner->findAncestorOpInBlock(*insertionPoint))
41-
return false;
42-
} else {
43-
auto opResult = cast<OpResult>(val);
44-
if (!domInfo.properlyDominates(opResult.getOwner(), insertionPoint))
45-
return false;
46-
}
47-
}
53+
for (Value val : neededValues)
54+
if (!valueDominateInsertionPoint(domInfo, insertionPoint, val))
55+
return false;
56+
4857
return true;
4958
}
5059

@@ -65,6 +74,23 @@ findValidInsertionPoint(Operation *emptyTensorOp,
6574
const SmallVector<Value> &neededValues) {
6675
DominanceInfo domInfo;
6776

77+
// Trying to move the needed values before the `emptyTensorOp`.
78+
for (Value val : neededValues) {
79+
if (valueDominateInsertionPoint(domInfo, emptyTensorOp, val))
80+
continue;
81+
Operation *definingOp = val.getDefiningOp();
82+
if (!definingOp)
83+
continue;
84+
85+
bool isItSafeToMoveOp =
86+
llvm::all_of(definingOp->getOperands(), [&](Value operand) {
87+
return valueDominateInsertionPoint(domInfo, emptyTensorOp, operand);
88+
});
89+
90+
if (isItSafeToMoveOp)
91+
definingOp->moveBefore(emptyTensorOp);
92+
}
93+
6894
// Gather all possible insertion points: the location of `emptyTensorOp` and
6995
// right after the definition of each value in `neededValues`.
7096
SmallVector<Operation *> insertionPointCandidates;

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -368,21 +368,21 @@ func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32
368368

369369
// -----
370370

371-
// `EmptyTensorElimination` fails to find a valid insertion
372-
// point for the new injected `SubsetExtraction`.
371+
// `EmptyTensorElimination` finds a valid insertion
372+
// point for the new injected `SubsetExtraction` by
373+
// trying to move the needed value for the extraction.
373374
// CHECK-LABEL: func.func @fail_to_eliminate_any_empty_tensors
374375
func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
375376
%cst_1 = arith.constant 1.0 : f32
376377
%cst_2 = arith.constant 2.0 : f32
377378
// CHECK: memref.alloc
378-
// CHECK: memref.alloc
379-
// CHECK: memref.alloc
379+
// CHECK-NOT: memref.alloc
380380
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
381381
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
382382
%empty_2 = tensor.empty() : tensor<5x6x64xf32>
383383
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
384384
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
385-
// CHECK: memref.copy
385+
// CHECK-NOT: memref.copy
386386
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
387387
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
388388
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
@@ -397,13 +397,13 @@ func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
397397
%cst_1 = arith.constant 1.0 : f32
398398
%cst_2 = arith.constant 2.0 : f32
399399
// CHECK: memref.alloc
400-
// CHECK: memref.alloc
400+
// CHECK-NOT: memref.alloc
401401
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
402402
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
403403
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
404404
%empty_2 = tensor.empty() : tensor<5x6x64xf32>
405405
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
406-
// CHECK: memref.copy
406+
// CHECK-NOT: memref.copy
407407
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
408408
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
409409
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]

0 commit comments

Comments
 (0)