|
22 | 22 | #include "mlir/Interfaces/CallInterfaces.h" |
23 | 23 | #include "mlir/Interfaces/ControlFlowInterfaces.h" |
24 | 24 | #include "mlir/Support/LLVM.h" |
| 25 | +#include "llvm/ADT/ScopeExit.h" |
25 | 26 | #include "llvm/Support/Casting.h" |
26 | 27 | #include "llvm/Support/Debug.h" |
27 | 28 | #include "llvm/Support/DebugLog.h" |
@@ -159,6 +160,7 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { |
159 | 160 | LDBG() << "[init] Entering initializeSymbolCallables for top-level op: " |
160 | 161 | << OpWithFlags(top, OpPrintingFlags().skipRegions()); |
161 | 162 | analysisScope = top; |
| 163 | + hasSymbolTable = top->hasTrait<OpTrait::SymbolTable>(); |
162 | 164 | auto walkFn = [&](Operation *symTable, bool allUsesVisible) { |
163 | 165 | LDBG() << "[init] Processing symbol table op: " |
164 | 166 | << OpWithFlags(symTable, OpPrintingFlags().skipRegions()); |
@@ -260,14 +262,25 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) { |
260 | 262 | return failure(); |
261 | 263 | } |
262 | 264 | // Recurse on nested operations. |
263 | | - for (Region ®ion : op->getRegions()) { |
264 | | - LDBG() << "[init] Recursing into region of op: " |
265 | | - << OpWithFlags(op, OpPrintingFlags().skipRegions()); |
266 | | - for (Operation &nestedOp : region.getOps()) { |
267 | | - LDBG() << "[init] Recursing into nested op: " |
268 | | - << OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions()); |
269 | | - if (failed(initializeRecursively(&nestedOp))) |
270 | | - return failure(); |
| 265 | + if (op->getNumRegions()) { |
| 266 | + // If we haven't seen a symbol table yet, check if the current operation |
| 267 | + // has one. If so, update the flag to allow for resolving callables in |
| 268 | + // nested regions. |
| 269 | + bool savedHasSymbolTable = hasSymbolTable; |
| 270 | + auto restoreHasSymbolTable = |
| 271 | + llvm::make_scope_exit([&]() { hasSymbolTable = savedHasSymbolTable; }); |
| 272 | + if (!hasSymbolTable && op->hasTrait<OpTrait::SymbolTable>()) |
| 273 | + hasSymbolTable = true; |
| 274 | + |
| 275 | + for (Region ®ion : op->getRegions()) { |
| 276 | + LDBG() << "[init] Recursing into region of op: " |
| 277 | + << OpWithFlags(op, OpPrintingFlags().skipRegions()); |
| 278 | + for (Operation &nestedOp : region.getOps()) { |
| 279 | + LDBG() << "[init] Recursing into nested op: " |
| 280 | + << OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions()); |
| 281 | + if (failed(initializeRecursively(&nestedOp))) |
| 282 | + return failure(); |
| 283 | + } |
271 | 284 | } |
272 | 285 | } |
273 | 286 | LDBG() << "[init] Finished initializeRecursively for op: " |
@@ -388,7 +401,13 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) { |
388 | 401 | void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) { |
389 | 402 | LDBG() << "visitCallOperation: " |
390 | 403 | << OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions()); |
391 | | - Operation *callableOp = call.resolveCallableInTable(&symbolTable); |
| 404 | + |
| 405 | + Operation *callableOp = nullptr; |
| 406 | + if (hasSymbolTable) |
| 407 | + callableOp = call.resolveCallableInTable(&symbolTable); |
| 408 | + else |
| 409 | + LDBG() |
| 410 | + << "No symbol table present in analysis scope, can't resolve callable"; |
392 | 411 |
|
393 | 412 | // A call to a externally-defined callable has unknown predecessors. |
394 | 413 | const auto isExternalCallable = [this](Operation *op) { |
|
0 commit comments