diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp index 509f5202be08d..65df355216f74 100644 --- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp @@ -294,7 +294,34 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) { solver.load(symbolTable); LDBG() << "Initializing and running solver"; (void)solver.initializeAndRun(op); - LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName(); + LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName() + << " check on unreachable code now:"; + // The framework doesn't visit operations in dead blocks, so we need to + // explicitly mark them as dead. + op->walk([&](Operation *op) { + if (op->getNumResults() == 0) + return; + for (auto result : llvm::enumerate(op->getResults())) { + if (getLiveness(result.value())) + continue; + LDBG() << "Result: " << result.index() << " of " + << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " has no liveness info (unreachable), mark dead"; + solver.getOrCreateState(result.value()); + } + for (auto ®ion : op->getRegions()) { + for (auto &block : region) { + for (auto blockArg : llvm::enumerate(block.getArguments())) { + if (getLiveness(blockArg.value())) + continue; + LDBG() << "Block argument: " << blockArg.index() << " of " + << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " has no liveness info, mark dead"; + solver.getOrCreateState(blockArg.value()); + } + } + } + }); } const Liveness *RunLivenessAnalysis::getLiveness(Value val) { diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp index e625f626d12fd..13a3e1480c836 100644 --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -19,12 +19,15 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/DebugLog.h" #include #include using namespace mlir; using namespace mlir::dataflow; +#define DEBUG_TYPE "dataflow" + //===----------------------------------------------------------------------===// // AbstractSparseLattice //===----------------------------------------------------------------------===// @@ -64,22 +67,36 @@ AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) { LogicalResult AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) { + LDBG() << "Initializing recursively for operation: " << op->getName(); + // Initialize the analysis by visiting every owner of an SSA value (all // operations and blocks). - if (failed(visitOperation(op))) + if (failed(visitOperation(op))) { + LDBG() << "Failed to visit operation: " << op->getName(); return failure(); + } for (Region ®ion : op->getRegions()) { + LDBG() << "Processing region with " << region.getBlocks().size() + << " blocks"; for (Block &block : region) { + LDBG() << "Processing block with " << block.getNumArguments() + << " arguments"; getOrCreate(getProgramPointBefore(&block)) ->blockContentSubscribe(this); visitBlock(&block); - for (Operation &op : block) - if (failed(initializeRecursively(&op))) + for (Operation &op : block) { + LDBG() << "Recursively initializing nested operation: " << op.getName(); + if (failed(initializeRecursively(&op))) { + LDBG() << "Failed to initialize nested operation: " << op.getName(); return failure(); + } + } } } + LDBG() << "Successfully completed recursive initialization for operation: " + << op->getName(); return success(); } @@ -409,11 +426,20 @@ static MutableArrayRef operandsToOpOperands(OperandRange &operands) { LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { + LDBG() << "Visiting operation: " << op->getName() << " with " + << op->getNumOperands() << " operands and " << op->getNumResults() + << " results"; + // If we're in a dead block, bail out. if (op->getBlock() != nullptr && - !getOrCreate(getProgramPointBefore(op->getBlock()))->isLive()) + !getOrCreate(getProgramPointBefore(op->getBlock())) + ->isLive()) { + LDBG() << "Operation is in dead block, bailing out"; return success(); + } + LDBG() << "Creating lattice elements for " << op->getNumOperands() + << " operands and " << op->getNumResults() << " results"; SmallVector operandLattices = getLatticeElements(op->getOperands()); SmallVector resultLattices = @@ -422,11 +448,15 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { // Block arguments of region branch operations flow back into the operands // of the parent op if (auto branch = dyn_cast(op)) { + LDBG() << "Processing RegionBranchOpInterface operation"; visitRegionSuccessors(branch, operandLattices); return success(); } if (auto branch = dyn_cast(op)) { + LDBG() << "Processing BranchOpInterface operation with " + << op->getNumSuccessors() << " successors"; + // Block arguments of successor blocks flow back into our operands. // We remember all operands not forwarded to any block in a BitVector. @@ -463,6 +493,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { // For function calls, connect the arguments of the entry blocks to the // operands of the call op that are forwarded to these arguments. if (auto call = dyn_cast(op)) { + LDBG() << "Processing CallOpInterface operation"; Operation *callableOp = call.resolveCallableInTable(&symbolTable); if (auto callable = dyn_cast_or_null(callableOp)) { // Not all operands of a call op forward to arguments. Such operands are @@ -513,6 +544,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { // of this op itself and the operands of the terminators of the regions of // this op. if (auto terminator = dyn_cast(op)) { + LDBG() << "Processing RegionBranchTerminatorOpInterface operation"; if (auto branch = dyn_cast(op->getParentOp())) { visitRegionSuccessorsFromTerminator(terminator, branch); return success(); @@ -520,12 +552,16 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { } if (op->hasTrait()) { + LDBG() << "Processing ReturnLike operation"; // Going backwards, the operands of the return are derived from the // results of all CallOps calling this CallableOp. - if (auto callable = dyn_cast(op->getParentOp())) + if (auto callable = dyn_cast(op->getParentOp())) { + LDBG() << "Callable parent found, visiting callable operation"; return visitCallableOperation(op, callable, operandLattices); + } } + LDBG() << "Using default visitOperationImpl for operation: " << op->getName(); return visitOperationImpl(op, operandLattices, resultLattices); } diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 4ccb83f3ad298..02dad69e49614 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -258,18 +258,17 @@ static SmallVector operandsToOpOperands(OperandRange operands) { static void processSimpleOp(Operation *op, RunLivenessAnalysis &la, DenseSet &nonLiveSet, RDVFinalCleanupList &cl) { - LDBG() << "Processing simple op: " << *op; if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) { - LDBG() - << "Simple op is not memory effect free or has live results, skipping: " - << *op; + LDBG() << "Simple op is not memory effect free or has live results, " + "preserving it: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); return; } LDBG() << "Simple op has all dead results and is memory effect free, scheduling " "for removal: " - << *op; + << OpWithFlags(op, OpPrintingFlags().skipRegions()); cl.operations.push_back(op); collectNonLiveValues(nonLiveSet, op->getResults(), BitVector(op->getNumResults(), true)); @@ -728,19 +727,31 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, /// Removes dead values collected in RDVFinalCleanupList. /// To be run once when all dead values have been collected. static void cleanUpDeadVals(RDVFinalCleanupList &list) { + LDBG() << "Starting cleanup of dead values..."; + // 1. Operations + LDBG() << "Cleaning up " << list.operations.size() << " operations"; for (auto &op : list.operations) { + LDBG() << "Erasing operation: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); op->dropAllUses(); op->erase(); } // 2. Values + LDBG() << "Cleaning up " << list.values.size() << " values"; for (auto &v : list.values) { + LDBG() << "Dropping all uses of value: " << v; v.dropAllUses(); } // 3. Functions + LDBG() << "Cleaning up " << list.functions.size() << " functions"; for (auto &f : list.functions) { + LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName(); + LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments"; + LDBG() << " Erasing " << f.nonLiveRets.count() + << " non-live return values"; // Some functions may not allow erasing arguments or results. These calls // return failure in such cases without modifying the function, so it's okay // to proceed. @@ -749,44 +760,67 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { } // 4. Operands + LDBG() << "Cleaning up " << list.operands.size() << " operand lists"; for (OperationToCleanup &o : list.operands) { - if (o.op->getNumOperands() > 0) + if (o.op->getNumOperands() > 0) { + LDBG() << "Erasing " << o.nonLive.count() + << " non-live operands from operation: " + << OpWithFlags(o.op, OpPrintingFlags().skipRegions()); o.op->eraseOperands(o.nonLive); + } } // 5. Results + LDBG() << "Cleaning up " << list.results.size() << " result lists"; for (auto &r : list.results) { + LDBG() << "Erasing " << r.nonLive.count() + << " non-live results from operation: " + << OpWithFlags(r.op, OpPrintingFlags().skipRegions()); dropUsesAndEraseResults(r.op, r.nonLive); } // 6. Blocks + LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists"; for (auto &b : list.blocks) { // blocks that are accessed via multiple codepaths processed once if (b.b->getNumArguments() != b.nonLiveArgs.size()) continue; + LDBG() << "Erasing " << b.nonLiveArgs.count() + << " non-live arguments from block: " << b.b; // it iterates backwards because erase invalidates all successor indexes for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) { if (!b.nonLiveArgs[i]) continue; + LDBG() << " Erasing block argument " << i << ": " << b.b->getArgument(i); b.b->getArgument(i).dropAllUses(); b.b->eraseArgument(i); } } // 7. Successor Operands + LDBG() << "Cleaning up " << list.successorOperands.size() + << " successor operand lists"; for (auto &op : list.successorOperands) { SuccessorOperands successorOperands = op.branch.getSuccessorOperands(op.successorIndex); // blocks that are accessed via multiple codepaths processed once if (successorOperands.size() != op.nonLiveOperands.size()) continue; + LDBG() << "Erasing " << op.nonLiveOperands.count() + << " non-live successor operands from successor " + << op.successorIndex << " of branch: " + << OpWithFlags(op.branch, OpPrintingFlags().skipRegions()); // it iterates backwards because erase invalidates all successor indexes for (int i = successorOperands.size() - 1; i >= 0; --i) { if (!op.nonLiveOperands[i]) continue; + LDBG() << " Erasing successor operand " << i << ": " + << successorOperands[i]; successorOperands.erase(i); } } + + LDBG() << "Finished cleanup of dead values"; } struct RemoveDeadValues : public impl::RemoveDeadValuesBase { diff --git a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir index a89a0f4084e99..3748be74eb0f3 100644 --- a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir +++ b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir @@ -283,3 +283,23 @@ func.func @test_10_negative() -> (i32) { %0:2 = func.call @private_1() : () -> (i32, i32) return %0#0 : i32 } + +// ----- + +// Test that we correctly set a liveness value for operations in dead block. +// These won't be visited by the dataflow framework so the analysis need to +// explicitly manage them. +// CHECK-LABEL: test_tag: dead_block_cmpi: +// CHECK-NEXT: operand #0: not live +// CHECK-NEXT: operand #1: not live +// CHECK-NEXT: result #0: not live +func.func @dead_block() { + %false = arith.constant false + %zero = arith.constant 0 : i64 + cf.cond_br %false, ^bb1, ^bb4 + ^bb1: + %3 = arith.cmpi eq, %zero, %zero {tag = "dead_block_cmpi"} : i64 + cf.br ^bb1 + ^bb4: + return +} diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index 9ded6a30d9c95..0f8d757086e87 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -571,3 +571,24 @@ module @return_void_with_unused_argument { } } +// ----- + +// CHECK-LABEL: module @dynamically_unreachable +module @dynamically_unreachable { + func.func @dynamically_unreachable() { + // This value is used by an operation in a dynamically unreachable block. + %zero = arith.constant 0 : i64 + + // Dataflow analysis knows from the constant condition that + // ^bb1 is unreachable + %false = arith.constant false + cf.cond_br %false, ^bb1, ^bb4 + ^bb1: + // This unreachable operation should be removed. + // CHECK-NOT: arith.cmpi + %3 = arith.cmpi eq, %zero, %zero : i64 + cf.br ^bb1 + ^bb4: + return + } +} diff --git a/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp index 43005e22584c2..8e2f03b644e49 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp @@ -33,7 +33,6 @@ struct TestLivenessAnalysisPass void runOnOperation() override { auto &livenessAnalysis = getAnalysis(); - Operation *op = getOperation(); raw_ostream &os = llvm::outs();