diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index fc6424a25ac70..d1d106220a38c 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -341,15 +341,25 @@ static LogicalResult getFuncOpsOrderedByCalls( // Iteratively remove function operations that do not call any of the // functions remaining in the callCounter map and add them to ordered list. - while (!numberCallOpsContainedInFuncOp.empty()) { - auto it = llvm::find_if(numberCallOpsContainedInFuncOp, - [](auto entry) { return entry.getSecond() == 0; }); - if (it == numberCallOpsContainedInFuncOp.end()) - break; - orderedFuncOps.push_back(it->getFirst()); - for (auto callee : calledBy[it->getFirst()]) - numberCallOpsContainedInFuncOp[callee]--; - numberCallOpsContainedInFuncOp.erase(it); + SmallVector worklist; + + for (const auto &entry : numberCallOpsContainedInFuncOp) { + if (entry.second == 0) + worklist.push_back(entry.first); + } + + while (!worklist.empty()) { + func::FuncOp func = worklist.pop_back_val(); + orderedFuncOps.push_back(func); + + for (func::FuncOp caller : calledBy[func]) { + auto &count = numberCallOpsContainedInFuncOp[caller]; + + if (--count == 0) + worklist.push_back(caller); + } + + numberCallOpsContainedInFuncOp.erase(func); } // Put all other functions in the list of remaining functions. These are