1414#include " mlir/Transforms/Passes.h"
1515
1616#include " mlir/IR/SymbolTable.h"
17+ #include " llvm/Support/Debug.h"
1718
1819namespace mlir {
1920#define GEN_PASS_DEF_SYMBOLDCE
@@ -22,6 +23,8 @@ namespace mlir {
2223
2324using namespace mlir ;
2425
26+ #define DEBUG_TYPE " symbol-dce"
27+
2528namespace {
2629struct SymbolDCE : public impl ::SymbolDCEBase<SymbolDCE> {
2730 void runOnOperation () override ;
@@ -84,6 +87,8 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
8487 SymbolTableCollection &symbolTable,
8588 bool symbolTableIsHidden,
8689 DenseSet<Operation *> &liveSymbols) {
90+ LLVM_DEBUG (llvm::dbgs () << " computeLiveness: " << symbolTableOp->getName ()
91+ << " \n " );
8792 // A worklist of live operations to propagate uses from.
8893 SmallVector<Operation *, 16 > worklist;
8994
@@ -105,36 +110,70 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
105110 }
106111
107112 // Process the set of symbols that were known to be live, adding new symbols
108- // that are referenced within.
113+ // that are referenced within. For operations that are not symbol tables, it
114+ // considers the liveness with respect to the op itself rather than scope of
115+ // nested symbol tables by enqueuing all the top level operations for
116+ // consideration.
109117 while (!worklist.empty ()) {
110118 Operation *op = worklist.pop_back_val ();
119+ LLVM_DEBUG (llvm::dbgs () << " processing: " << op->getName () << " \n " );
111120
112121 // If this is a symbol table, recursively compute its liveness.
113122 if (op->hasTrait <OpTrait::SymbolTable>()) {
114123 // The internal symbol table is hidden if the parent is, if its not a
115124 // symbol, or if it is a private symbol.
116125 SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
117126 bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate ();
127+ LLVM_DEBUG (llvm::dbgs () << " \t symbol table: " << op->getName ()
128+ << " is hidden: " << symIsHidden << " \n " );
118129 if (failed (computeLiveness (op, symbolTable, symIsHidden, liveSymbols)))
119130 return failure ();
131+ } else {
132+ LLVM_DEBUG (llvm::dbgs ()
133+ << " \t non-symbol table: " << op->getName () << " \n " );
134+ // If the op is not a symbol table, then, unless op itself is dead which
135+ // would be handled by DCE, we need to check all the regions and blocks
136+ // within the op to find the uses (e.g., consider visibility within op as
137+ // if top level rather than relying on pure symbol table visibility). This
138+ // is more conservative than SymbolTable::walkSymbolTables in the case
139+ // where there is again SymbolTable information to take advantage of.
140+ for (auto ®ion : op->getRegions ())
141+ for (auto &block : region.getBlocks ())
142+ for (Operation &op : block)
143+ if (op.getNumRegions ())
144+ worklist.push_back (&op);
120145 }
121146
147+ // Get the first parent symbol table op. Note: due to enqueueing of
148+ // top-level ops, we may not have a symbol table parent here, but if we do
149+ // not, then we also don't have a symbol.
150+ Operation *parentOp = op->getParentOp ();
151+ if (!parentOp->hasTrait <OpTrait::SymbolTable>())
152+ continue ;
153+
122154 // Collect the uses held by this operation.
123155 std::optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses (op);
124156 if (!uses) {
125157 return op->emitError ()
126- << " operation contains potentially unknown symbol table, "
127- " meaning that we can't reliable compute symbol uses" ;
158+ << " operation contains potentially unknown symbol table, meaning "
159+ << " that we can't reliable compute symbol uses" ;
128160 }
129161
130162 SmallVector<Operation *, 4 > resolvedSymbols;
163+ LLVM_DEBUG (llvm::dbgs () << " uses of " << op->getName () << " \n " );
131164 for (const SymbolTable::SymbolUse &use : *uses) {
165+ LLVM_DEBUG (llvm::dbgs () << " \t use: " << use.getUser () << " \n " );
132166 // Lookup the symbols referenced by this use.
133167 resolvedSymbols.clear ();
134- if (failed (symbolTable.lookupSymbolIn (
135- op-> getParentOp (), use. getSymbolRef (), resolvedSymbols)))
168+ if (failed (symbolTable.lookupSymbolIn (parentOp, use. getSymbolRef (),
169+ resolvedSymbols)))
136170 // Ignore references to unknown symbols.
137171 continue ;
172+ LLVM_DEBUG ({
173+ llvm::dbgs () << " \t\t resolved symbols: " ;
174+ llvm::interleaveComma (resolvedSymbols, llvm::dbgs ());
175+ llvm::dbgs () << " \n " ;
176+ });
138177
139178 // Mark each of the resolved symbols as live.
140179 for (Operation *resolvedSymbol : resolvedSymbols)
0 commit comments