Skip to content

Commit 2507f1e

Browse files
Separate implementation of the symbol-dce pass.
1 parent 1906c3e commit 2507f1e

File tree

5 files changed

+199
-158
lines changed

5 files changed

+199
-158
lines changed

mlir/include/mlir/Transforms/Passes.td

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -509,10 +509,6 @@ def SymbolDCE : Pass<"symbol-dce"> {
509509
information on `Symbols`.
510510
}];
511511
let constructor = "mlir::createSymbolDCEPass()";
512-
513-
let statistics = [
514-
Statistic<"numDCE", "num-dce'd", "Number of symbols DCE'd">,
515-
];
516512
}
517513

518514
def SymbolPrivatize : Pass<"symbol-privatize"> {
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//===- SymbolDceUtils.h.h ---------------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_TRANSFORMS_SYMBOLDCEUTILS_H
10+
#define MLIR_TRANSFORMS_SYMBOLDCEUTILS_H
11+
12+
#include "mlir/Support/LLVM.h"
13+
14+
#include "llvm/ADT/SmallVector.h"
15+
#include "llvm/Support/LogicalResult.h"
16+
17+
namespace mlir {
18+
19+
class Operation;
20+
class SymbolTableCollection;
21+
22+
/// Eliminate dead symbols in the symbolTableOp.
23+
LogicalResult symbolDce(Operation *);
24+
25+
/// Compute the liveness of the symbols within the given symbol table.
26+
/// `symbolTableIsHidden` is true if this symbol table is known to be
27+
/// unaccessible from operations in its parent regions.
28+
LogicalResult computeLiveness(Operation *, SymbolTableCollection &, bool,
29+
DenseSet<Operation *> &);
30+
} // end namespace mlir
31+
32+
#endif // MLIR_TRANSFORMS_SYMBOLDCEUTILS_H

mlir/lib/Transforms/SymbolDCE.cpp

Lines changed: 2 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414
#include "mlir/Transforms/Passes.h"
1515

1616
#include "mlir/IR/Operation.h"
17-
#include "mlir/IR/SymbolTable.h"
18-
#include "llvm/Support/Debug.h"
19-
#include "llvm/Support/DebugLog.h"
20-
#include "llvm/Support/InterleavedRange.h"
17+
#include "mlir/Transforms/SymbolDceUtils.h"
2118

2219
namespace mlir {
2320
#define GEN_PASS_DEF_SYMBOLDCE
@@ -26,165 +23,16 @@ namespace mlir {
2623

2724
using namespace mlir;
2825

29-
#define DEBUG_TYPE "symbol-dce"
30-
3126
namespace {
3227
struct SymbolDCE : public impl::SymbolDCEBase<SymbolDCE> {
3328
void runOnOperation() override;
34-
35-
/// Compute the liveness of the symbols within the given symbol table.
36-
/// `symbolTableIsHidden` is true if this symbol table is known to be
37-
/// unaccessible from operations in its parent regions.
38-
LogicalResult computeLiveness(Operation *symbolTableOp,
39-
SymbolTableCollection &symbolTable,
40-
bool symbolTableIsHidden,
41-
DenseSet<Operation *> &liveSymbols);
4229
};
4330
} // namespace
4431

4532
void SymbolDCE::runOnOperation() {
4633
Operation *symbolTableOp = getOperation();
47-
48-
// SymbolDCE should only be run on operations that define a symbol table.
49-
if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
50-
symbolTableOp->emitOpError()
51-
<< " was scheduled to run under SymbolDCE, but does not define a "
52-
"symbol table";
34+
if (failed(symbolDce(symbolTableOp)))
5335
return signalPassFailure();
54-
}
55-
56-
// A flag that signals if the top level symbol table is hidden, i.e. not
57-
// accessible from parent scopes.
58-
bool symbolTableIsHidden = true;
59-
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(symbolTableOp);
60-
if (symbolTableOp->getParentOp() && symbol)
61-
symbolTableIsHidden = symbol.isPrivate();
62-
63-
// Compute the set of live symbols within the symbol table.
64-
DenseSet<Operation *> liveSymbols;
65-
SymbolTableCollection symbolTable;
66-
if (failed(computeLiveness(symbolTableOp, symbolTable, symbolTableIsHidden,
67-
liveSymbols)))
68-
return signalPassFailure();
69-
70-
// After computing the liveness, delete all of the symbols that were found to
71-
// be dead.
72-
symbolTableOp->walk([&](Operation *nestedSymbolTable) {
73-
if (!nestedSymbolTable->hasTrait<OpTrait::SymbolTable>())
74-
return;
75-
for (auto &block : nestedSymbolTable->getRegion(0)) {
76-
for (Operation &op : llvm::make_early_inc_range(block)) {
77-
if (isa<SymbolOpInterface>(&op) && !liveSymbols.count(&op)) {
78-
op.erase();
79-
++numDCE;
80-
}
81-
}
82-
}
83-
});
84-
}
85-
86-
/// Compute the liveness of the symbols within the given symbol table.
87-
/// `symbolTableIsHidden` is true if this symbol table is known to be
88-
/// unaccessible from operations in its parent regions.
89-
LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
90-
SymbolTableCollection &symbolTable,
91-
bool symbolTableIsHidden,
92-
DenseSet<Operation *> &liveSymbols) {
93-
LDBG() << "computeLiveness: "
94-
<< OpWithFlags(symbolTableOp, OpPrintingFlags().skipRegions());
95-
// A worklist of live operations to propagate uses from.
96-
SmallVector<Operation *, 16> worklist;
97-
98-
// Walk the symbols within the current symbol table, marking the symbols that
99-
// are known to be live.
100-
for (auto &block : symbolTableOp->getRegion(0)) {
101-
// Add all non-symbols or symbols that can't be discarded.
102-
for (Operation &op : block) {
103-
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
104-
if (!symbol) {
105-
worklist.push_back(&op);
106-
continue;
107-
}
108-
bool isDiscardable = (symbolTableIsHidden || symbol.isPrivate()) &&
109-
symbol.canDiscardOnUseEmpty();
110-
if (!isDiscardable && liveSymbols.insert(&op).second)
111-
worklist.push_back(&op);
112-
}
113-
}
114-
115-
// Process the set of symbols that were known to be live, adding new symbols
116-
// that are referenced within. For operations that are not symbol tables, it
117-
// considers the liveness with respect to the op itself rather than scope of
118-
// nested symbol tables by enqueuing all the top level operations for
119-
// consideration.
120-
while (!worklist.empty()) {
121-
Operation *op = worklist.pop_back_val();
122-
LDBG() << "processing: "
123-
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
124-
125-
// If this is a symbol table, recursively compute its liveness.
126-
if (op->hasTrait<OpTrait::SymbolTable>()) {
127-
// The internal symbol table is hidden if the parent is, if its not a
128-
// symbol, or if it is a private symbol.
129-
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
130-
bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
131-
LDBG() << "\tsymbol table: "
132-
<< OpWithFlags(op, OpPrintingFlags().skipRegions())
133-
<< " is hidden: " << symIsHidden;
134-
if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols)))
135-
return failure();
136-
} else {
137-
LDBG() << "\tnon-symbol table: "
138-
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
139-
// If the op is not a symbol table, then, unless op itself is dead which
140-
// would be handled by DCE, we need to check all the regions and blocks
141-
// within the op to find the uses (e.g., consider visibility within op as
142-
// if top level rather than relying on pure symbol table visibility). This
143-
// is more conservative than SymbolTable::walkSymbolTables in the case
144-
// where there is again SymbolTable information to take advantage of.
145-
for (auto &region : op->getRegions())
146-
for (auto &block : region.getBlocks())
147-
for (Operation &op : block)
148-
if (op.getNumRegions())
149-
worklist.push_back(&op);
150-
}
151-
152-
// Get the first parent symbol table op. Note: due to enqueueing of
153-
// top-level ops, we may not have a symbol table parent here, but if we do
154-
// not, then we also don't have a symbol.
155-
Operation *parentOp = op->getParentOp();
156-
if (!parentOp->hasTrait<OpTrait::SymbolTable>())
157-
continue;
158-
159-
// Collect the uses held by this operation.
160-
std::optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(op);
161-
if (!uses) {
162-
return op->emitError()
163-
<< "operation contains potentially unknown symbol table, meaning "
164-
<< "that we can't reliable compute symbol uses";
165-
}
166-
167-
SmallVector<Operation *, 4> resolvedSymbols;
168-
LDBG() << "uses of " << OpWithFlags(op, OpPrintingFlags().skipRegions());
169-
for (const SymbolTable::SymbolUse &use : *uses) {
170-
LDBG() << "\tuse: " << use.getUser();
171-
// Lookup the symbols referenced by this use.
172-
resolvedSymbols.clear();
173-
if (failed(symbolTable.lookupSymbolIn(parentOp, use.getSymbolRef(),
174-
resolvedSymbols)))
175-
// Ignore references to unknown symbols.
176-
continue;
177-
LDBG() << "\t\tresolved symbols: "
178-
<< llvm::interleaved(resolvedSymbols, ", ");
179-
180-
// Mark each of the resolved symbols as live.
181-
for (Operation *resolvedSymbol : resolvedSymbols)
182-
if (liveSymbols.insert(resolvedSymbol).second)
183-
worklist.push_back(resolvedSymbol);
184-
}
185-
}
186-
187-
return success();
18836
}
18937

19038
std::unique_ptr<Pass> mlir::createSymbolDCEPass() {

mlir/lib/Transforms/Utils/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_mlir_library(MLIRTransformUtils
1010
LoopInvariantCodeMotionUtils.cpp
1111
RegionUtils.cpp
1212
WalkPatternRewriteDriver.cpp
13+
SymbolDceUtils.cpp
1314

1415
ADDITIONAL_HEADER_DIRS
1516
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
//===- SymbolDceUtils.cpp -------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements eliminate dead symbols
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Transforms/SymbolDceUtils.h"
14+
#include "mlir/IR/Operation.h"
15+
#include "mlir/IR/SymbolTable.h"
16+
17+
#include "llvm/Support/Debug.h"
18+
#include "llvm/Support/DebugLog.h"
19+
#include "llvm/Support/InterleavedRange.h"
20+
21+
#define DEBUG_TYPE "symbol-dce"
22+
23+
llvm::LogicalResult mlir::symbolDce(Operation *symbolTableOp) {
24+
// SymbolDCE should only be run on operations that define a symbol table.
25+
if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
26+
symbolTableOp->emitOpError()
27+
<< " was scheduled to run under SymbolDCE, but does not define a "
28+
"symbol table";
29+
return failure();
30+
}
31+
32+
// A flag that signals if the top level symbol table is hidden, i.e. not
33+
// accessible from parent scopes.
34+
bool symbolTableIsHidden = true;
35+
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(symbolTableOp);
36+
if (symbolTableOp->getParentOp() && symbol)
37+
symbolTableIsHidden = symbol.isPrivate();
38+
39+
// Compute the set of live symbols within the symbol table.
40+
DenseSet<Operation *> liveSymbols;
41+
SymbolTableCollection symbolTable;
42+
if (failed(computeLiveness(symbolTableOp, symbolTable, symbolTableIsHidden,
43+
liveSymbols)))
44+
return failure();
45+
46+
// After computing the liveness, delete all of the symbols that were found to
47+
// be dead.
48+
symbolTableOp->walk([&](Operation *nestedSymbolTable) {
49+
if (!nestedSymbolTable->hasTrait<OpTrait::SymbolTable>())
50+
return;
51+
for (auto &block : nestedSymbolTable->getRegion(0)) {
52+
for (Operation &op : llvm::make_early_inc_range(block)) {
53+
if (isa<SymbolOpInterface>(&op) && !liveSymbols.count(&op)) {
54+
op.erase();
55+
}
56+
}
57+
}
58+
});
59+
return success();
60+
}
61+
62+
/// Compute the liveness of the symbols within the given symbol table.
63+
/// `symbolTableIsHidden` is true if this symbol table is known to be
64+
/// unaccessible from operations in its parent regions.
65+
llvm::LogicalResult mlir::computeLiveness(Operation *symbolTableOp,
66+
SymbolTableCollection &symbolTable,
67+
bool symbolTableIsHidden,
68+
DenseSet<Operation *> &liveSymbols) {
69+
LDBG() << "computeLiveness: "
70+
<< OpWithFlags(symbolTableOp, OpPrintingFlags().skipRegions());
71+
// A worklist of live operations to propagate uses from.
72+
SmallVector<Operation *, 16> worklist;
73+
74+
// Walk the symbols within the current symbol table, marking the symbols that
75+
// are known to be live.
76+
for (auto &block : symbolTableOp->getRegion(0)) {
77+
// Add all non-symbols or symbols that can't be discarded.
78+
for (Operation &op : block) {
79+
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
80+
if (!symbol) {
81+
worklist.push_back(&op);
82+
continue;
83+
}
84+
bool isDiscardable = (symbolTableIsHidden || symbol.isPrivate()) &&
85+
symbol.canDiscardOnUseEmpty();
86+
if (!isDiscardable && liveSymbols.insert(&op).second)
87+
worklist.push_back(&op);
88+
}
89+
}
90+
91+
// Process the set of symbols that were known to be live, adding new symbols
92+
// that are referenced within. For operations that are not symbol tables, it
93+
// considers the liveness with respect to the op itself rather than scope of
94+
// nested symbol tables by enqueuing all the top level operations for
95+
// consideration.
96+
while (!worklist.empty()) {
97+
Operation *op = worklist.pop_back_val();
98+
LDBG() << "processing: "
99+
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
100+
101+
// If this is a symbol table, recursively compute its liveness.
102+
if (op->hasTrait<OpTrait::SymbolTable>()) {
103+
// The internal symbol table is hidden if the parent is, if its not a
104+
// symbol, or if it is a private symbol.
105+
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
106+
bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
107+
LDBG() << "\tsymbol table: "
108+
<< OpWithFlags(op, OpPrintingFlags().skipRegions())
109+
<< " is hidden: " << symIsHidden;
110+
if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols)))
111+
return failure();
112+
} else {
113+
LDBG() << "\tnon-symbol table: "
114+
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
115+
// If the op is not a symbol table, then, unless op itself is dead which
116+
// would be handled by DCE, we need to check all the regions and blocks
117+
// within the op to find the uses (e.g., consider visibility within op as
118+
// if top level rather than relying on pure symbol table visibility). This
119+
// is more conservative than SymbolTable::walkSymbolTables in the case
120+
// where there is again SymbolTable information to take advantage of.
121+
for (auto &region : op->getRegions())
122+
for (auto &block : region.getBlocks())
123+
for (Operation &op : block)
124+
if (op.getNumRegions())
125+
worklist.push_back(&op);
126+
}
127+
128+
// Get the first parent symbol table op. Note: due to enqueueing of
129+
// top-level ops, we may not have a symbol table parent here, but if we do
130+
// not, then we also don't have a symbol.
131+
Operation *parentOp = op->getParentOp();
132+
if (!parentOp->hasTrait<OpTrait::SymbolTable>())
133+
continue;
134+
135+
// Collect the uses held by this operation.
136+
std::optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(op);
137+
if (!uses) {
138+
return op->emitError()
139+
<< "operation contains potentially unknown symbol table, meaning "
140+
<< "that we can't reliable compute symbol uses";
141+
}
142+
143+
SmallVector<Operation *, 4> resolvedSymbols;
144+
LDBG() << "uses of " << OpWithFlags(op, OpPrintingFlags().skipRegions());
145+
for (const SymbolTable::SymbolUse &use : *uses) {
146+
LDBG() << "\tuse: " << use.getUser();
147+
// Lookup the symbols referenced by this use.
148+
resolvedSymbols.clear();
149+
if (failed(symbolTable.lookupSymbolIn(parentOp, use.getSymbolRef(),
150+
resolvedSymbols)))
151+
// Ignore references to unknown symbols.
152+
continue;
153+
LDBG() << "\t\tresolved symbols: "
154+
<< llvm::interleaved(resolvedSymbols, ", ");
155+
156+
// Mark each of the resolved symbols as live.
157+
for (Operation *resolvedSymbol : resolvedSymbols)
158+
if (liveSymbols.insert(resolvedSymbol).second)
159+
worklist.push_back(resolvedSymbol);
160+
}
161+
}
162+
163+
return success();
164+
}

0 commit comments

Comments
 (0)