Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,13 @@ class DeadCodeAnalysis : public DataFlowAnalysis {
/// considered an external callable.
Operation *analysisScope;

/// Whether the analysis scope has a symbol table. This is used to avoid
/// resolving callables outside the analysis scope.
/// It is updated when recursing into a region in case where the top-level
/// operation does not have a symbol table, but one is encountered in a nested
/// region.
bool hasSymbolTable = false;

/// A symbol table used for O(1) symbol lookups during simplification.
SymbolTableCollection symbolTable;
};
Expand Down
37 changes: 28 additions & 9 deletions mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/DebugLog.h"
Expand Down Expand Up @@ -159,6 +160,7 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
LDBG() << "[init] Entering initializeSymbolCallables for top-level op: "
<< OpWithFlags(top, OpPrintingFlags().skipRegions());
analysisScope = top;
hasSymbolTable = top->hasTrait<OpTrait::SymbolTable>();
auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
LDBG() << "[init] Processing symbol table op: "
<< OpWithFlags(symTable, OpPrintingFlags().skipRegions());
Expand Down Expand Up @@ -260,14 +262,25 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
return failure();
}
// Recurse on nested operations.
for (Region &region : op->getRegions()) {
LDBG() << "[init] Recursing into region of op: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
for (Operation &nestedOp : region.getOps()) {
LDBG() << "[init] Recursing into nested op: "
<< OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions());
if (failed(initializeRecursively(&nestedOp)))
return failure();
if (op->getNumRegions()) {
// If we haven't seen a symbol table yet, check if the current operation
// has one. If so, update the flag to allow for resolving callables in
// nested regions.
bool savedHasSymbolTable = hasSymbolTable;
auto restoreHasSymbolTable =
llvm::make_scope_exit([&]() { hasSymbolTable = savedHasSymbolTable; });
if (!hasSymbolTable && op->hasTrait<OpTrait::SymbolTable>())
hasSymbolTable = true;

for (Region &region : op->getRegions()) {
LDBG() << "[init] Recursing into region of op: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
for (Operation &nestedOp : region.getOps()) {
LDBG() << "[init] Recursing into nested op: "
<< OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions());
if (failed(initializeRecursively(&nestedOp)))
return failure();
}
}
}
LDBG() << "[init] Finished initializeRecursively for op: "
Expand Down Expand Up @@ -388,7 +401,13 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) {
void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
LDBG() << "visitCallOperation: "
<< OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions());
Operation *callableOp = call.resolveCallableInTable(&symbolTable);

Operation *callableOp = nullptr;
if (hasSymbolTable)
callableOp = call.resolveCallableInTable(&symbolTable);
else
LDBG()
<< "No symbol table present in analysis scope, can't resolve callable";

// A call to a externally-defined callable has unknown predecessors.
const auto isExternalCallable = [this](Operation *op) {
Expand Down
22 changes: 14 additions & 8 deletions mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,12 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation(
AbstractDenseLattice *after) {
// Allow for customizing the behavior of calls to external symbols, including
// when the analysis is explicitly marked as non-interprocedural.
auto callable =
dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
if (!getSolverConfig().isInterprocedural() ||
(callable && !callable.getCallableRegion())) {
auto isExternalCallable = [&]() {
auto callable =
dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
return callable && !callable.getCallableRegion();
};
if (!getSolverConfig().isInterprocedural() || isExternalCallable()) {
return visitCallControlFlowTransfer(
call, CallControlFlowAction::ExternalCallee, before, after);
}
Expand Down Expand Up @@ -290,19 +292,23 @@ AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint *point) {
void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
CallOpInterface call, const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
// If the solver is not interprocedural, let the hook handle it as an external
// callee.
if (!getSolverConfig().isInterprocedural())
return visitCallControlFlowTransfer(
call, CallControlFlowAction::ExternalCallee, after, before);

// Find the callee.
Operation *callee = call.resolveCallableInTable(&symbolTable);

auto callable = dyn_cast_or_null<CallableOpInterface>(callee);
// No region means the callee is only declared in this module.
// If that is the case or if the solver is not interprocedural,
// let the hook handle it.
if (!getSolverConfig().isInterprocedural() ||
(callable && (!callable.getCallableRegion() ||
callable.getCallableRegion()->empty()))) {
if (callable &&
(!callable.getCallableRegion() || callable.getCallableRegion()->empty()))
return visitCallControlFlowTransfer(
call, CallControlFlowAction::ExternalCallee, after, before);
}

if (!callable)
return setToExitState(before);
Expand Down
10 changes: 6 additions & 4 deletions mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,12 @@ LogicalResult AbstractSparseForwardDataFlowAnalysis::visitCallOperation(
ArrayRef<AbstractSparseLattice *> resultLattices) {
// If the call operation is to an external function, attempt to infer the
// results from the call arguments.
auto callable =
dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
if (!getSolverConfig().isInterprocedural() ||
(callable && !callable.getCallableRegion())) {
auto isExternalCallable = [&]() {
auto callable =
dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
return callable && !callable.getCallableRegion();
};
if (!getSolverConfig().isInterprocedural() || isExternalCallable()) {
visitExternalCallImpl(call, operandLattices, resultLattices);
return success();
}
Expand Down
7 changes: 7 additions & 0 deletions mlir/lib/Analysis/DataFlowFramework.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/iterator.h"
Expand Down Expand Up @@ -109,6 +110,12 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
isRunning = true;
auto guard = llvm::make_scope_exit([&]() { isRunning = false; });

bool isInterprocedural = config.isInterprocedural();
auto restoreInterprocedural = llvm::make_scope_exit(
[&]() { config.setInterprocedural(isInterprocedural); });
if (isInterprocedural && !top->hasTrait<OpTrait::SymbolTable>())
config.setInterprocedural(false);

// Initialize equivalent lattice anchors.
for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
analysis.initializeEquivalentLatticeAnchor(top);
Expand Down