From 1db8e794fc718ba3720789659ebba47c9e7e05a8 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 17 Dec 2024 12:57:55 -0600 Subject: [PATCH 1/3] Enable any `AllocationOpInterface` with `hoistStaticAllocs` option --- .../Bufferization/Transforms/BufferResultsToOutParams.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index b7755b2be8483..b4d2d6b0c5da8 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -121,7 +122,8 @@ static LogicalResult updateReturnOps(func::FuncOp func, OpBuilder builder(op); for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) { if (hoistStaticAllocs && - isa_and_nonnull(orig.getDefiningOp()) && + isa_and_nonnull( + orig.getDefiningOp()) && mlir::cast(orig.getType()).hasStaticShape()) { orig.replaceAllUsesWith(arg); orig.getDefiningOp()->erase(); From bf145cfeb45263d879eb998ce47cf6dd93aa2b06 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 19 Dec 2024 11:46:37 -0600 Subject: [PATCH 2/3] add and use custom allocation function --- .../Dialect/Bufferization/Transforms/Passes.h | 13 +++++++++- .../Transforms/BufferResultsToOutParams.cpp | 24 ++++++++++++++----- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h index fe43a05c81fdc..966438956fc6c 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -6,6 +6,7 @@ namespace mlir { class FunctionOpInterface; +class MemRefType; class ModuleOp; class RewritePatternSet; class OpBuilder; @@ -38,7 +39,7 @@ std::unique_ptr createOwnershipBasedBufferDeallocationPass( DeallocationOptions options = DeallocationOptions()); /// Creates a pass that finds all temporary allocations -/// and attempts to move the deallocation after the last user/dependency +/// and attempts to move the deallocation after the last user/dependency /// of the allocation, thereby optimizing allocation liveness. std::unique_ptr createOptimizeAllocationLivenessPass(); @@ -157,6 +158,12 @@ std::unique_ptr createBufferLoopHoistingPass(); // Options struct for BufferResultsToOutParams pass. // Note: defined only here, not in tablegen. struct BufferResultsToOutParamsOpts { + /// Allocator function: Generate a memref allocation with the given type. + /// Since `promoteBufferResultsToOutParams` doesn't allow dynamically shaped + /// results, we don't allow passing a range of values for dynamic dims. + using AllocationFn = + std::function(OpBuilder &, Location, MemRefType)>; + /// Memcpy function: Generate a memcpy between two memrefs. using MemCpyFn = std::function; @@ -167,6 +174,10 @@ struct BufferResultsToOutParamsOpts { return true; }; + /// Allocation function; used to allocate a memref. + /// If this is empty, memref.alloc is used + std::optional allocationFn; + /// Memcpy function; used to create a copy between two memrefs. /// If this is empty, memref.copy is used. std::optional memCpyFn; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index b4d2d6b0c5da8..545b6ca009c03 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -22,6 +22,7 @@ namespace bufferization { } // namespace mlir using namespace mlir; +using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn; using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn; /// Return `true` if the given MemRef type has a fully dynamic layout. @@ -141,9 +142,8 @@ static LogicalResult updateReturnOps(func::FuncOp func, // Updates all CallOps in the scope of the given ModuleOp by allocating // temporary buffers for newly introduced out params. -static LogicalResult -updateCalls(ModuleOp module, - const bufferization::BufferResultsToOutParamsOpts &options) { +static LogicalResult updateCalls(ModuleOp module, AllocationFn allocationFn, + std::function filterFn) { bool didFail = false; SymbolTable symtab(module); module.walk([&](func::CallOp op) { @@ -154,7 +154,7 @@ updateCalls(ModuleOp module, didFail = true; return; } - if (!options.filterFn(&callee)) + if (!filterFn(&callee)) return; SmallVector replaceWithNewCallResults; SmallVector replaceWithOutParams; @@ -177,7 +177,13 @@ updateCalls(ModuleOp module, auto allocType = MemRefType::get(memrefType.getShape(), memrefType.getElementType(), AffineMap(), memrefType.getMemorySpace()); - Value outParam = builder.create(op.getLoc(), allocType); + auto maybeOutParam = allocationFn(builder, op.getLoc(), allocType); + if (failed(maybeOutParam)) { + op.emitError() << "failed to create allocation op"; + didFail = true; + return; + } + Value outParam = maybeOutParam.value(); if (!hasStaticIdentityLayout(memrefType)) { // Layout maps are already checked in `updateFuncOp`. assert(hasFullyDynamicLayoutMap(memrefType) && @@ -226,7 +232,13 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( return failure(); } } - if (failed(updateCalls(module, options))) + auto defaultAllocationFn = [](OpBuilder &builder, Location loc, + MemRefType type) { + return builder.create(loc, type).getResult(); + }; + if (failed(updateCalls(module, + options.allocationFn.value_or(defaultAllocationFn), + options.filterFn))) return failure(); return success(); } From cc788e0b16ac112c0ea214c04a1de943d8101f32 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 20 Dec 2024 10:14:42 -0600 Subject: [PATCH 3/3] Move default alloc/copy fns to `BufferResultsToOutParamsOpts` struct --- .../Dialect/Bufferization/Transforms/Passes.h | 16 ++++++-- .../Transforms/BufferResultsToOutParams.cpp | 38 +++++++------------ 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h index 966438956fc6c..c8e456a1d7e38 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -2,6 +2,7 @@ #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Pass/Pass.h" namespace mlir { @@ -175,12 +176,19 @@ struct BufferResultsToOutParamsOpts { }; /// Allocation function; used to allocate a memref. - /// If this is empty, memref.alloc is used - std::optional allocationFn; + /// Default memref.alloc is used + AllocationFn allocationFn = [](OpBuilder &builder, Location loc, + MemRefType type) { + return builder.create(loc, type).getResult(); + }; /// Memcpy function; used to create a copy between two memrefs. - /// If this is empty, memref.copy is used. - std::optional memCpyFn; + /// Default memref.copy is used. + MemCpyFn memCpyFn = [](OpBuilder &builder, Location loc, Value from, + Value to) { + builder.create(loc, from, to); + return success(); + }; /// If true, the pass adds a "bufferize.result" attribute to each output /// parameter. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index 545b6ca009c03..2502744cb3f58 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -107,10 +107,9 @@ updateFuncOp(func::FuncOp func, // Updates all ReturnOps in the scope of the given func::FuncOp by either // keeping them as return values or copying the associated buffer contents into // the given out-params. -static LogicalResult updateReturnOps(func::FuncOp func, - ArrayRef appendedEntryArgs, - MemCpyFn memCpyFn, - bool hoistStaticAllocs) { +static LogicalResult +updateReturnOps(func::FuncOp func, ArrayRef appendedEntryArgs, + const bufferization::BufferResultsToOutParamsOpts &options) { auto res = func.walk([&](func::ReturnOp op) { SmallVector copyIntoOutParams; SmallVector keepAsReturnOperands; @@ -122,14 +121,14 @@ static LogicalResult updateReturnOps(func::FuncOp func, } OpBuilder builder(op); for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) { - if (hoistStaticAllocs && + if (options.hoistStaticAllocs && isa_and_nonnull( orig.getDefiningOp()) && mlir::cast(orig.getType()).hasStaticShape()) { orig.replaceAllUsesWith(arg); orig.getDefiningOp()->erase(); } else { - if (failed(memCpyFn(builder, op.getLoc(), orig, arg))) + if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg))) return WalkResult::interrupt(); } } @@ -142,8 +141,9 @@ static LogicalResult updateReturnOps(func::FuncOp func, // Updates all CallOps in the scope of the given ModuleOp by allocating // temporary buffers for newly introduced out params. -static LogicalResult updateCalls(ModuleOp module, AllocationFn allocationFn, - std::function filterFn) { +static LogicalResult +updateCalls(ModuleOp module, + const bufferization::BufferResultsToOutParamsOpts &options) { bool didFail = false; SymbolTable symtab(module); module.walk([&](func::CallOp op) { @@ -154,7 +154,7 @@ static LogicalResult updateCalls(ModuleOp module, AllocationFn allocationFn, didFail = true; return; } - if (!filterFn(&callee)) + if (!options.filterFn(&callee)) return; SmallVector replaceWithNewCallResults; SmallVector replaceWithOutParams; @@ -177,7 +177,8 @@ static LogicalResult updateCalls(ModuleOp module, AllocationFn allocationFn, auto allocType = MemRefType::get(memrefType.getShape(), memrefType.getElementType(), AffineMap(), memrefType.getMemorySpace()); - auto maybeOutParam = allocationFn(builder, op.getLoc(), allocType); + auto maybeOutParam = + options.allocationFn(builder, op.getLoc(), allocType); if (failed(maybeOutParam)) { op.emitError() << "failed to create allocation op"; didFail = true; @@ -221,24 +222,11 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( return failure(); if (func.isExternal()) continue; - auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from, - Value to) { - builder.create(loc, from, to); - return success(); - }; - if (failed(updateReturnOps(func, appendedEntryArgs, - options.memCpyFn.value_or(defaultMemCpyFn), - options.hoistStaticAllocs))) { + if (failed(updateReturnOps(func, appendedEntryArgs, options))) { return failure(); } } - auto defaultAllocationFn = [](OpBuilder &builder, Location loc, - MemRefType type) { - return builder.create(loc, type).getResult(); - }; - if (failed(updateCalls(module, - options.allocationFn.value_or(defaultAllocationFn), - options.filterFn))) + if (failed(updateCalls(module, options))) return failure(); return success(); }