Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H

#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
class FunctionOpInterface;
class MemRefType;
class ModuleOp;
class RewritePatternSet;
class OpBuilder;
Expand Down Expand Up @@ -38,7 +40,7 @@ std::unique_ptr<Pass> createOwnershipBasedBufferDeallocationPass(
DeallocationOptions options = DeallocationOptions());

/// Creates a pass that finds all temporary allocations
/// and attempts to move the deallocation after the last user/dependency
/// and attempts to move the deallocation after the last user/dependency
/// of the allocation, thereby optimizing allocation liveness.
std::unique_ptr<Pass> createOptimizeAllocationLivenessPass();

Expand Down Expand Up @@ -157,6 +159,12 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
// Options struct for BufferResultsToOutParams pass.
// Note: defined only here, not in tablegen.
struct BufferResultsToOutParamsOpts {
/// Allocator function: Generate a memref allocation with the given type.
/// Since `promoteBufferResultsToOutParams` doesn't allow dynamically shaped
/// results, we don't allow passing a range of values for dynamic dims.
using AllocationFn =
std::function<FailureOr<Value>(OpBuilder &, Location, MemRefType)>;

/// Memcpy function: Generate a memcpy between two memrefs.
using MemCpyFn =
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
Expand All @@ -167,9 +175,20 @@ struct BufferResultsToOutParamsOpts {
return true;
};

/// Allocation function; used to allocate a memref.
/// Default memref.alloc is used
AllocationFn allocationFn = [](OpBuilder &builder, Location loc,
MemRefType type) {
return builder.create<memref::AllocOp>(loc, type).getResult();
};

/// Memcpy function; used to create a copy between two memrefs.
/// If this is empty, memref.copy is used.
std::optional<MemCpyFn> memCpyFn;
/// Default memref.copy is used.
MemCpyFn memCpyFn = [](OpBuilder &builder, Location loc, Value from,
Value to) {
builder.create<memref::CopyOp>(loc, from, to);
return success();
};

/// If true, the pass adds a "bufferize.result" attribute to each output
/// parameter.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand All @@ -21,6 +22,7 @@ namespace bufferization {
} // namespace mlir

using namespace mlir;
using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;

/// Return `true` if the given MemRef type has a fully dynamic layout.
Expand Down Expand Up @@ -105,10 +107,9 @@ updateFuncOp(func::FuncOp func,
// Updates all ReturnOps in the scope of the given func::FuncOp by either
// keeping them as return values or copying the associated buffer contents into
// the given out-params.
static LogicalResult updateReturnOps(func::FuncOp func,
ArrayRef<BlockArgument> appendedEntryArgs,
MemCpyFn memCpyFn,
bool hoistStaticAllocs) {
static LogicalResult
updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
const bufferization::BufferResultsToOutParamsOpts &options) {
auto res = func.walk([&](func::ReturnOp op) {
SmallVector<Value, 6> copyIntoOutParams;
SmallVector<Value, 6> keepAsReturnOperands;
Expand All @@ -120,13 +121,14 @@ static LogicalResult updateReturnOps(func::FuncOp func,
}
OpBuilder builder(op);
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
if (hoistStaticAllocs &&
isa_and_nonnull<memref::AllocOp>(orig.getDefiningOp()) &&
if (options.hoistStaticAllocs &&
isa_and_nonnull<bufferization::AllocationOpInterface>(
orig.getDefiningOp()) &&
mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
orig.replaceAllUsesWith(arg);
orig.getDefiningOp()->erase();
} else {
if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
return WalkResult::interrupt();
}
}
Expand Down Expand Up @@ -175,7 +177,14 @@ updateCalls(ModuleOp module,
auto allocType =
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
AffineMap(), memrefType.getMemorySpace());
Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType);
auto maybeOutParam =
options.allocationFn(builder, op.getLoc(), allocType);
if (failed(maybeOutParam)) {
op.emitError() << "failed to create allocation op";
didFail = true;
return;
}
Value outParam = maybeOutParam.value();
if (!hasStaticIdentityLayout(memrefType)) {
// Layout maps are already checked in `updateFuncOp`.
assert(hasFullyDynamicLayoutMap(memrefType) &&
Expand Down Expand Up @@ -213,14 +222,7 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
return failure();
if (func.isExternal())
continue;
auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from,
Value to) {
builder.create<memref::CopyOp>(loc, from, to);
return success();
};
if (failed(updateReturnOps(func, appendedEntryArgs,
options.memCpyFn.value_or(defaultMemCpyFn),
options.hoistStaticAllocs))) {
if (failed(updateReturnOps(func, appendedEntryArgs, options))) {
return failure();
}
}
Expand Down
Loading