Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions mlir/include/mlir/IR/Operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,8 @@ inline raw_ostream &operator<<(raw_ostream &os, const Operation &op) {
/// useful to act as a "stream modifier" to customize printing an operation
/// with a stream using the operator<< overload, e.g.:
/// llvm::dbgs() << OpWithFlags(op, OpPrintingFlags().skipRegions());
/// This always prints the operation with the local scope, to avoid introducing
/// spurious newlines in the stream.
class OpWithFlags {
public:
OpWithFlags(Operation *op, OpPrintingFlags flags = {})
Expand All @@ -1116,11 +1118,11 @@ class OpWithFlags {
private:
Operation *op;
OpPrintingFlags theFlags;
friend raw_ostream &operator<<(raw_ostream &os, const OpWithFlags &op);
friend raw_ostream &operator<<(raw_ostream &os, OpWithFlags op);
};

inline raw_ostream &operator<<(raw_ostream &os,
const OpWithFlags &opWithFlags) {
inline raw_ostream &operator<<(raw_ostream &os, OpWithFlags opWithFlags) {
opWithFlags.flags().useLocalScope();
opWithFlags.op->print(os, opWithFlags.flags());
return os;
}
Expand Down
85 changes: 57 additions & 28 deletions mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,17 @@ void Executable::onUpdate(DataFlowSolver *solver) const {
void PredecessorState::print(raw_ostream &os) const {
if (allPredecessorsKnown())
os << "(all) ";
os << "predecessors:\n";
for (Operation *op : getKnownPredecessors())
os << " " << *op << "\n";
os << "predecessors:";
if (getKnownPredecessors().empty())
os << " (none)";
else
os << "\n";
llvm::interleave(
getKnownPredecessors(), os,
[&](Operation *op) {
os << " " << OpWithFlags(op, OpPrintingFlags().skipRegions());
},
"\n");
}

ChangeResult PredecessorState::join(Operation *predecessor) {
Expand Down Expand Up @@ -128,15 +136,16 @@ DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver)

LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
LDBG() << "Initializing DeadCodeAnalysis for top-level op: "
<< top->getName();
<< OpWithFlags(top, OpPrintingFlags().skipRegions());
// Mark the top-level blocks as executable.
for (Region &region : top->getRegions()) {
if (region.empty())
continue;
auto *state =
getOrCreate<Executable>(getProgramPointBefore(&region.front()));
propagateIfChanged(state, state->setToLive());
LDBG() << "Marked entry block live for region in op: " << top->getName();
LDBG() << "Marked entry block live for region in op: "
<< OpWithFlags(top, OpPrintingFlags().skipRegions());
}

// Mark as overdefined the predecessors of symbol callables with potentially
Expand All @@ -151,14 +160,16 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
<< OpWithFlags(top, OpPrintingFlags().skipRegions());
analysisScope = top;
auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
LDBG() << "[init] Processing symbol table op: " << symTable->getName();
LDBG() << "[init] Processing symbol table op: "
<< OpWithFlags(symTable, OpPrintingFlags().skipRegions());
Region &symbolTableRegion = symTable->getRegion(0);
Block *symbolTableBlock = &symbolTableRegion.front();

bool foundSymbolCallable = false;
for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
LDBG() << "[init] Found CallableOpInterface: "
<< callable.getOperation()->getName();
<< OpWithFlags(callable.getOperation(),
OpPrintingFlags().skipRegions());
Region *callableRegion = callable.getCallableRegion();
if (!callableRegion)
continue;
Expand All @@ -173,7 +184,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
getOrCreate<PredecessorState>(getProgramPointAfter(callable));
propagateIfChanged(state, state->setHasUnknownPredecessors());
LDBG() << "[init] Marked callable as having unknown predecessors: "
<< callable.getOperation()->getName();
<< OpWithFlags(callable.getOperation(),
OpPrintingFlags().skipRegions());
}
foundSymbolCallable = true;
}
Expand All @@ -196,7 +208,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
propagateIfChanged(state, state->setHasUnknownPredecessors());
LDBG() << "[init] Marked nested callable as "
"having unknown predecessors: "
<< callable.getOperation()->getName();
<< OpWithFlags(callable.getOperation(),
OpPrintingFlags().skipRegions());
});
}

Expand All @@ -212,7 +225,7 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
propagateIfChanged(state, state->setHasUnknownPredecessors());
LDBG() << "[init] Found non-call use for symbol, "
"marked as having unknown predecessors: "
<< symbol->getName();
<< OpWithFlags(symbol, OpPrintingFlags().skipRegions());
}
};
SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(),
Expand All @@ -235,7 +248,8 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
// Initialize the analysis by visiting every op with control-flow semantics.
if (op->getNumRegions() || op->getNumSuccessors() ||
isRegionOrCallableReturn(op) || isa<CallOpInterface>(op)) {
LDBG() << "[init] Visiting op with control-flow semantics: " << *op;
LDBG() << "[init] Visiting op with control-flow semantics: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
// When the liveness of the parent block changes, make sure to
// re-invoke the analysis on the op.
if (op->getBlock())
Expand All @@ -247,7 +261,8 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
}
// Recurse on nested operations.
for (Region &region : op->getRegions()) {
LDBG() << "[init] Recursing into region of op: " << op->getName();
LDBG() << "[init] Recursing into region of op: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
for (Operation &nestedOp : region.getOps()) {
LDBG() << "[init] Recursing into nested op: "
<< OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions());
Expand All @@ -270,14 +285,16 @@ void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
}

void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
LDBG() << "Marking entry blocks live for op: " << op->getName();
LDBG() << "Marking entry blocks live for op: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
for (Region &region : op->getRegions()) {
if (region.empty())
continue;
auto *state =
getOrCreate<Executable>(getProgramPointBefore(&region.front()));
propagateIfChanged(state, state->setToLive());
LDBG() << "Marked entry block live for region in op: " << op->getName();
LDBG() << "Marked entry block live for region in op: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
}
}

Expand All @@ -286,32 +303,37 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
if (point->isBlockStart())
return success();
Operation *op = point->getPrevOp();
LDBG() << "Visiting operation: " << *op;
LDBG() << "Visiting operation: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());

// If the parent block is not executable, there is nothing to do.
if (op->getBlock() != nullptr &&
!getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))
->isLive()) {
LDBG() << "Parent block not live, skipping op: " << *op;
LDBG() << "Parent block not live, skipping op: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
return success();
}

// We have a live call op. Add this as a live predecessor of the callee.
if (auto call = dyn_cast<CallOpInterface>(op)) {
LDBG() << "Visiting call operation: " << *op;
LDBG() << "Visiting call operation: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
visitCallOperation(call);
}

// Visit the regions.
if (op->getNumRegions()) {
// Check if we can reason about the region control-flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
LDBG() << "Visiting region branch operation: " << *op;
LDBG() << "Visiting region branch operation: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
visitRegionBranchOperation(branch);

// Check if this is a callable operation.
} else if (auto callable = dyn_cast<CallableOpInterface>(op)) {
LDBG() << "Visiting callable operation: " << *op;
LDBG() << "Visiting callable operation: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
const auto *callsites = getOrCreateFor<PredecessorState>(
getProgramPointAfter(op), getProgramPointAfter(callable));

Expand All @@ -323,19 +345,22 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {

// Otherwise, conservatively mark all entry blocks as executable.
} else {
LDBG() << "Marking all entry blocks live for op: " << *op;
LDBG() << "Marking all entry blocks live for op: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
markEntryBlocksLive(op);
}
}

if (isRegionOrCallableReturn(op)) {
if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
LDBG() << "Visiting region terminator: " << *op;
LDBG() << "Visiting region terminator: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
// Visit the exiting terminator of a region.
visitRegionTerminator(op, branch);
} else if (auto callable =
dyn_cast<CallableOpInterface>(op->getParentOp())) {
LDBG() << "Visiting callable terminator: " << *op;
LDBG() << "Visiting callable terminator: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
// Visit the exiting terminator of a callable.
visitCallableTerminator(op, callable);
}
Expand All @@ -344,12 +369,14 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
if (op->getNumSuccessors()) {
// Check if we can reason about the control-flow.
if (auto branch = dyn_cast<BranchOpInterface>(op)) {
LDBG() << "Visiting branch operation: " << *op;
LDBG() << "Visiting branch operation: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
visitBranchOperation(branch);

// Otherwise, conservatively mark all successors as exectuable.
} else {
LDBG() << "Marking all successors live for op: " << *op;
LDBG() << "Marking all successors live for op: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
for (Block *successor : op->getSuccessors())
markEdgeLive(op->getBlock(), successor);
}
Expand All @@ -359,7 +386,8 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
}

void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
LDBG() << "visitCallOperation: " << call.getOperation()->getName();
LDBG() << "visitCallOperation: "
<< OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions());
Operation *callableOp = call.resolveCallableInTable(&symbolTable);

// A call to a externally-defined callable has unknown predecessors.
Expand Down Expand Up @@ -442,7 +470,8 @@ void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {

void DeadCodeAnalysis::visitRegionBranchOperation(
RegionBranchOpInterface branch) {
LDBG() << "visitRegionBranchOperation: " << branch.getOperation()->getName();
LDBG() << "visitRegionBranchOperation: "
<< OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions());
// Try to deduce which regions are executable.
std::optional<SmallVector<Attribute>> operands = getOperandValues(branch);
if (!operands)
Expand Down Expand Up @@ -519,14 +548,14 @@ void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
if (canResolve) {
propagateIfChanged(predecessors, predecessors->join(op));
LDBG() << "Added callable terminator as predecessor for callsite: "
<< predecessor->getName();
<< OpWithFlags(predecessor, OpPrintingFlags().skipRegions());
} else {
// If the terminator is not a return-like, then conservatively assume we
// can't resolve the predecessor.
propagateIfChanged(predecessors,
predecessors->setHasUnknownPredecessors());
LDBG() << "Could not resolve callable terminator for callsite: "
<< predecessor->getName();
<< OpWithFlags(predecessor, OpPrintingFlags().skipRegions());
}
}
}
9 changes: 5 additions & 4 deletions mlir/lib/Analysis/DataFlowFramework.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,12 @@ void ProgramPoint::print(raw_ostream &os) const {
return;
}
if (!isBlockStart()) {
os << "<after operation>:";
return getPrevOp()->print(os, OpPrintingFlags().skipRegions());
os << "<after operation>:"
<< OpWithFlags(getPrevOp(), OpPrintingFlags().skipRegions());
return;
}
os << "<before operation>:";
return getNextOp()->print(os, OpPrintingFlags().skipRegions());
os << "<before operation>:"
<< OpWithFlags(getNextOp(), OpPrintingFlags().skipRegions());
}

//===----------------------------------------------------------------------===//
Expand Down