diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h index 2250db823b551..c7c405e1423cb 100644 --- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h @@ -229,6 +229,13 @@ class DeadCodeAnalysis : public DataFlowAnalysis { /// considered an external callable. Operation *analysisScope; + /// Whether the analysis scope has a symbol table. This is used to avoid + /// resolving callables outside the analysis scope. + /// It is updated when recursing into a region in case where the top-level + /// operation does not have a symbol table, but one is encountered in a nested + /// region. + bool hasSymbolTable = false; + /// A symbol table used for O(1) symbol lookups during simplification. SymbolTableCollection symbolTable; }; diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp index 9424eff3e6b6f..131c49c44171b 100644 --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -22,6 +22,7 @@ #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/DebugLog.h" @@ -159,6 +160,7 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { LDBG() << "[init] Entering initializeSymbolCallables for top-level op: " << OpWithFlags(top, OpPrintingFlags().skipRegions()); analysisScope = top; + hasSymbolTable = top->hasTrait(); auto walkFn = [&](Operation *symTable, bool allUsesVisible) { LDBG() << "[init] Processing symbol table op: " << OpWithFlags(symTable, OpPrintingFlags().skipRegions()); @@ -260,14 +262,25 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) { return failure(); } // Recurse on nested operations. - for (Region ®ion : op->getRegions()) { - LDBG() << "[init] Recursing into region of op: " - << OpWithFlags(op, OpPrintingFlags().skipRegions()); - for (Operation &nestedOp : region.getOps()) { - LDBG() << "[init] Recursing into nested op: " - << OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions()); - if (failed(initializeRecursively(&nestedOp))) - return failure(); + if (op->getNumRegions()) { + // If we haven't seen a symbol table yet, check if the current operation + // has one. If so, update the flag to allow for resolving callables in + // nested regions. + bool savedHasSymbolTable = hasSymbolTable; + auto restoreHasSymbolTable = + llvm::make_scope_exit([&]() { hasSymbolTable = savedHasSymbolTable; }); + if (!hasSymbolTable && op->hasTrait()) + hasSymbolTable = true; + + for (Region ®ion : op->getRegions()) { + LDBG() << "[init] Recursing into region of op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); + for (Operation &nestedOp : region.getOps()) { + LDBG() << "[init] Recursing into nested op: " + << OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions()); + if (failed(initializeRecursively(&nestedOp))) + return failure(); + } } } LDBG() << "[init] Finished initializeRecursively for op: " @@ -388,7 +401,13 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) { void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) { LDBG() << "visitCallOperation: " << OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions()); - Operation *callableOp = call.resolveCallableInTable(&symbolTable); + + Operation *callableOp = nullptr; + if (hasSymbolTable) + callableOp = call.resolveCallableInTable(&symbolTable); + else + LDBG() + << "No symbol table present in analysis scope, can't resolve callable"; // A call to a externally-defined callable has unknown predecessors. const auto isExternalCallable = [this](Operation *op) { diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp index d05374f667a51..b51465bc31ec3 100644 --- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp @@ -64,10 +64,12 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation( AbstractDenseLattice *after) { // Allow for customizing the behavior of calls to external symbols, including // when the analysis is explicitly marked as non-interprocedural. - auto callable = - dyn_cast_if_present(call.resolveCallable()); - if (!getSolverConfig().isInterprocedural() || - (callable && !callable.getCallableRegion())) { + auto isExternalCallable = [&]() { + auto callable = + dyn_cast_if_present(call.resolveCallable()); + return callable && !callable.getCallableRegion(); + }; + if (!getSolverConfig().isInterprocedural() || isExternalCallable()) { return visitCallControlFlowTransfer( call, CallControlFlowAction::ExternalCallee, before, after); } @@ -290,6 +292,12 @@ AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint *point) { void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation( CallOpInterface call, const AbstractDenseLattice &after, AbstractDenseLattice *before) { + // If the solver is not interprocedural, let the hook handle it as an external + // callee. + if (!getSolverConfig().isInterprocedural()) + return visitCallControlFlowTransfer( + call, CallControlFlowAction::ExternalCallee, after, before); + // Find the callee. Operation *callee = call.resolveCallableInTable(&symbolTable); @@ -297,12 +305,10 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation( // No region means the callee is only declared in this module. // If that is the case or if the solver is not interprocedural, // let the hook handle it. - if (!getSolverConfig().isInterprocedural() || - (callable && (!callable.getCallableRegion() || - callable.getCallableRegion()->empty()))) { + if (callable && + (!callable.getCallableRegion() || callable.getCallableRegion()->empty())) return visitCallControlFlowTransfer( call, CallControlFlowAction::ExternalCallee, after, before); - } if (!callable) return setToExitState(before); diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp index 13a3e1480c836..0d2e2ed85549d 100644 --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -228,10 +228,12 @@ LogicalResult AbstractSparseForwardDataFlowAnalysis::visitCallOperation( ArrayRef resultLattices) { // If the call operation is to an external function, attempt to infer the // results from the call arguments. - auto callable = - dyn_cast_if_present(call.resolveCallable()); - if (!getSolverConfig().isInterprocedural() || - (callable && !callable.getCallableRegion())) { + auto isExternalCallable = [&]() { + auto callable = + dyn_cast_if_present(call.resolveCallable()); + return callable && !callable.getCallableRegion(); + }; + if (!getSolverConfig().isInterprocedural() || isExternalCallable()) { visitExternalCallImpl(call, operandLattices, resultLattices); return success(); } diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp index 7e1b4052027d3..9352ab02f7472 100644 --- a/mlir/lib/Analysis/DataFlowFramework.cpp +++ b/mlir/lib/Analysis/DataFlowFramework.cpp @@ -9,6 +9,7 @@ #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/iterator.h" @@ -109,6 +110,12 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) { isRunning = true; auto guard = llvm::make_scope_exit([&]() { isRunning = false; }); + bool isInterprocedural = config.isInterprocedural(); + auto restoreInterprocedural = llvm::make_scope_exit( + [&]() { config.setInterprocedural(isInterprocedural); }); + if (isInterprocedural && !top->hasTrait()) + config.setInterprocedural(false); + // Initialize equivalent lattice anchors. for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) { analysis.initializeEquivalentLatticeAnchor(top);