Skip to content

Commit 8499232

Browse files
committed
Cache symbol tables during OneShotBufferization analyses
1 parent 7ec1e0f commit 8499232

File tree

3 files changed

+31
-13
lines changed

3 files changed

+31
-13
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ struct FuncAnalysisState : public OneShotAnalysisState::Extension {
6969
/// analyzed.
7070
DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;
7171

72+
/// A collection of cached SymbolTables used for faster function lookup.
73+
mutable mlir::SymbolTableCollection symbolTable;
74+
7275
/// This function is called right before analyzing the given FuncOp. It
7376
/// initializes the data structures for the FuncOp in this state object.
7477
void startFunctionAnalysis(FuncOp funcOp);

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,14 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
7676
}
7777

7878
/// Return the FuncOp called by `callOp`.
79-
static FuncOp getCalledFunction(CallOpInterface callOp) {
79+
static FuncOp getCalledFunction(CallOpInterface callOp,
80+
mlir::SymbolTableCollection &symbolTable) {
8081
SymbolRefAttr sym =
8182
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
8283
if (!sym)
8384
return nullptr;
8485
return dyn_cast_or_null<FuncOp>(
85-
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
86+
symbolTable.lookupNearestSymbolFrom(callOp, sym));
8687
}
8788

8889
/// Get FuncAnalysisState.
@@ -135,44 +136,44 @@ struct CallOpInterface
135136
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
136137
const AnalysisState &state) const {
137138
func::CallOp callOp = cast<func::CallOp>(op);
138-
FuncOp funcOp = getCalledFunction(callOp);
139+
const FuncAnalysisState &funcState = getFuncAnalysisState(state);
140+
FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
139141
assert(funcOp && "expected CallOp to a FuncOp");
140142

141143
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
142144
// FuncOp not analyzed yet. Assume that OpOperand is read.
143145
return true;
144146

145-
const FuncAnalysisState &funcState = getFuncAnalysisState(state);
146147
return funcState.readBbArgs.lookup(funcOp).contains(
147148
opOperand.getOperandNumber());
148149
}
149150

150151
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
151152
const AnalysisState &state) const {
152153
func::CallOp callOp = cast<func::CallOp>(op);
153-
FuncOp funcOp = getCalledFunction(callOp);
154+
const FuncAnalysisState &funcState = getFuncAnalysisState(state);
155+
FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
154156
assert(funcOp && "expected CallOp to a FuncOp");
155157

156158
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
157159
// FuncOp not analyzed yet. Assume that OpOperand is written.
158160
return true;
159161

160-
const FuncAnalysisState &funcState = getFuncAnalysisState(state);
161162
return funcState.writtenBbArgs.lookup(funcOp).contains(
162163
opOperand.getOperandNumber());
163164
}
164165

165166
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
166167
const AnalysisState &state) const {
167168
func::CallOp callOp = cast<func::CallOp>(op);
168-
FuncOp funcOp = getCalledFunction(callOp);
169+
const FuncAnalysisState &funcState = getFuncAnalysisState(state);
170+
FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
169171
assert(funcOp && "expected CallOp to a FuncOp");
170172
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
171173
// FuncOp not analyzed yet. Any OpResult may be aliasing.
172174
return detail::unknownGetAliasingValues(opOperand);
173175

174176
// Get aliasing results from state.
175-
const FuncAnalysisState &funcState = getFuncAnalysisState(state);
176177
auto aliasingReturnVals =
177178
funcState.aliasingReturnVals.lookup(funcOp).lookup(
178179
opOperand.getOperandNumber());
@@ -199,7 +200,11 @@ struct CallOpInterface
199200
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
200201
SmallVector<Value> &invocationStack) const {
201202
auto callOp = cast<func::CallOp>(op);
202-
FuncOp funcOp = getCalledFunction(callOp);
203+
204+
// TODO Avoid recomputing the symbol tables every time.
205+
mlir::SymbolTableCollection symbolTable;
206+
207+
FuncOp funcOp = getCalledFunction(callOp, symbolTable);
203208
assert(funcOp && "expected CallOp to a FuncOp");
204209

205210
// If the callee was already bufferized, we can directly take the type from
@@ -243,7 +248,11 @@ struct CallOpInterface
243248
// 2. Rewrite tensor operands as memrefs based on type of the already
244249
// bufferized callee.
245250
SmallVector<Value> newOperands;
246-
FuncOp funcOp = getCalledFunction(callOp);
251+
252+
// TODO Avoid recomputing the symbol tables every time.
253+
mlir::SymbolTableCollection symbolTable;
254+
255+
FuncOp funcOp = getCalledFunction(callOp, symbolTable);
247256
assert(funcOp && "expected CallOp to a FuncOp");
248257
FunctionType funcType = funcOp.getFunctionType();
249258

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,13 +280,15 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {
280280
}
281281

282282
/// Return the func::FuncOp called by `callOp`.
283-
static func::FuncOp getCalledFunction(func::CallOp callOp) {
283+
static func::FuncOp
284+
getCalledFunction(func::CallOp callOp,
285+
mlir::SymbolTableCollection &symbolTable) {
284286
SymbolRefAttr sym =
285287
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
286288
if (!sym)
287289
return nullptr;
288290
return dyn_cast_or_null<func::FuncOp>(
289-
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
291+
symbolTable.lookupNearestSymbolFrom(callOp, sym));
290292
}
291293

292294
/// Return "true" if the given function signature has tensor semantics.
@@ -314,11 +316,15 @@ static LogicalResult getFuncOpsOrderedByCalls(
314316
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
315317
// For each FuncOp, the number of func::CallOp it contains.
316318
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
319+
320+
// TODO Avoid recomputing the symbol tables every time.
321+
mlir::SymbolTableCollection symbolTable;
322+
317323
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
318324
// Collect function calls and populate the caller map.
319325
numberCallOpsContainedInFuncOp[funcOp] = 0;
320326
WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
321-
func::FuncOp calledFunction = getCalledFunction(callOp);
327+
func::FuncOp calledFunction = getCalledFunction(callOp, symbolTable);
322328
assert(calledFunction && "could not retrieved called func::FuncOp");
323329
// If the called function does not have any tensors in its signature, then
324330
// it is not necessary to bufferize the callee before the caller.

0 commit comments

Comments
 (0)