From 2095a7cec48091ede9847a66c28eb0163fdf728d Mon Sep 17 00:00:00 2001 From: Michele Scuttari Date: Fri, 30 May 2025 09:26:20 +0200 Subject: [PATCH] Reduce complexity of searching circular function calls --- .../Transforms/OneShotModuleBufferize.cpp | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index dee2af8271ce8..03df1f85b526e 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -343,15 +343,25 @@ static LogicalResult getFuncOpsOrderedByCalls( // Iteratively remove function operations that do not call any of the // functions remaining in the callCounter map and add them to ordered list. - while (!numberCallOpsContainedInFuncOp.empty()) { - auto it = llvm::find_if(numberCallOpsContainedInFuncOp, - [](auto entry) { return entry.getSecond() == 0; }); - if (it == numberCallOpsContainedInFuncOp.end()) - break; - orderedFuncOps.push_back(it->getFirst()); - for (auto callee : calledBy[it->getFirst()]) - numberCallOpsContainedInFuncOp[callee]--; - numberCallOpsContainedInFuncOp.erase(it); + SmallVector worklist; + + for (const auto &entry : numberCallOpsContainedInFuncOp) { + if (entry.second == 0) + worklist.push_back(entry.first); + } + + while (!worklist.empty()) { + func::FuncOp func = worklist.pop_back_val(); + orderedFuncOps.push_back(func); + + for (func::FuncOp caller : calledBy[func]) { + auto &count = numberCallOpsContainedInFuncOp[caller]; + + if (--count == 0) + worklist.push_back(caller); + } + + numberCallOpsContainedInFuncOp.erase(func); } // Put all other functions in the list of remaining functions. These are