@@ -310,21 +310,19 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
310310// / any func::CallOp.
311311static LogicalResult getFuncOpsOrderedByCalls (
312312 ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
313- SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) {
313+ SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
314+ SymbolTableCollection &symbolTables) {
314315 // For each FuncOp, the set of functions called by it (i.e. the union of
315316 // symbols of all nested func::CallOp).
316317 DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
317318 // For each FuncOp, the number of func::CallOp it contains.
318319 DenseMap<func::FuncOp, unsigned > numberCallOpsContainedInFuncOp;
319320
320- // TODO Avoid recomputing the symbol tables every time.
321- mlir::SymbolTableCollection symbolTable;
322-
323321 for (func::FuncOp funcOp : moduleOp.getOps <func::FuncOp>()) {
324322 // Collect function calls and populate the caller map.
325323 numberCallOpsContainedInFuncOp[funcOp] = 0 ;
326324 WalkResult res = funcOp.walk ([&](func::CallOp callOp) -> WalkResult {
327- func::FuncOp calledFunction = getCalledFunction (callOp, symbolTable );
325+ func::FuncOp calledFunction = getCalledFunction (callOp, symbolTables );
328326 assert (calledFunction && " could not retrieved called func::FuncOp" );
329327 // If the called function does not have any tensors in its signature, then
330328 // it is not necessary to bufferize the callee before the caller.
@@ -458,7 +456,8 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
458456 FuncCallerMap callerMap;
459457
460458 if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps,
461- remainingFuncOps, callerMap)))
459+ remainingFuncOps, callerMap,
460+ funcState.symbolTables )))
462461 return failure ();
463462
464463 // Analyze functions in order. Starting with functions that are not calling
@@ -534,7 +533,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
534533 // each other recursively are bufferized in an unspecified order at the end.
535534 // We may use unnecessarily "complex" (in terms of layout map) buffer types.
536535 if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps,
537- remainingFuncOps, callerMap)))
536+ remainingFuncOps, callerMap,
537+ state.getSymbolTables ())))
538538 return failure ();
539539 llvm::append_range (orderedFuncOps, remainingFuncOps);
540540
0 commit comments