Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

namespace mlir {
class FunctionOpInterface;
class MemRefType;
class ModuleOp;
class RewritePatternSet;
class OpBuilder;
Expand Down Expand Up @@ -38,7 +39,7 @@ std::unique_ptr<Pass> createOwnershipBasedBufferDeallocationPass(
DeallocationOptions options = DeallocationOptions());

/// Creates a pass that finds all temporary allocations
/// and attempts to move the deallocation after the last user/dependency
/// and attempts to move the deallocation after the last user/dependency
/// of the allocation, thereby optimizing allocation liveness.
std::unique_ptr<Pass> createOptimizeAllocationLivenessPass();

Expand Down Expand Up @@ -157,6 +158,12 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
// Options struct for BufferResultsToOutParams pass.
// Note: defined only here, not in tablegen.
struct BufferResultsToOutParamsOpts {
/// Allocator function: Generate a memref allocation with the given type.
/// Since `promoteBufferResultsToOutParams` doesn't allow dynamically shaped
/// results, we don't allow passing a range of values for dynamic dims.
using AllocationFn =
std::function<FailureOr<Value>(OpBuilder &, Location, MemRefType)>;

/// Memcpy function: Generate a memcpy between two memrefs.
using MemCpyFn =
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
Expand All @@ -167,6 +174,10 @@ struct BufferResultsToOutParamsOpts {
return true;
};

/// Allocation function; used to allocate a memref.
/// If this is empty, memref.alloc is used
std::optional<AllocationFn> allocationFn;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be avoided by adding a buildAlloc interface method to AllocationOpInterface? Maybe the transformation could also supported mixed alloc ops then>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that would work actually. wouldn't you need an instance of an AllocationOpInterface to call the method? In this context there is no such instance, since we are creating a new one. Maybe i'm missing something. Also I do think I prefer this since it mirrors alloc and copy customization in bufferization.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, you only have a memref SSA value, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yah, this is used in updateCalls for allocating at the call site and uses the arg types to allocate


/// Memcpy function; used to create a copy between two memrefs.
/// If this is empty, memref.copy is used.
std::optional<MemCpyFn> memCpyFn;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand All @@ -21,6 +22,7 @@ namespace bufferization {
} // namespace mlir

using namespace mlir;
using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;

/// Return `true` if the given MemRef type has a fully dynamic layout.
Expand Down Expand Up @@ -121,7 +123,8 @@ static LogicalResult updateReturnOps(func::FuncOp func,
OpBuilder builder(op);
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
if (hoistStaticAllocs &&
isa_and_nonnull<memref::AllocOp>(orig.getDefiningOp()) &&
isa_and_nonnull<bufferization::AllocationOpInterface>(
orig.getDefiningOp()) &&
mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
orig.replaceAllUsesWith(arg);
orig.getDefiningOp()->erase();
Expand All @@ -139,9 +142,8 @@ static LogicalResult updateReturnOps(func::FuncOp func,

// Updates all CallOps in the scope of the given ModuleOp by allocating
// temporary buffers for newly introduced out params.
static LogicalResult
updateCalls(ModuleOp module,
const bufferization::BufferResultsToOutParamsOpts &options) {
static LogicalResult updateCalls(ModuleOp module, AllocationFn allocationFn,
std::function<bool(func::FuncOp *)> filterFn) {
bool didFail = false;
SymbolTable symtab(module);
module.walk([&](func::CallOp op) {
Expand All @@ -152,7 +154,7 @@ updateCalls(ModuleOp module,
didFail = true;
return;
}
if (!options.filterFn(&callee))
if (!filterFn(&callee))
return;
SmallVector<Value, 6> replaceWithNewCallResults;
SmallVector<Value, 6> replaceWithOutParams;
Expand All @@ -175,7 +177,13 @@ updateCalls(ModuleOp module,
auto allocType =
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
AffineMap(), memrefType.getMemorySpace());
Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType);
auto maybeOutParam = allocationFn(builder, op.getLoc(), allocType);
if (failed(maybeOutParam)) {
op.emitError() << "failed to create allocation op";
didFail = true;
return;
}
Value outParam = maybeOutParam.value();
if (!hasStaticIdentityLayout(memrefType)) {
// Layout maps are already checked in `updateFuncOp`.
assert(hasFullyDynamicLayoutMap(memrefType) &&
Expand Down Expand Up @@ -224,7 +232,13 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
return failure();
}
}
if (failed(updateCalls(module, options)))
auto defaultAllocationFn = [](OpBuilder &builder, Location loc,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could give keep this as a default value in the struct, then the field doesn't have to be optional and you can pass the options to updateCalls. Not sure if it's possible without including MemRefOps.h in Passes.h though. (Maybe we can split the header?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure that sounds good to me. i think i would want to do that with the copy op too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was just kinda following what had already been done with copy. i think they should be treated the same in this context either way

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it's possible without including MemRefOps.h in Passes.h though. (Maybe we can split the header?)

If you are opposed to including MemRefOps.h I could just forward declare the appropriate ops if you prefer. I have no particular preference and honestly don't know what's best practice in this case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, no. i don't think forward declaration would work here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i went ahead and just included MemRef.h for now. Passes.h was already included. Naively, I think this is ok since the default use of this would have to link the memref dialect anyway if I'm not mistaken. But let me know if this is no good

MemRefType type) {
return builder.create<memref::AllocOp>(loc, type).getResult();
};
if (failed(updateCalls(module,
options.allocationFn.value_or(defaultAllocationFn),
options.filterFn)))
return failure();
return success();
}
Expand Down
Loading