From 7df538609d965b6edfc402e014972b27715e29ae Mon Sep 17 00:00:00 2001 From: Tom Eccles Date: Tue, 7 Jan 2025 11:37:37 +0000 Subject: [PATCH] [flang][StackArrays] track pointers through fir.convert This does add a little computational complexity because now every freemem operation has to be tested for every allocation. This could be improved with some more memoisation but I think it is easier to read this way. Let me know if you would prefer me to change this to pre-compute the normalised addresses each freemem operation is using. Weirdly, this change resulted in a verifier failure for the fir.declare in the previous test case. Maybe it was previously removed as dead code and now it isn't. Anyway I fixed that too. --- .../lib/Optimizer/Transforms/StackArrays.cpp | 37 +++++++++++-------- flang/test/Transforms/stack-arrays.fir | 20 +++++++++- 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp index bdcb8199b790d..2a9d3397e87b0 100644 --- a/flang/lib/Optimizer/Transforms/StackArrays.cpp +++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp @@ -330,6 +330,18 @@ std::optional LatticePoint::get(mlir::Value val) const { return it->second; } +static mlir::Value lookThroughDeclaresAndConverts(mlir::Value value) { + while (mlir::Operation *op = value.getDefiningOp()) { + if (auto declareOp = llvm::dyn_cast(op)) + value = declareOp.getMemref(); + else if (auto convertOp = llvm::dyn_cast(op)) + value = convertOp->getOperand(0); + else + return value; + } + return value; +} + mlir::LogicalResult AllocationAnalysis::visitOperation( mlir::Operation *op, const LatticePoint &before, LatticePoint *after) { LLVM_DEBUG(llvm::dbgs() << "StackArrays: Visiting operation: " << *op @@ -363,10 +375,10 @@ mlir::LogicalResult AllocationAnalysis::visitOperation( mlir::Value operand = op->getOperand(0); // Note: StackArrays is scheduled in the pass pipeline after lowering hlfir - // to fir. Therefore, we only need to handle `fir::DeclareOp`s. - if (auto declareOp = - llvm::dyn_cast_if_present(operand.getDefiningOp())) - operand = declareOp.getMemref(); + // to fir. Therefore, we only need to handle `fir::DeclareOp`s. Also look + // past converts in case the pointer was changed between different pointer + // types. + operand = lookThroughDeclaresAndConverts(operand); std::optional operandState = before.get(operand); if (operandState && *operandState == AllocationState::Allocated) { @@ -535,17 +547,12 @@ AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem, // remove freemem operations llvm::SmallVector erases; - for (mlir::Operation *user : allocmem.getOperation()->getUsers()) { - if (auto declareOp = mlir::dyn_cast_if_present(user)) { - for (mlir::Operation *user : declareOp->getUsers()) { - if (mlir::isa(user)) - erases.push_back(user); - } - } - - if (mlir::isa(user)) - erases.push_back(user); - } + mlir::Operation *parent = allocmem->getParentOp(); + // TODO: this shouldn't need to be re-calculated for every allocmem + parent->walk([&](fir::FreeMemOp freeOp) { + if (lookThroughDeclaresAndConverts(freeOp->getOperand(0)) == allocmem) + erases.push_back(freeOp); + }); // now we are done iterating the users, it is safe to mutate them for (mlir::Operation *erase : erases) diff --git a/flang/test/Transforms/stack-arrays.fir b/flang/test/Transforms/stack-arrays.fir index 66cd2a5aa910b..444136d53e034 100644 --- a/flang/test/Transforms/stack-arrays.fir +++ b/flang/test/Transforms/stack-arrays.fir @@ -379,7 +379,8 @@ func.func @placement_loop_declare() { %3 = arith.addi %c1, %c2 : index // operand is now available %4 = fir.allocmem !fir.array, %3 - %5 = fir.declare %4 {uniq_name = "temp"} : (!fir.heap>) -> !fir.heap> + %shape = fir.shape %3 : (index) -> !fir.shape<1> + %5 = fir.declare %4(%shape) {uniq_name = "temp"} : (!fir.heap>, !fir.shape<1>) -> !fir.heap> // ... fir.freemem %5 : !fir.heap> fir.result %3, %c1_i32 : index, i32 @@ -400,3 +401,20 @@ func.func @placement_loop_declare() { // CHECK-NEXT: } // CHECK-NEXT: return // CHECK-NEXT: } + +// Can we look through fir.convert and fir.declare? +func.func @lookthrough() { + %0 = fir.allocmem !fir.array<42xi32> + %c42 = arith.constant 42 : index + %shape = fir.shape %c42 : (index) -> !fir.shape<1> + %1 = fir.declare %0(%shape) {uniq_name = "name"} : (!fir.heap>, !fir.shape<1>) -> !fir.heap> + %2 = fir.convert %1 : (!fir.heap>) -> !fir.ref> + // use the ref so the converts aren't folded + %3 = fir.load %2 : !fir.ref> + %4 = fir.convert %2 : (!fir.ref>) -> !fir.heap> + fir.freemem %4 : !fir.heap> + return +} +// CHECK: func.func @lookthrough() { +// CHECK: fir.alloca !fir.array<42xi32> +// CHECK-NOT: fir.freemem