Skip to content

Commit 309f07f

Browse files
committed
Compute symbol table if no FuncAnalysisState is registered
1 parent 983b2c6 commit 309f07f

File tree

2 files changed

+31
-26
lines changed

2 files changed

+31
-26
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ struct FuncAnalysisState : public OneShotAnalysisState::Extension {
7070
DenseMap<FuncOp, FuncOpAnalysisState> analyzedFuncOps;
7171

7272
/// A collection of cached SymbolTables used for faster function lookup.
73-
mutable mlir::SymbolTableCollection symbolTable;
73+
mutable SymbolTableCollection symbolTables;
7474

7575
/// This function is called right before analyzing the given FuncOp. It
7676
/// initializes the data structures for the FuncOp in this state object.

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -77,33 +77,38 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
7777

7878
/// Return the FuncOp called by `callOp`.
7979
static FuncOp getCalledFunction(CallOpInterface callOp,
80-
mlir::SymbolTableCollection &symbolTable) {
80+
SymbolTableCollection &symbolTables) {
8181
SymbolRefAttr sym =
8282
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
8383
if (!sym)
8484
return nullptr;
8585
return dyn_cast_or_null<FuncOp>(
86-
symbolTable.lookupNearestSymbolFrom(callOp, sym));
86+
symbolTables.lookupNearestSymbolFrom(callOp, sym));
8787
}
8888

89-
/// Get or create FuncAnalysisState.
90-
static const FuncAnalysisState &
91-
getOrCreateFuncAnalysisState(const AnalysisState &state) {
92-
assert(isa<OneShotAnalysisState>(state) && "expected OneShotAnalysisState");
93-
94-
// Unfortunately, at the moment the BufferizableOpInterface methods do provide
95-
// a const reference to the AnalysisState class, and the only way to
96-
// dynamically add an extension is to const_cast it to a non-const reference.
97-
// Should the const qualifier be dropped from the interface?
98-
auto &oneShotAnalysisState =
99-
static_cast<OneShotAnalysisState &>(const_cast<AnalysisState &>(state));
89+
/// Return the FuncOp called by `callOp`.
90+
static FuncOp getCalledFunction(CallOpInterface callOp,
91+
const AnalysisState &state) {
92+
auto &oneShotAnalysisState = static_cast<const OneShotAnalysisState &>(state);
10093

101-
auto *result = oneShotAnalysisState.getExtension<FuncAnalysisState>();
94+
if (auto *funcAnalysisState =
95+
oneShotAnalysisState.getExtension<FuncAnalysisState>()) {
96+
// Use the cached symbol tables.
97+
return getCalledFunction(callOp, funcAnalysisState->symbolTables);
98+
}
10299

103-
if (result)
104-
return *result;
100+
SymbolTableCollection symbolTables;
101+
return getCalledFunction(callOp, symbolTables);
102+
}
105103

106-
return oneShotAnalysisState.addExtension<FuncAnalysisState>();
104+
/// Get FuncAnalysisState.
105+
static const FuncAnalysisState &
106+
getFuncAnalysisState(const AnalysisState &state) {
107+
assert(isa<OneShotAnalysisState>(state) && "expected OneShotAnalysisState");
108+
auto *result = static_cast<const OneShotAnalysisState &>(state)
109+
.getExtension<FuncAnalysisState>();
110+
assert(result && "FuncAnalysisState does not exist");
111+
return *result;
107112
}
108113

109114
/// Return the state (phase) of analysis of the FuncOp.
@@ -146,44 +151,44 @@ struct CallOpInterface
146151
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
147152
const AnalysisState &state) const {
148153
func::CallOp callOp = cast<func::CallOp>(op);
149-
const FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
150-
FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
154+
FuncOp funcOp = getCalledFunction(callOp, state);
151155
assert(funcOp && "expected CallOp to a FuncOp");
152156

153157
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
154158
// FuncOp not analyzed yet. Assume that OpOperand is read.
155159
return true;
156160

161+
const FuncAnalysisState &funcState = getFuncAnalysisState(state);
157162
return funcState.readBbArgs.lookup(funcOp).contains(
158163
opOperand.getOperandNumber());
159164
}
160165

161166
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
162167
const AnalysisState &state) const {
163168
func::CallOp callOp = cast<func::CallOp>(op);
164-
const FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
165-
FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
169+
FuncOp funcOp = getCalledFunction(callOp, state);
166170
assert(funcOp && "expected CallOp to a FuncOp");
167171

168172
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
169173
// FuncOp not analyzed yet. Assume that OpOperand is written.
170174
return true;
171175

176+
const FuncAnalysisState &funcState = getFuncAnalysisState(state);
172177
return funcState.writtenBbArgs.lookup(funcOp).contains(
173178
opOperand.getOperandNumber());
174179
}
175180

176181
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
177182
const AnalysisState &state) const {
178183
func::CallOp callOp = cast<func::CallOp>(op);
179-
const FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
180-
FuncOp funcOp = getCalledFunction(callOp, funcState.symbolTable);
184+
FuncOp funcOp = getCalledFunction(callOp, state);
181185
assert(funcOp && "expected CallOp to a FuncOp");
182186
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
183187
// FuncOp not analyzed yet. Any OpResult may be aliasing.
184188
return detail::unknownGetAliasingValues(opOperand);
185189

186190
// Get aliasing results from state.
191+
const FuncAnalysisState &funcState = getFuncAnalysisState(state);
187192
auto aliasingReturnVals =
188193
funcState.aliasingReturnVals.lookup(funcOp).lookup(
189194
opOperand.getOperandNumber());
@@ -212,7 +217,7 @@ struct CallOpInterface
212217
auto callOp = cast<func::CallOp>(op);
213218

214219
// TODO Avoid recomputing the symbol tables every time.
215-
mlir::SymbolTableCollection symbolTable;
220+
SymbolTableCollection symbolTable;
216221

217222
FuncOp funcOp = getCalledFunction(callOp, symbolTable);
218223
assert(funcOp && "expected CallOp to a FuncOp");
@@ -260,7 +265,7 @@ struct CallOpInterface
260265
SmallVector<Value> newOperands;
261266

262267
// TODO Avoid recomputing the symbol tables every time.
263-
mlir::SymbolTableCollection symbolTable;
268+
SymbolTableCollection symbolTable;
264269

265270
FuncOp funcOp = getCalledFunction(callOp, symbolTable);
266271
assert(funcOp && "expected CallOp to a FuncOp");

0 commit comments

Comments
 (0)