Skip to content

Commit 0227b79

Browse files
authored
[mlir][nfc] Minor cleanups in DeadCodeAnalysis (#159232)
* Remove `getOperandValuesImpl` since its only used once. * Extract common logic from `DeadCodeAnalysis::visitRegion{BranchOperation,Terminator}` into a new function `DeadCodeAnalysis::visitRegionBranchEdges`. In particular, both functions do the following: * Detect live region branch edges (similar to CFGEdge); * For each edge, mark the successor program point as executable (so that subsequent program gets visited); * For each edge, store the information of the predecessor op and arguments (so that other analyses know what states to join into the successor program point). One caveat is that, before this PR, in `visitRegionTerminator`, the successor program point is only marked as live if it is the start of a block; after this PR, the successor program point is consistently marked as live regardless what it is, which makes the behavior equal to `visitBranchOperation`. This minor fix improves consistency, but at this point it is still NFC, because the rest of the dataflow analysis framework only cares about liveness at block level, and the liveness information in the middle of a block isn't read anyway. This probably will change once [early-exits](https://discourse.llvm.org/t/rfc-region-based-control-flow-with-early-exits-in-mlir/76998) are supported.
1 parent 2bf62e7 commit 0227b79

File tree

2 files changed

+38
-50
lines changed

2 files changed

+38
-50
lines changed

mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "mlir/Analysis/DataFlowFramework.h"
1919
#include "mlir/IR/SymbolTable.h"
20+
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2021
#include "llvm/ADT/SmallPtrSet.h"
2122
#include <optional>
2223

@@ -200,6 +201,13 @@ class DeadCodeAnalysis : public DataFlowAnalysis {
200201
/// which are live from the current block.
201202
void visitBranchOperation(BranchOpInterface branch);
202203

204+
/// Visit region branch edges from `predecessorOp` to a list of successors.
205+
/// For each edge, mark the successor program point as executable, and record
206+
/// the predecessor information in its `PredecessorState`.
207+
void visitRegionBranchEdges(RegionBranchOpInterface regionBranchOp,
208+
Operation *predecessorOp,
209+
const SmallVector<RegionSuccessor> &successors);
210+
203211
/// Visit the given region branch operation, which defines regions, and
204212
/// compute any necessary lattice state. This also resolves the lattice state
205213
/// of both the operation results and any nested regions.

mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp

Lines changed: 30 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -444,30 +444,21 @@ void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
444444
/// Get the constant values of the operands of an operation. If any of the
445445
/// constant value lattices are uninitialized, return std::nullopt to indicate
446446
/// the analysis should bail out.
447-
static std::optional<SmallVector<Attribute>> getOperandValuesImpl(
448-
Operation *op,
449-
function_ref<const Lattice<ConstantValue> *(Value)> getLattice) {
447+
std::optional<SmallVector<Attribute>>
448+
DeadCodeAnalysis::getOperandValues(Operation *op) {
450449
SmallVector<Attribute> operands;
451450
operands.reserve(op->getNumOperands());
452451
for (Value operand : op->getOperands()) {
453-
const Lattice<ConstantValue> *cv = getLattice(operand);
452+
Lattice<ConstantValue> *cv = getOrCreate<Lattice<ConstantValue>>(operand);
453+
cv->useDefSubscribe(this);
454454
// If any of the operands' values are uninitialized, bail out.
455455
if (cv->getValue().isUninitialized())
456-
return {};
456+
return std::nullopt;
457457
operands.push_back(cv->getValue().getConstantValue());
458458
}
459459
return operands;
460460
}
461461

462-
std::optional<SmallVector<Attribute>>
463-
DeadCodeAnalysis::getOperandValues(Operation *op) {
464-
return getOperandValuesImpl(op, [&](Value value) {
465-
auto *lattice = getOrCreate<Lattice<ConstantValue>>(value);
466-
lattice->useDefSubscribe(this);
467-
return lattice;
468-
});
469-
}
470-
471462
void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
472463
LDBG() << "visitBranchOperation: "
473464
<< OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions());
@@ -498,23 +489,8 @@ void DeadCodeAnalysis::visitRegionBranchOperation(
498489

499490
SmallVector<RegionSuccessor> successors;
500491
branch.getEntrySuccessorRegions(*operands, successors);
501-
for (const RegionSuccessor &successor : successors) {
502-
// The successor can be either an entry block or the parent operation.
503-
ProgramPoint *point =
504-
successor.getSuccessor()
505-
? getProgramPointBefore(&successor.getSuccessor()->front())
506-
: getProgramPointAfter(branch);
507-
// Mark the entry block as executable.
508-
auto *state = getOrCreate<Executable>(point);
509-
propagateIfChanged(state, state->setToLive());
510-
LDBG() << "Marked region successor live: " << point;
511-
// Add the parent op as a predecessor.
512-
auto *predecessors = getOrCreate<PredecessorState>(point);
513-
propagateIfChanged(
514-
predecessors,
515-
predecessors->join(branch, successor.getSuccessorInputs()));
516-
LDBG() << "Added region branch as predecessor for successor: " << point;
517-
}
492+
493+
visitRegionBranchEdges(branch, branch.getOperation(), successors);
518494
}
519495

520496
void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
@@ -530,26 +506,30 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
530506
else
531507
branch.getSuccessorRegions(op->getParentRegion(), successors);
532508

533-
// Mark successor region entry blocks as executable and add this op to the
534-
// list of predecessors.
509+
visitRegionBranchEdges(branch, op, successors);
510+
}
511+
512+
void DeadCodeAnalysis::visitRegionBranchEdges(
513+
RegionBranchOpInterface regionBranchOp, Operation *predecessorOp,
514+
const SmallVector<RegionSuccessor> &successors) {
535515
for (const RegionSuccessor &successor : successors) {
536-
PredecessorState *predecessors;
537-
if (Region *region = successor.getSuccessor()) {
538-
auto *state =
539-
getOrCreate<Executable>(getProgramPointBefore(&region->front()));
540-
propagateIfChanged(state, state->setToLive());
541-
LDBG() << "Marked region entry block live for region: " << region;
542-
predecessors = getOrCreate<PredecessorState>(
543-
getProgramPointBefore(&region->front()));
544-
} else {
545-
// Add this terminator as a predecessor to the parent op.
546-
predecessors =
547-
getOrCreate<PredecessorState>(getProgramPointAfter(branch));
548-
}
549-
propagateIfChanged(predecessors,
550-
predecessors->join(op, successor.getSuccessorInputs()));
551-
LDBG() << "Added region terminator as predecessor for successor: "
552-
<< (successor.getSuccessor() ? "region entry" : "parent op");
516+
// The successor can be either an entry block or the parent operation.
517+
ProgramPoint *point =
518+
successor.getSuccessor()
519+
? getProgramPointBefore(&successor.getSuccessor()->front())
520+
: getProgramPointAfter(regionBranchOp);
521+
522+
// Mark the entry block as executable.
523+
auto *state = getOrCreate<Executable>(point);
524+
propagateIfChanged(state, state->setToLive());
525+
LDBG() << "Marked region successor live: " << point;
526+
527+
// Add the parent op as a predecessor.
528+
auto *predecessors = getOrCreate<PredecessorState>(point);
529+
propagateIfChanged(
530+
predecessors,
531+
predecessors->join(predecessorOp, successor.getSuccessorInputs()));
532+
LDBG() << "Added region branch as predecessor for successor: " << point;
553533
}
554534
}
555535

0 commit comments

Comments
 (0)