@@ -77,33 +77,38 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
7777
7878// / Return the FuncOp called by `callOp`.
7979static FuncOp getCalledFunction (CallOpInterface callOp,
80- mlir:: SymbolTableCollection &symbolTable ) {
80+ SymbolTableCollection &symbolTables ) {
8181 SymbolRefAttr sym =
8282 llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee ());
8383 if (!sym)
8484 return nullptr ;
8585 return dyn_cast_or_null<FuncOp>(
86- symbolTable .lookupNearestSymbolFrom (callOp, sym));
86+ symbolTables .lookupNearestSymbolFrom (callOp, sym));
8787}
8888
89- // / Get or create FuncAnalysisState.
90- static const FuncAnalysisState &
91- getOrCreateFuncAnalysisState (const AnalysisState &state) {
92- assert (isa<OneShotAnalysisState>(state) && " expected OneShotAnalysisState" );
93-
94- // Unfortunately, at the moment the BufferizableOpInterface methods do provide
95- // a const reference to the AnalysisState class, and the only way to
96- // dynamically add an extension is to const_cast it to a non-const reference.
97- // Should the const qualifier be dropped from the interface?
98- auto &oneShotAnalysisState =
99- static_cast <OneShotAnalysisState &>(const_cast <AnalysisState &>(state));
89+ // / Return the FuncOp called by `callOp`.
90+ static FuncOp getCalledFunction (CallOpInterface callOp,
91+ const AnalysisState &state) {
92+ auto &oneShotAnalysisState = static_cast <const OneShotAnalysisState &>(state);
10093
101- auto *result = oneShotAnalysisState.getExtension <FuncAnalysisState>();
94+ if (auto *funcAnalysisState =
95+ oneShotAnalysisState.getExtension <FuncAnalysisState>()) {
96+ // Use the cached symbol tables.
97+ return getCalledFunction (callOp, funcAnalysisState->symbolTables );
98+ }
10299
103- if (result)
104- return *result;
100+ SymbolTableCollection symbolTables;
101+ return getCalledFunction (callOp, symbolTables);
102+ }
105103
106- return oneShotAnalysisState.addExtension <FuncAnalysisState>();
104+ // / Get FuncAnalysisState.
105+ static const FuncAnalysisState &
106+ getFuncAnalysisState (const AnalysisState &state) {
107+ assert (isa<OneShotAnalysisState>(state) && " expected OneShotAnalysisState" );
108+ auto *result = static_cast <const OneShotAnalysisState &>(state)
109+ .getExtension <FuncAnalysisState>();
110+ assert (result && " FuncAnalysisState does not exist" );
111+ return *result;
107112}
108113
109114// / Return the state (phase) of analysis of the FuncOp.
@@ -146,44 +151,44 @@ struct CallOpInterface
146151 bool bufferizesToMemoryRead (Operation *op, OpOperand &opOperand,
147152 const AnalysisState &state) const {
148153 func::CallOp callOp = cast<func::CallOp>(op);
149- const FuncAnalysisState &funcState = getOrCreateFuncAnalysisState (state);
150- FuncOp funcOp = getCalledFunction (callOp, funcState.symbolTable );
154+ FuncOp funcOp = getCalledFunction (callOp, state);
151155 assert (funcOp && " expected CallOp to a FuncOp" );
152156
153157 if (getFuncOpAnalysisState (state, funcOp) != FuncOpAnalysisState::Analyzed)
154158 // FuncOp not analyzed yet. Assume that OpOperand is read.
155159 return true ;
156160
161+ const FuncAnalysisState &funcState = getFuncAnalysisState (state);
157162 return funcState.readBbArgs .lookup (funcOp).contains (
158163 opOperand.getOperandNumber ());
159164 }
160165
161166 bool bufferizesToMemoryWrite (Operation *op, OpOperand &opOperand,
162167 const AnalysisState &state) const {
163168 func::CallOp callOp = cast<func::CallOp>(op);
164- const FuncAnalysisState &funcState = getOrCreateFuncAnalysisState (state);
165- FuncOp funcOp = getCalledFunction (callOp, funcState.symbolTable );
169+ FuncOp funcOp = getCalledFunction (callOp, state);
166170 assert (funcOp && " expected CallOp to a FuncOp" );
167171
168172 if (getFuncOpAnalysisState (state, funcOp) != FuncOpAnalysisState::Analyzed)
169173 // FuncOp not analyzed yet. Assume that OpOperand is written.
170174 return true ;
171175
176+ const FuncAnalysisState &funcState = getFuncAnalysisState (state);
172177 return funcState.writtenBbArgs .lookup (funcOp).contains (
173178 opOperand.getOperandNumber ());
174179 }
175180
176181 AliasingValueList getAliasingValues (Operation *op, OpOperand &opOperand,
177182 const AnalysisState &state) const {
178183 func::CallOp callOp = cast<func::CallOp>(op);
179- const FuncAnalysisState &funcState = getOrCreateFuncAnalysisState (state);
180- FuncOp funcOp = getCalledFunction (callOp, funcState.symbolTable );
184+ FuncOp funcOp = getCalledFunction (callOp, state);
181185 assert (funcOp && " expected CallOp to a FuncOp" );
182186 if (getFuncOpAnalysisState (state, funcOp) != FuncOpAnalysisState::Analyzed)
183187 // FuncOp not analyzed yet. Any OpResult may be aliasing.
184188 return detail::unknownGetAliasingValues (opOperand);
185189
186190 // Get aliasing results from state.
191+ const FuncAnalysisState &funcState = getFuncAnalysisState (state);
187192 auto aliasingReturnVals =
188193 funcState.aliasingReturnVals .lookup (funcOp).lookup (
189194 opOperand.getOperandNumber ());
@@ -212,7 +217,7 @@ struct CallOpInterface
212217 auto callOp = cast<func::CallOp>(op);
213218
214219 // TODO Avoid recomputing the symbol tables every time.
215- mlir:: SymbolTableCollection symbolTable;
220+ SymbolTableCollection symbolTable;
216221
217222 FuncOp funcOp = getCalledFunction (callOp, symbolTable);
218223 assert (funcOp && " expected CallOp to a FuncOp" );
@@ -260,7 +265,7 @@ struct CallOpInterface
260265 SmallVector<Value> newOperands;
261266
262267 // TODO Avoid recomputing the symbol tables every time.
263- mlir:: SymbolTableCollection symbolTable;
268+ SymbolTableCollection symbolTable;
264269
265270 FuncOp funcOp = getCalledFunction (callOp, symbolTable);
266271 assert (funcOp && " expected CallOp to a FuncOp" );
0 commit comments