Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 76 additions & 5 deletions mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,14 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include <cassert>
#include <optional>

#define DEBUG_TYPE "dead-code-analysis"
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

using namespace mlir;
using namespace mlir::dataflow;

Expand Down Expand Up @@ -122,13 +127,15 @@ DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver)
}

LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
LDBG("Initializing DeadCodeAnalysis for top-level op: " << top->getName());
// Mark the top-level blocks as executable.
for (Region &region : top->getRegions()) {
if (region.empty())
continue;
auto *state =
getOrCreate<Executable>(getProgramPointBefore(&region.front()));
propagateIfChanged(state, state->setToLive());
LDBG("Marked entry block live for region in op: " << top->getName());
}

// Mark as overdefined the predecessors of symbol callables with potentially
Expand All @@ -139,13 +146,18 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
}

void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
LDBG("[init] Entering initializeSymbolCallables for top-level op: "
<< top->getName());
analysisScope = top;
auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
LDBG("[init] Processing symbol table op: " << symTable->getName());
Region &symbolTableRegion = symTable->getRegion(0);
Block *symbolTableBlock = &symbolTableRegion.front();

bool foundSymbolCallable = false;
for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
LDBG("[init] Found CallableOpInterface: "
<< callable.getOperation()->getName());
Region *callableRegion = callable.getCallableRegion();
if (!callableRegion)
continue;
Expand All @@ -159,6 +171,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
auto *state =
getOrCreate<PredecessorState>(getProgramPointAfter(callable));
propagateIfChanged(state, state->setHasUnknownPredecessors());
LDBG("[init] Marked callable as having unknown predecessors: "
<< callable.getOperation()->getName());
}
foundSymbolCallable = true;
}
Expand All @@ -173,10 +187,15 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
if (!uses) {
// If we couldn't gather the symbol uses, conservatively assume that
// we can't track information for any nested symbols.
LDBG("[init] Could not gather symbol uses, conservatively marking "
"all nested callables as having unknown predecessors");
return top->walk([&](CallableOpInterface callable) {
auto *state =
getOrCreate<PredecessorState>(getProgramPointAfter(callable));
propagateIfChanged(state, state->setHasUnknownPredecessors());
LDBG("[init] Marked nested callable as "
"having unknown predecessors: "
<< callable.getOperation()->getName());
});
}

Expand All @@ -190,10 +209,15 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
continue;
auto *state = getOrCreate<PredecessorState>(getProgramPointAfter(symbol));
propagateIfChanged(state, state->setHasUnknownPredecessors());
LDBG("[init] Found non-call use for symbol, "
"marked as having unknown predecessors: "
<< symbol->getName());
}
};
SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(),
walkFn);
LDBG("[init] Finished initializeSymbolCallables for top-level op: "
<< top->getName());
}

/// Returns true if the operation is a returning terminator in region
Expand All @@ -205,9 +229,12 @@ static bool isRegionOrCallableReturn(Operation *op) {
}

LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
LDBG("[init] Entering initializeRecursively for op: " << op->getName()
<< " at " << op);
// Initialize the analysis by visiting every op with control-flow semantics.
if (op->getNumRegions() || op->getNumSuccessors() ||
isRegionOrCallableReturn(op) || isa<CallOpInterface>(op)) {
LDBG("[init] Visiting op with control-flow semantics: " << *op);
// When the liveness of the parent block changes, make sure to re-invoke the
// analysis on the op.
if (op->getBlock())
Expand All @@ -218,14 +245,22 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
return failure();
}
// Recurse on nested operations.
for (Region &region : op->getRegions())
for (Operation &op : region.getOps())
if (failed(initializeRecursively(&op)))
for (Region &region : op->getRegions()) {
LDBG("[init] Recursing into region of op: " << op->getName());
for (Operation &nestedOp : region.getOps()) {
LDBG("[init] Recursing into nested op: " << nestedOp.getName() << " at "
<< &nestedOp);
if (failed(initializeRecursively(&nestedOp)))
return failure();
}
}
LDBG("[init] Finished initializeRecursively for op: " << op->getName()
<< " at " << op);
return success();
}

void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
LDBG("Marking edge live from block " << from << " to block " << to);
auto *state = getOrCreate<Executable>(getProgramPointBefore(to));
propagateIfChanged(state, state->setToLive());
auto *edgeState =
Expand All @@ -234,37 +269,48 @@ void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
}

void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
LDBG("Marking entry blocks live for op: " << op->getName());
for (Region &region : op->getRegions()) {
if (region.empty())
continue;
auto *state =
getOrCreate<Executable>(getProgramPointBefore(&region.front()));
propagateIfChanged(state, state->setToLive());
LDBG("Marked entry block live for region in op: " << op->getName());
}
}

LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
LDBG("Visiting program point: " << point << " " << *point);
if (point->isBlockStart())
return success();
Operation *op = point->getPrevOp();
LDBG("Visiting operation: " << *op);

// If the parent block is not executable, there is nothing to do.
if (op->getBlock() != nullptr &&
!getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
!getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))
->isLive()) {
LDBG("Parent block not live, skipping op: " << *op);
return success();
}

// We have a live call op. Add this as a live predecessor of the callee.
if (auto call = dyn_cast<CallOpInterface>(op))
if (auto call = dyn_cast<CallOpInterface>(op)) {
LDBG("Visiting call operation: " << *op);
visitCallOperation(call);
}

// Visit the regions.
if (op->getNumRegions()) {
// Check if we can reason about the region control-flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
LDBG("Visiting region branch operation: " << *op);
visitRegionBranchOperation(branch);

// Check if this is a callable operation.
} else if (auto callable = dyn_cast<CallableOpInterface>(op)) {
LDBG("Visiting callable operation: " << *op);
const auto *callsites = getOrCreateFor<PredecessorState>(
getProgramPointAfter(op), getProgramPointAfter(callable));

Expand All @@ -276,16 +322,19 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {

// Otherwise, conservatively mark all entry blocks as executable.
} else {
LDBG("Marking all entry blocks live for op: " << *op);
markEntryBlocksLive(op);
}
}

if (isRegionOrCallableReturn(op)) {
if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
LDBG("Visiting region terminator: " << *op);
// Visit the exiting terminator of a region.
visitRegionTerminator(op, branch);
} else if (auto callable =
dyn_cast<CallableOpInterface>(op->getParentOp())) {
LDBG("Visiting callable terminator: " << *op);
// Visit the exiting terminator of a callable.
visitCallableTerminator(op, callable);
}
Expand All @@ -294,10 +343,12 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
if (op->getNumSuccessors()) {
// Check if we can reason about the control-flow.
if (auto branch = dyn_cast<BranchOpInterface>(op)) {
LDBG("Visiting branch operation: " << *op);
visitBranchOperation(branch);

// Otherwise, conservatively mark all successors as exectuable.
} else {
LDBG("Marking all successors live for op: " << *op);
for (Block *successor : op->getSuccessors())
markEdgeLive(op->getBlock(), successor);
}
Expand All @@ -307,6 +358,7 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
}

void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
LDBG("visitCallOperation: " << call.getOperation()->getName());
Operation *callableOp = call.resolveCallableInTable(&symbolTable);

// A call to a externally-defined callable has unknown predecessors.
Expand All @@ -329,11 +381,15 @@ void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
auto *callsites =
getOrCreate<PredecessorState>(getProgramPointAfter(callableOp));
propagateIfChanged(callsites, callsites->join(call));
LDBG("Added callsite as predecessor for callable: "
<< callableOp->getName());
} else {
// Mark this call op's predecessors as overdefined.
auto *predecessors =
getOrCreate<PredecessorState>(getProgramPointAfter(call));
propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors());
LDBG("Marked call op's predecessors as unknown for: "
<< call.getOperation()->getName());
}
}

Expand Down Expand Up @@ -365,22 +421,26 @@ DeadCodeAnalysis::getOperandValues(Operation *op) {
}

void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
LDBG("visitBranchOperation: " << branch.getOperation()->getName());
// Try to deduce a single successor for the branch.
std::optional<SmallVector<Attribute>> operands = getOperandValues(branch);
if (!operands)
return;

if (Block *successor = branch.getSuccessorForOperands(*operands)) {
markEdgeLive(branch->getBlock(), successor);
LDBG("Branch has single successor: " << successor);
} else {
// Otherwise, mark all successors as executable and outgoing edges.
for (Block *successor : branch->getSuccessors())
markEdgeLive(branch->getBlock(), successor);
LDBG("Branch has multiple/all successors live");
}
}

void DeadCodeAnalysis::visitRegionBranchOperation(
RegionBranchOpInterface branch) {
LDBG("visitRegionBranchOperation: " << branch.getOperation()->getName());
// Try to deduce which regions are executable.
std::optional<SmallVector<Attribute>> operands = getOperandValues(branch);
if (!operands)
Expand All @@ -397,16 +457,19 @@ void DeadCodeAnalysis::visitRegionBranchOperation(
// Mark the entry block as executable.
auto *state = getOrCreate<Executable>(point);
propagateIfChanged(state, state->setToLive());
LDBG("Marked region successor live: " << point);
// Add the parent op as a predecessor.
auto *predecessors = getOrCreate<PredecessorState>(point);
propagateIfChanged(
predecessors,
predecessors->join(branch, successor.getSuccessorInputs()));
LDBG("Added region branch as predecessor for successor: " << point);
}
}

void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
RegionBranchOpInterface branch) {
LDBG("visitRegionTerminator: " << *op);
std::optional<SmallVector<Attribute>> operands = getOperandValues(op);
if (!operands)
return;
Expand All @@ -425,6 +488,7 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
auto *state =
getOrCreate<Executable>(getProgramPointBefore(&region->front()));
propagateIfChanged(state, state->setToLive());
LDBG("Marked region entry block live for region: " << region);
predecessors = getOrCreate<PredecessorState>(
getProgramPointBefore(&region->front()));
} else {
Expand All @@ -434,11 +498,14 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
}
propagateIfChanged(predecessors,
predecessors->join(op, successor.getSuccessorInputs()));
LDBG("Added region terminator as predecessor for successor: "
<< (successor.getSuccessor() ? "region entry" : "parent op"));
}
}

void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
CallableOpInterface callable) {
LDBG("visitCallableTerminator: " << *op);
// Add as predecessors to all callsites this return op.
auto *callsites = getOrCreateFor<PredecessorState>(
getProgramPointAfter(op), getProgramPointAfter(callable));
Expand All @@ -449,11 +516,15 @@ void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
getOrCreate<PredecessorState>(getProgramPointAfter(predecessor));
if (canResolve) {
propagateIfChanged(predecessors, predecessors->join(op));
LDBG("Added callable terminator as predecessor for callsite: "
<< predecessor->getName());
} else {
// If the terminator is not a return-like, then conservatively assume we
// can't resolve the predecessor.
propagateIfChanged(predecessors,
predecessors->setHasUnknownPredecessors());
LDBG("Could not resolve callable terminator for callsite: "
<< predecessor->getName());
}
}
}
Loading
Loading