Skip to content

Commit 64839fb

Browse files
[mlir][bufferization] Empty tensor elimination for materialize_in_destination (llvm#65468)
This revision adds support for empty tensor elimination to "bufferization.materialize_in_destination" by implementing the `SubsetInsertionOpInterface`. Furthermore, the One-Shot Bufferize conflict detection is improved for "bufferization.materialize_in_destination".
1 parent be2723d commit 64839fb

File tree

6 files changed

+87
-1
lines changed

6 files changed

+87
-1
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
#include "mlir/Bytecode/BytecodeOpInterface.h"
1313
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
1414
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
15+
#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h"
1516
#include "mlir/Interfaces/CopyOpInterface.h"
17+
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
1618
#include "mlir/Interfaces/InferTypeOpInterface.h"
1719

1820
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
1313
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
1414
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
15+
include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.td"
16+
include "mlir/Interfaces/DestinationStyleOpInterface.td"
1517
include "mlir/Interfaces/InferTypeOpInterface.td"
1618
include "mlir/Interfaces/SideEffectInterfaces.td"
1719
include "mlir/Interfaces/CopyOpInterface.td"
@@ -215,7 +217,11 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [
215217
def Bufferization_MaterializeInDestinationOp
216218
: Bufferization_Op<"materialize_in_destination",
217219
[BufferizableOpInterface, SameOperandsAndResultType,
218-
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
220+
DestinationStyleOpInterface,
221+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
222+
DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
223+
["getSourceOperand", "getValuesNeededToBuildSubsetExtraction",
224+
"buildSubsetExtraction", "isEquivalentSubset"]>]> {
219225
let summary = "copy a tensor";
220226

221227
let description = [{
@@ -248,12 +254,19 @@ def Bufferization_MaterializeInDestinationOp
248254
bool bufferizesToMemoryWrite(OpOperand &opOperand,
249255
const AnalysisState &state);
250256

257+
bool bufferizesToElementwiseAccess(const AnalysisState &state,
258+
ArrayRef<OpOperand *> opOperands);
259+
251260
AliasingValueList getAliasingValues(
252261
OpOperand &opOperand, const AnalysisState &state);
253262

254263
RankedTensorType getType() {
255264
return ::llvm::cast<RankedTensorType>(getResult().getType());
256265
}
266+
267+
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
268+
return {1, 2}; // `dest` operand
269+
}
257270
}];
258271

259272
let assemblyFormat = "$source `in` $dest attr-dict `:` type($source)";

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,13 +576,40 @@ MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
576576
return success();
577577
}
578578

579+
bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
580+
const AnalysisState &state, ArrayRef<OpOperand *> opOperands) {
581+
// As elements are copied from the "source" buffer to the "dest" buffer,
582+
// already copied elements are not read a second time.
583+
return true;
584+
}
585+
579586
LogicalResult MaterializeInDestinationOp::reifyResultShapes(
580587
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
581588
reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
582589
reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
583590
return success();
584591
}
585592

593+
Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
594+
Location loc) {
595+
// The subset is the entire destination tensor.
596+
return getDest();
597+
}
598+
599+
bool MaterializeInDestinationOp::isEquivalentSubset(
600+
Value candidate, function_ref<bool(Value, Value)> equivalenceFn) {
601+
return equivalenceFn(getDest(), candidate);
602+
}
603+
604+
SmallVector<Value>
605+
MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
606+
return {getDest()};
607+
}
608+
609+
OpOperand &MaterializeInDestinationOp::getSourceOperand() {
610+
return getOperation()->getOpOperand(0) /*source*/;
611+
}
612+
586613
//===----------------------------------------------------------------------===//
587614
// ToTensorOp
588615
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,31 @@ func.func @bbarg_of_unknown_op_2(%f: f32) {
158158
// CHECK: {__inplace_operands_attr__ = ["false"]} : (tensor<10xf32>) -> ()
159159
return
160160
}
161+
162+
// -----
163+
164+
// CHECK: func @materialize_in_destination_aliasing(
165+
func.func @materialize_in_destination_aliasing(%t: tensor<?xf32>, %p1: index, %p2: index, %sz: index) -> tensor<5xf32> {
166+
%buffer = tensor.empty(%sz) : tensor<?xf32>
167+
// CHECK: tensor.extract_slice
168+
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]}
169+
%src = tensor.extract_slice %t[%p1][5][1] : tensor<?xf32> to tensor<5xf32>
170+
// CHECK: tensor.extract_slice
171+
// CHECK-SAME: {__inplace_operands_attr__ = ["false", "none"]}
172+
%dest = tensor.extract_slice %t[%p2][5][1] : tensor<?xf32> to tensor<5xf32>
173+
// CHECK: bufferization.materialize_in_destination
174+
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
175+
%r = bufferization.materialize_in_destination %src in %dest : tensor<5xf32>
176+
return %r : tensor<5xf32>
177+
}
178+
179+
// -----
180+
181+
// CHECK: func @materialize_in_destination(
182+
func.func @materialize_in_destination(%t: tensor<?xf32>, %sz: index) -> tensor<?xf32> {
183+
%buffer = tensor.empty(%sz) : tensor<?xf32>
184+
// CHECK: bufferization.materialize_in_destination
185+
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
186+
%r = bufferization.materialize_in_destination %buffer in %buffer : tensor<?xf32>
187+
return %r : tensor<?xf32>
188+
}

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,17 @@ func.func @regression_multiple_insertion_points(%t1: tensor<?x?xf32>) -> tensor<
291291
%2 = tensor.insert_slice %filled into %t1 [%0, %1] [2, 5] [1, 1] : tensor<2x5xf32> into tensor<?x?xf32>
292292
return %2 : tensor<?x?xf32>
293293
}
294+
295+
// -----
296+
297+
// CHECK-LABEL: func @materialize_in_destination(
298+
// CHECK-SAME: %[[m:.*]]: memref<5xf32, strided<[?], offset: ?>>,
299+
// CHECK: linalg.fill {{.*}} outs(%[[m]]
300+
// CHECK: return %[[m]]
301+
func.func @materialize_in_destination(%t: tensor<5xf32>, %f: f32) -> tensor<5xf32> {
302+
%0 = tensor.empty() : tensor<5xf32>
303+
%filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
304+
%1 = bufferization.materialize_in_destination %filled in %t : tensor<5xf32>
305+
return %1 : tensor<5xf32>
306+
}
307+

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12092,6 +12092,8 @@ gentbl_cc_library(
1209212092
deps = [
1209312093
":BufferizableOpInterfaceTdFiles",
1209412094
":BufferizationOpsTdFiles",
12095+
":DestinationStyleOpInterfaceTdFiles",
12096+
":SubsetInsertionOpInterfaceTdFiles",
1209512097
],
1209612098
)
1209712099

0 commit comments

Comments
 (0)