Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/ADT/SmallPtrSet.h"
#include <optional>

Expand Down Expand Up @@ -200,6 +201,13 @@ class DeadCodeAnalysis : public DataFlowAnalysis {
/// which are live from the current block.
void visitBranchOperation(BranchOpInterface branch);

/// Visit region branch edges from `predecessorOp` to a list of successors.
/// For each edge, mark the successor program point as executable, and record
/// the predecessor information in its `PredecessorState`.
void visitRegionBranchEdges(RegionBranchOpInterface regionBranchOp,
Operation *predecessorOp,
const SmallVector<RegionSuccessor> &successors);

/// Visit the given region branch operation, which defines regions, and
/// compute any necessary lattice state. This also resolves the lattice state
/// of both the operation results and any nested regions.
Expand Down
80 changes: 30 additions & 50 deletions mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,30 +444,21 @@ void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
/// Get the constant values of the operands of an operation. If any of the
/// constant value lattices are uninitialized, return std::nullopt to indicate
/// the analysis should bail out.
static std::optional<SmallVector<Attribute>> getOperandValuesImpl(
Operation *op,
function_ref<const Lattice<ConstantValue> *(Value)> getLattice) {
std::optional<SmallVector<Attribute>>
DeadCodeAnalysis::getOperandValues(Operation *op) {
SmallVector<Attribute> operands;
operands.reserve(op->getNumOperands());
for (Value operand : op->getOperands()) {
const Lattice<ConstantValue> *cv = getLattice(operand);
Lattice<ConstantValue> *cv = getOrCreate<Lattice<ConstantValue>>(operand);
cv->useDefSubscribe(this);
// If any of the operands' values are uninitialized, bail out.
if (cv->getValue().isUninitialized())
return {};
return std::nullopt;
operands.push_back(cv->getValue().getConstantValue());
}
return operands;
}

std::optional<SmallVector<Attribute>>
DeadCodeAnalysis::getOperandValues(Operation *op) {
return getOperandValuesImpl(op, [&](Value value) {
auto *lattice = getOrCreate<Lattice<ConstantValue>>(value);
lattice->useDefSubscribe(this);
return lattice;
});
}

void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
LDBG() << "visitBranchOperation: "
<< OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions());
Expand Down Expand Up @@ -498,23 +489,8 @@ void DeadCodeAnalysis::visitRegionBranchOperation(

SmallVector<RegionSuccessor> successors;
branch.getEntrySuccessorRegions(*operands, successors);
for (const RegionSuccessor &successor : successors) {
// The successor can be either an entry block or the parent operation.
ProgramPoint *point =
successor.getSuccessor()
? getProgramPointBefore(&successor.getSuccessor()->front())
: getProgramPointAfter(branch);
// Mark the entry block as executable.
auto *state = getOrCreate<Executable>(point);
propagateIfChanged(state, state->setToLive());
LDBG() << "Marked region successor live: " << point;
// Add the parent op as a predecessor.
auto *predecessors = getOrCreate<PredecessorState>(point);
propagateIfChanged(
predecessors,
predecessors->join(branch, successor.getSuccessorInputs()));
LDBG() << "Added region branch as predecessor for successor: " << point;
}

visitRegionBranchEdges(branch, branch.getOperation(), successors);
}

void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
Expand All @@ -530,26 +506,30 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
else
branch.getSuccessorRegions(op->getParentRegion(), successors);

// Mark successor region entry blocks as executable and add this op to the
// list of predecessors.
visitRegionBranchEdges(branch, op, successors);
}

void DeadCodeAnalysis::visitRegionBranchEdges(
RegionBranchOpInterface regionBranchOp, Operation *predecessorOp,
const SmallVector<RegionSuccessor> &successors) {
for (const RegionSuccessor &successor : successors) {
PredecessorState *predecessors;
if (Region *region = successor.getSuccessor()) {
auto *state =
getOrCreate<Executable>(getProgramPointBefore(&region->front()));
propagateIfChanged(state, state->setToLive());
LDBG() << "Marked region entry block live for region: " << region;
predecessors = getOrCreate<PredecessorState>(
getProgramPointBefore(&region->front()));
} else {
// Add this terminator as a predecessor to the parent op.
predecessors =
getOrCreate<PredecessorState>(getProgramPointAfter(branch));
}
propagateIfChanged(predecessors,
predecessors->join(op, successor.getSuccessorInputs()));
LDBG() << "Added region terminator as predecessor for successor: "
<< (successor.getSuccessor() ? "region entry" : "parent op");
// The successor can be either an entry block or the parent operation.
ProgramPoint *point =
successor.getSuccessor()
? getProgramPointBefore(&successor.getSuccessor()->front())
: getProgramPointAfter(regionBranchOp);

// Mark the entry block as executable.
auto *state = getOrCreate<Executable>(point);
propagateIfChanged(state, state->setToLive());
LDBG() << "Marked region successor live: " << point;

// Add the parent op as a predecessor.
auto *predecessors = getOrCreate<PredecessorState>(point);
propagateIfChanged(
predecessors,
predecessors->join(predecessorOp, successor.getSuccessorInputs()));
LDBG() << "Added region branch as predecessor for successor: " << point;
}
}

Expand Down