@@ -76,13 +76,14 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
7676}
7777
7878// / Return the FuncOp called by `callOp`.
79- static FuncOp getCalledFunction (CallOpInterface callOp) {
79+ static FuncOp getCalledFunction (CallOpInterface callOp,
80+ mlir::SymbolTableCollection &symbolTable) {
8081 SymbolRefAttr sym =
8182 llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee ());
8283 if (!sym)
8384 return nullptr ;
8485 return dyn_cast_or_null<FuncOp>(
85- SymbolTable:: lookupNearestSymbolFrom (callOp, sym));
86+ symbolTable. lookupNearestSymbolFrom (callOp, sym));
8687}
8788
8889// / Get FuncAnalysisState.
@@ -135,44 +136,44 @@ struct CallOpInterface
135136 bool bufferizesToMemoryRead (Operation *op, OpOperand &opOperand,
136137 const AnalysisState &state) const {
137138 func::CallOp callOp = cast<func::CallOp>(op);
138- FuncOp funcOp = getCalledFunction (callOp);
139+ const FuncAnalysisState &funcState = getFuncAnalysisState (state);
140+ FuncOp funcOp = getCalledFunction (callOp, funcState.symbolTable );
139141 assert (funcOp && " expected CallOp to a FuncOp" );
140142
141143 if (getFuncOpAnalysisState (state, funcOp) != FuncOpAnalysisState::Analyzed)
142144 // FuncOp not analyzed yet. Assume that OpOperand is read.
143145 return true ;
144146
145- const FuncAnalysisState &funcState = getFuncAnalysisState (state);
146147 return funcState.readBbArgs .lookup (funcOp).contains (
147148 opOperand.getOperandNumber ());
148149 }
149150
150151 bool bufferizesToMemoryWrite (Operation *op, OpOperand &opOperand,
151152 const AnalysisState &state) const {
152153 func::CallOp callOp = cast<func::CallOp>(op);
153- FuncOp funcOp = getCalledFunction (callOp);
154+ const FuncAnalysisState &funcState = getFuncAnalysisState (state);
155+ FuncOp funcOp = getCalledFunction (callOp, funcState.symbolTable );
154156 assert (funcOp && " expected CallOp to a FuncOp" );
155157
156158 if (getFuncOpAnalysisState (state, funcOp) != FuncOpAnalysisState::Analyzed)
157159 // FuncOp not analyzed yet. Assume that OpOperand is written.
158160 return true ;
159161
160- const FuncAnalysisState &funcState = getFuncAnalysisState (state);
161162 return funcState.writtenBbArgs .lookup (funcOp).contains (
162163 opOperand.getOperandNumber ());
163164 }
164165
165166 AliasingValueList getAliasingValues (Operation *op, OpOperand &opOperand,
166167 const AnalysisState &state) const {
167168 func::CallOp callOp = cast<func::CallOp>(op);
168- FuncOp funcOp = getCalledFunction (callOp);
169+ const FuncAnalysisState &funcState = getFuncAnalysisState (state);
170+ FuncOp funcOp = getCalledFunction (callOp, funcState.symbolTable );
169171 assert (funcOp && " expected CallOp to a FuncOp" );
170172 if (getFuncOpAnalysisState (state, funcOp) != FuncOpAnalysisState::Analyzed)
171173 // FuncOp not analyzed yet. Any OpResult may be aliasing.
172174 return detail::unknownGetAliasingValues (opOperand);
173175
174176 // Get aliasing results from state.
175- const FuncAnalysisState &funcState = getFuncAnalysisState (state);
176177 auto aliasingReturnVals =
177178 funcState.aliasingReturnVals .lookup (funcOp).lookup (
178179 opOperand.getOperandNumber ());
@@ -199,7 +200,11 @@ struct CallOpInterface
199200 getBufferType (Operation *op, Value value, const BufferizationOptions &options,
200201 SmallVector<Value> &invocationStack) const {
201202 auto callOp = cast<func::CallOp>(op);
202- FuncOp funcOp = getCalledFunction (callOp);
203+
204+ // TODO Avoid recomputing the symbol tables every time.
205+ mlir::SymbolTableCollection symbolTable;
206+
207+ FuncOp funcOp = getCalledFunction (callOp, symbolTable);
203208 assert (funcOp && " expected CallOp to a FuncOp" );
204209
205210 // If the callee was already bufferized, we can directly take the type from
@@ -243,7 +248,11 @@ struct CallOpInterface
243248 // 2. Rewrite tensor operands as memrefs based on type of the already
244249 // bufferized callee.
245250 SmallVector<Value> newOperands;
246- FuncOp funcOp = getCalledFunction (callOp);
251+
252+ // TODO Avoid recomputing the symbol tables every time.
253+ mlir::SymbolTableCollection symbolTable;
254+
255+ FuncOp funcOp = getCalledFunction (callOp, symbolTable);
247256 assert (funcOp && " expected CallOp to a FuncOp" );
248257 FunctionType funcType = funcOp.getFunctionType ();
249258
0 commit comments