diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp index 5178d4a62f374..e17b39cd7e371 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp @@ -17,7 +17,10 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #define DEBUG_TYPE "optimize-allocation-liveness" #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") @@ -88,6 +91,19 @@ static bool hasMemoryAllocEffect(MemoryEffectOpInterface memEffectOp) { return false; } +/// Extracts OpResult's with Allocate effects from given op +static SmallVector +collectAllocations(MemoryEffectOpInterface allocOp) { + SmallVector effects; + allocOp.getEffects(effects); + SmallVector allocResults; + for (const MemoryEffects::EffectInstance &it : effects) + if (isa(it.getEffect())) + if (auto val = it.getValue(); val && val.getDefiningOp() == allocOp) + allocResults.push_back(cast(val)); + return allocResults; +} + struct OptimizeAllocationLiveness : public bufferization::impl::OptimizeAllocationLivenessPassBase< OptimizeAllocationLiveness> { @@ -109,7 +125,15 @@ struct OptimizeAllocationLiveness auto allocOp = memEffectOp; LDBG("Checking alloc op: " << allocOp); - auto deallocOp = findUserWithFreeSideEffect(allocOp->getResult(0)); + SmallVector allocationResults = collectAllocations(allocOp); + // Multiple allocations from a single op are not considered here yet. + if (allocationResults.size() != 1) + return WalkResult::advance(); + + OpResult allocResult = allocationResults[0]; + LDBG("On allocation result: " << allocResult); + + auto *deallocOp = findUserWithFreeSideEffect(allocResult); if (!deallocOp || (deallocOp->getBlock() != allocOp->getBlock())) { // The pass handles allocations that have a single dealloc op in the // same block. We also should not hoist the dealloc op out of @@ -119,9 +143,9 @@ struct OptimizeAllocationLiveness Operation *lastUser = nullptr; const BufferViewFlowAnalysis::ValueSetT &deps = - analysis.resolve(allocOp->getResult(0)); + analysis.resolve(allocResult); for (auto dep : llvm::make_early_inc_range(deps)) { - for (auto user : dep.getUsers()) { + for (auto *user : dep.getUsers()) { // We are looking for a non dealloc op user. // check if user is the dealloc op itself. if (user == deallocOp) diff --git a/mlir/test/Dialect/Bufferization/Transforms/optimize-allocation-liveness.mlir b/mlir/test/Dialect/Bufferization/Transforms/optimize-allocation-liveness.mlir index 5f5a0ce54e2c1..63d33e3a88bed 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/optimize-allocation-liveness.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/optimize-allocation-liveness.mlir @@ -209,3 +209,28 @@ func.func private @test_conditional_deallocation() -> memref<32xf32, 1> { return %3 : memref<32xf32, 1> } + +// ----- +// CHECK-LABEL: func.func private @test_alloc_with_multiple_results() { +// CHECK: %[[ID1:.+]], %[[ALLOC1:.+]] = test.alloc_with_multiple_results : index, memref<64xf32> +// CHECK: memref.expand_shape %[[ALLOC1]] +// CHECK: memref.dealloc %[[ALLOC1]] : memref<64xf32> +// CHECK: %[[ID2:.+]], %[[ALLOC2:.+]] = test.alloc_with_multiple_results : index, memref<64xf32> +// CHECK: memref.expand_shape %[[ALLOC2]] +// CHECK: memref.dealloc %[[ALLOC2]] : memref<64xf32> +// CHECK: return +// CHECK: } + +// This test will check that allocations with multiple results and allocated +// buffer at non-zero position are accepted. +func.func private @test_alloc_with_multiple_results() -> () { + %id1, %alloc1 = test.alloc_with_multiple_results : index, memref<64xf32> + %expand_shape1 = memref.expand_shape %alloc1 [[0, 1]] output_shape [8, 8] : memref<64xf32> into memref<8x8xf32> + + %id2, %alloc2 = test.alloc_with_multiple_results : index, memref<64xf32> + %expand_shape2 = memref.expand_shape %alloc2 [[0, 1]] output_shape [8, 8] : memref<64xf32> into memref<8x8xf32> + + memref.dealloc %alloc1 : memref<64xf32> + memref.dealloc %alloc2 : memref<64xf32> + return +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index d8024145e711f..31be00ace1384 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -3441,4 +3441,16 @@ def TestMultiSlotAlloca : TEST_Op<"multi_slot_alloca", let assemblyFormat = "attr-dict `:` functional-type(operands, results)"; } +//===----------------------------------------------------------------------===// +// Test allocation Ops +//===----------------------------------------------------------------------===// + +def TestAllocWithMultipleResults : TEST_Op<"alloc_with_multiple_results"> { + let results = (outs Index:$index, + Res:$memref); + let assemblyFormat = [{ + attr-dict `:` type($index) `,` type($memref) + }]; +} + #endif // TEST_OPS