Skip to content

Commit dfaebe7

Browse files
authored
[MLIR] Fix Liveness analysis handling of unreachable code (#153973)
This patch is forcing all values to be initialized by the LivenessAnalysis, even in dead blocks. The dataflow framework will skip visiting values when its already knows that a block is dynamically unreachable, so this requires specific handling. Downstream code could consider that the absence of liveness is the same a "dead". However as the code is mutated, new value can be introduced, and a transformation like "RemoveDeadValue" must conservatively consider that the absence of liveness information meant that we weren't sure if a value was dead (it could be a newly introduced value. Fixes #153906
1 parent 191e7eb commit dfaebe7

File tree

6 files changed

+150
-13
lines changed

6 files changed

+150
-13
lines changed

mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,34 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) {
294294
solver.load<LivenessAnalysis>(symbolTable);
295295
LDBG() << "Initializing and running solver";
296296
(void)solver.initializeAndRun(op);
297-
LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName();
297+
LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName()
298+
<< " check on unreachable code now:";
299+
// The framework doesn't visit operations in dead blocks, so we need to
300+
// explicitly mark them as dead.
301+
op->walk([&](Operation *op) {
302+
if (op->getNumResults() == 0)
303+
return;
304+
for (auto result : llvm::enumerate(op->getResults())) {
305+
if (getLiveness(result.value()))
306+
continue;
307+
LDBG() << "Result: " << result.index() << " of "
308+
<< OpWithFlags(op, OpPrintingFlags().skipRegions())
309+
<< " has no liveness info (unreachable), mark dead";
310+
solver.getOrCreateState<Liveness>(result.value());
311+
}
312+
for (auto &region : op->getRegions()) {
313+
for (auto &block : region) {
314+
for (auto blockArg : llvm::enumerate(block.getArguments())) {
315+
if (getLiveness(blockArg.value()))
316+
continue;
317+
LDBG() << "Block argument: " << blockArg.index() << " of "
318+
<< OpWithFlags(op, OpPrintingFlags().skipRegions())
319+
<< " has no liveness info, mark dead";
320+
solver.getOrCreateState<Liveness>(blockArg.value());
321+
}
322+
}
323+
}
324+
});
298325
}
299326

300327
const Liveness *RunLivenessAnalysis::getLiveness(Value val) {

mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@
1919
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2020
#include "mlir/Support/LLVM.h"
2121
#include "llvm/ADT/STLExtras.h"
22+
#include "llvm/Support/DebugLog.h"
2223
#include <cassert>
2324
#include <optional>
2425

2526
using namespace mlir;
2627
using namespace mlir::dataflow;
2728

29+
#define DEBUG_TYPE "dataflow"
30+
2831
//===----------------------------------------------------------------------===//
2932
// AbstractSparseLattice
3033
//===----------------------------------------------------------------------===//
@@ -64,22 +67,36 @@ AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) {
6467

6568
LogicalResult
6669
AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
70+
LDBG() << "Initializing recursively for operation: " << op->getName();
71+
6772
// Initialize the analysis by visiting every owner of an SSA value (all
6873
// operations and blocks).
69-
if (failed(visitOperation(op)))
74+
if (failed(visitOperation(op))) {
75+
LDBG() << "Failed to visit operation: " << op->getName();
7076
return failure();
77+
}
7178

7279
for (Region &region : op->getRegions()) {
80+
LDBG() << "Processing region with " << region.getBlocks().size()
81+
<< " blocks";
7382
for (Block &block : region) {
83+
LDBG() << "Processing block with " << block.getNumArguments()
84+
<< " arguments";
7485
getOrCreate<Executable>(getProgramPointBefore(&block))
7586
->blockContentSubscribe(this);
7687
visitBlock(&block);
77-
for (Operation &op : block)
78-
if (failed(initializeRecursively(&op)))
88+
for (Operation &op : block) {
89+
LDBG() << "Recursively initializing nested operation: " << op.getName();
90+
if (failed(initializeRecursively(&op))) {
91+
LDBG() << "Failed to initialize nested operation: " << op.getName();
7992
return failure();
93+
}
94+
}
8095
}
8196
}
8297

98+
LDBG() << "Successfully completed recursive initialization for operation: "
99+
<< op->getName();
83100
return success();
84101
}
85102

@@ -409,11 +426,20 @@ static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
409426

410427
LogicalResult
411428
AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
429+
LDBG() << "Visiting operation: " << op->getName() << " with "
430+
<< op->getNumOperands() << " operands and " << op->getNumResults()
431+
<< " results";
432+
412433
// If we're in a dead block, bail out.
413434
if (op->getBlock() != nullptr &&
414-
!getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
435+
!getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))
436+
->isLive()) {
437+
LDBG() << "Operation is in dead block, bailing out";
415438
return success();
439+
}
416440

441+
LDBG() << "Creating lattice elements for " << op->getNumOperands()
442+
<< " operands and " << op->getNumResults() << " results";
417443
SmallVector<AbstractSparseLattice *> operandLattices =
418444
getLatticeElements(op->getOperands());
419445
SmallVector<const AbstractSparseLattice *> resultLattices =
@@ -422,11 +448,15 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
422448
// Block arguments of region branch operations flow back into the operands
423449
// of the parent op
424450
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
451+
LDBG() << "Processing RegionBranchOpInterface operation";
425452
visitRegionSuccessors(branch, operandLattices);
426453
return success();
427454
}
428455

429456
if (auto branch = dyn_cast<BranchOpInterface>(op)) {
457+
LDBG() << "Processing BranchOpInterface operation with "
458+
<< op->getNumSuccessors() << " successors";
459+
430460
// Block arguments of successor blocks flow back into our operands.
431461

432462
// We remember all operands not forwarded to any block in a BitVector.
@@ -463,6 +493,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
463493
// For function calls, connect the arguments of the entry blocks to the
464494
// operands of the call op that are forwarded to these arguments.
465495
if (auto call = dyn_cast<CallOpInterface>(op)) {
496+
LDBG() << "Processing CallOpInterface operation";
466497
Operation *callableOp = call.resolveCallableInTable(&symbolTable);
467498
if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
468499
// Not all operands of a call op forward to arguments. Such operands are
@@ -513,19 +544,24 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
513544
// of this op itself and the operands of the terminators of the regions of
514545
// this op.
515546
if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
547+
LDBG() << "Processing RegionBranchTerminatorOpInterface operation";
516548
if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
517549
visitRegionSuccessorsFromTerminator(terminator, branch);
518550
return success();
519551
}
520552
}
521553

522554
if (op->hasTrait<OpTrait::ReturnLike>()) {
555+
LDBG() << "Processing ReturnLike operation";
523556
// Going backwards, the operands of the return are derived from the
524557
// results of all CallOps calling this CallableOp.
525-
if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp()))
558+
if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
559+
LDBG() << "Callable parent found, visiting callable operation";
526560
return visitCallableOperation(op, callable, operandLattices);
561+
}
527562
}
528563

564+
LDBG() << "Using default visitOperationImpl for operation: " << op->getName();
529565
return visitOperationImpl(op, operandLattices, resultLattices);
530566
}
531567

mlir/lib/Transforms/RemoveDeadValues.cpp

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -258,18 +258,17 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
258258
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
259259
DenseSet<Value> &nonLiveSet,
260260
RDVFinalCleanupList &cl) {
261-
LDBG() << "Processing simple op: " << *op;
262261
if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
263-
LDBG()
264-
<< "Simple op is not memory effect free or has live results, skipping: "
265-
<< *op;
262+
LDBG() << "Simple op is not memory effect free or has live results, "
263+
"preserving it: "
264+
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
266265
return;
267266
}
268267

269268
LDBG()
270269
<< "Simple op has all dead results and is memory effect free, scheduling "
271270
"for removal: "
272-
<< *op;
271+
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
273272
cl.operations.push_back(op);
274273
collectNonLiveValues(nonLiveSet, op->getResults(),
275274
BitVector(op->getNumResults(), true));
@@ -728,19 +727,31 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
728727
/// Removes dead values collected in RDVFinalCleanupList.
729728
/// To be run once when all dead values have been collected.
730729
static void cleanUpDeadVals(RDVFinalCleanupList &list) {
730+
LDBG() << "Starting cleanup of dead values...";
731+
731732
// 1. Operations
733+
LDBG() << "Cleaning up " << list.operations.size() << " operations";
732734
for (auto &op : list.operations) {
735+
LDBG() << "Erasing operation: "
736+
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
733737
op->dropAllUses();
734738
op->erase();
735739
}
736740

737741
// 2. Values
742+
LDBG() << "Cleaning up " << list.values.size() << " values";
738743
for (auto &v : list.values) {
744+
LDBG() << "Dropping all uses of value: " << v;
739745
v.dropAllUses();
740746
}
741747

742748
// 3. Functions
749+
LDBG() << "Cleaning up " << list.functions.size() << " functions";
743750
for (auto &f : list.functions) {
751+
LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName();
752+
LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments";
753+
LDBG() << " Erasing " << f.nonLiveRets.count()
754+
<< " non-live return values";
744755
// Some functions may not allow erasing arguments or results. These calls
745756
// return failure in such cases without modifying the function, so it's okay
746757
// to proceed.
@@ -749,44 +760,67 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
749760
}
750761

751762
// 4. Operands
763+
LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
752764
for (OperationToCleanup &o : list.operands) {
753-
if (o.op->getNumOperands() > 0)
765+
if (o.op->getNumOperands() > 0) {
766+
LDBG() << "Erasing " << o.nonLive.count()
767+
<< " non-live operands from operation: "
768+
<< OpWithFlags(o.op, OpPrintingFlags().skipRegions());
754769
o.op->eraseOperands(o.nonLive);
770+
}
755771
}
756772

757773
// 5. Results
774+
LDBG() << "Cleaning up " << list.results.size() << " result lists";
758775
for (auto &r : list.results) {
776+
LDBG() << "Erasing " << r.nonLive.count()
777+
<< " non-live results from operation: "
778+
<< OpWithFlags(r.op, OpPrintingFlags().skipRegions());
759779
dropUsesAndEraseResults(r.op, r.nonLive);
760780
}
761781

762782
// 6. Blocks
783+
LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
763784
for (auto &b : list.blocks) {
764785
// blocks that are accessed via multiple codepaths processed once
765786
if (b.b->getNumArguments() != b.nonLiveArgs.size())
766787
continue;
788+
LDBG() << "Erasing " << b.nonLiveArgs.count()
789+
<< " non-live arguments from block: " << b.b;
767790
// it iterates backwards because erase invalidates all successor indexes
768791
for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
769792
if (!b.nonLiveArgs[i])
770793
continue;
794+
LDBG() << " Erasing block argument " << i << ": " << b.b->getArgument(i);
771795
b.b->getArgument(i).dropAllUses();
772796
b.b->eraseArgument(i);
773797
}
774798
}
775799

776800
// 7. Successor Operands
801+
LDBG() << "Cleaning up " << list.successorOperands.size()
802+
<< " successor operand lists";
777803
for (auto &op : list.successorOperands) {
778804
SuccessorOperands successorOperands =
779805
op.branch.getSuccessorOperands(op.successorIndex);
780806
// blocks that are accessed via multiple codepaths processed once
781807
if (successorOperands.size() != op.nonLiveOperands.size())
782808
continue;
809+
LDBG() << "Erasing " << op.nonLiveOperands.count()
810+
<< " non-live successor operands from successor "
811+
<< op.successorIndex << " of branch: "
812+
<< OpWithFlags(op.branch, OpPrintingFlags().skipRegions());
783813
// it iterates backwards because erase invalidates all successor indexes
784814
for (int i = successorOperands.size() - 1; i >= 0; --i) {
785815
if (!op.nonLiveOperands[i])
786816
continue;
817+
LDBG() << " Erasing successor operand " << i << ": "
818+
<< successorOperands[i];
787819
successorOperands.erase(i);
788820
}
789821
}
822+
823+
LDBG() << "Finished cleanup of dead values";
790824
}
791825

792826
struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {

mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,3 +283,23 @@ func.func @test_10_negative() -> (i32) {
283283
%0:2 = func.call @private_1() : () -> (i32, i32)
284284
return %0#0 : i32
285285
}
286+
287+
// -----
288+
289+
// Test that we correctly set a liveness value for operations in dead block.
290+
// These won't be visited by the dataflow framework so the analysis need to
291+
// explicitly manage them.
292+
// CHECK-LABEL: test_tag: dead_block_cmpi:
293+
// CHECK-NEXT: operand #0: not live
294+
// CHECK-NEXT: operand #1: not live
295+
// CHECK-NEXT: result #0: not live
296+
func.func @dead_block() {
297+
%false = arith.constant false
298+
%zero = arith.constant 0 : i64
299+
cf.cond_br %false, ^bb1, ^bb4
300+
^bb1:
301+
%3 = arith.cmpi eq, %zero, %zero {tag = "dead_block_cmpi"} : i64
302+
cf.br ^bb1
303+
^bb4:
304+
return
305+
}

mlir/test/Transforms/remove-dead-values.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,3 +571,24 @@ module @return_void_with_unused_argument {
571571
}
572572
}
573573

574+
// -----
575+
576+
// CHECK-LABEL: module @dynamically_unreachable
577+
module @dynamically_unreachable {
578+
func.func @dynamically_unreachable() {
579+
// This value is used by an operation in a dynamically unreachable block.
580+
%zero = arith.constant 0 : i64
581+
582+
// Dataflow analysis knows from the constant condition that
583+
// ^bb1 is unreachable
584+
%false = arith.constant false
585+
cf.cond_br %false, ^bb1, ^bb4
586+
^bb1:
587+
// This unreachable operation should be removed.
588+
// CHECK-NOT: arith.cmpi
589+
%3 = arith.cmpi eq, %zero, %zero : i64
590+
cf.br ^bb1
591+
^bb4:
592+
return
593+
}
594+
}

mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ struct TestLivenessAnalysisPass
3333

3434
void runOnOperation() override {
3535
auto &livenessAnalysis = getAnalysis<RunLivenessAnalysis>();
36-
3736
Operation *op = getOperation();
3837

3938
raw_ostream &os = llvm::outs();

0 commit comments

Comments
 (0)