Skip to content

Commit 019f5b9

Browse files
committed
Remove caching option and separate extension
1 parent 45e0383 commit 019f5b9

File tree

9 files changed

+16
-61
lines changed

9 files changed

+16
-61
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -639,16 +639,13 @@ class BufferizationState {
639639
return const_cast<BufferizationState *>(this)->getExtension<Ty>();
640640
}
641641

642+
/// Get a reference to the collection of cached symbol tables.
643+
SymbolTableCollection &getSymbolTables();
644+
645+
private:
642646
/// Extensions attached to the state, identified by the TypeID of their type.
643647
/// Only one extension of any given type is allowed.
644648
DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
645-
};
646-
647-
/// Extra bufferization state that is required for bufferization of operations
648-
/// declaring a symbol or a symbol table.
649-
struct SymbolBufferizationState : public BufferizationState::Extension {
650-
SymbolBufferizationState(BufferizationState &state)
651-
: BufferizationState::Extension(state) {}
652649

653650
/// The cached symbol tables.
654651
/// The user is expected to update / invalidate the cached symbol tables if

mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,6 @@ FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
127127
uint64_t alignment,
128128
Attribute memorySpace = {});
129129

130-
FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp op,
131-
BufferizationState &state,
132-
uint64_t alignment,
133-
Attribute memorySpace);
134-
135130
void removeSymbol(Operation *op, BufferizationState &state);
136131

137132
void insertSymbol(Operation *op, BufferizationState &state);

mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,6 @@ struct OneShotBufferizationOptions : public BufferizationOptions {
5252
/// `AnalysisHeuristic::Fuzzer`. The fuzzer should be used only with
5353
/// `testAnalysisOnly = true`.
5454
unsigned analysisFuzzerSeed = 0;
55-
56-
/// Enable caching of symbol tables. If enabled, the SymbolBufferizationState
57-
/// class is attached to the bufferization state and the user is required to
58-
/// keep the cached symbol tables consistent with respect to the performed
59-
/// bufferizations.
60-
bool cacheSymbolTables = false;
6155
};
6256

6357
/// State for analysis-enabled bufferization. This class keeps track of alias

mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ struct ConstantOpInterface
4747
// Create global memory segment and replace tensor with memref pointing to
4848
// that memory segment.
4949
FailureOr<memref::GlobalOp> globalOp =
50-
getGlobalFor(constantOp, state, options.bufferAlignment, memorySpace);
50+
getGlobalFor(constantOp, state.getSymbolTables(),
51+
options.bufferAlignment, memorySpace);
5152
if (failed(globalOp))
5253
return failure();
5354
memref::GlobalOp globalMemref = *globalOp;

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ void AnalysisState::resetCache() {
127127

128128
BufferizationState::Extension::~Extension() = default;
129129

130+
SymbolTableCollection &BufferizationState::getSymbolTables() {
131+
return symbolTables;
132+
}
133+
130134
Region *bufferization::getNextEnclosingRepetitiveRegion(
131135
Region *region, const BufferizationOptions &options) {
132136
assert(isRepetitiveRegion(region, options) && "expected repetitive region");

mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,6 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
8585
auto payloadOps = state.getPayloadOps(getTarget());
8686
BufferizationState bufferizationState;
8787

88-
if (options.cacheSymbolTables) {
89-
bufferizationState.addExtension<SymbolBufferizationState>();
90-
}
91-
9288
for (Operation *target : payloadOps) {
9389
if (!isa<ModuleOp, FunctionOpInterface>(target))
9490
return emitSilenceableError() << "expected module or function target";

mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -161,40 +161,17 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp,
161161
}
162162

163163
namespace mlir::bufferization {
164-
FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp op,
165-
BufferizationState &state,
166-
uint64_t alignment,
167-
Attribute memorySpace) {
168-
if (auto *symbolBufferizationState =
169-
state.getExtension<SymbolBufferizationState>()) {
170-
// Use the cached symbol tables.
171-
return getGlobalFor(op, symbolBufferizationState->symbolTables, alignment,
172-
memorySpace);
173-
}
174-
175-
SymbolTableCollection symbolTables;
176-
return getGlobalFor(op, symbolTables, alignment, memorySpace);
177-
}
178-
179164
void removeSymbol(Operation *op, BufferizationState &state) {
180-
if (auto *symbolBufferizationState =
181-
state.getExtension<SymbolBufferizationState>()) {
182-
SymbolTable &symbolTable =
183-
symbolBufferizationState->symbolTables.getSymbolTable(
184-
op->getParentWithTrait<OpTrait::SymbolTable>());
165+
SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
166+
op->getParentWithTrait<OpTrait::SymbolTable>());
185167

186-
symbolTable.remove(op);
187-
}
168+
symbolTable.remove(op);
188169
}
189170

190171
void insertSymbol(Operation *op, BufferizationState &state) {
191-
if (auto *symbolBufferizationState =
192-
state.getExtension<SymbolBufferizationState>()) {
193-
SymbolTable &symbolTable =
194-
symbolBufferizationState->symbolTables.getSymbolTable(
195-
op->getParentWithTrait<OpTrait::SymbolTable>());
172+
SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
173+
op->getParentWithTrait<OpTrait::SymbolTable>());
196174

197-
symbolTable.insert(op);
198-
}
175+
symbolTable.insert(op);
199176
}
200177
} // namespace mlir::bufferization

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,6 @@ struct OneShotBufferizePass
163163

164164
BufferizationState state;
165165

166-
if (opt.cacheSymbolTables) {
167-
state.addExtension<SymbolBufferizationState>();
168-
}
169-
170166
BufferizationStatistics statistics;
171167
ModuleOp moduleOp = getOperation();
172168
if (opt.bufferizeFunctionBoundaries) {

mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,6 @@ class SparsificationAndBufferizationPass
116116

117117
bufferization::BufferizationState bufferizationState;
118118

119-
if (updatedOptions.cacheSymbolTables) {
120-
bufferizationState
121-
.addExtension<bufferization::SymbolBufferizationState>();
122-
}
123-
124119
if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()),
125120
updatedOptions,
126121
bufferizationState)))

0 commit comments

Comments
 (0)