@@ -285,14 +285,17 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
285285}
286286
287287// / Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
288- // / callee-caller order (i.e. callees without callers first).
288+ // / callee-caller order (i.e., callees without callers first). Store all
289+ // / remaining functions (i.e., the ones that call each other recursively) in
290+ // / `remainingFuncOps`.
291+ // /
289292// / Store the map of FuncOp to all its callers in `callerMap`.
290- // / Return `failure()` if a cycle of calls is detected or if we are unable to
291- // / retrieve the called FuncOp from any func::CallOp.
292- static LogicalResult
293- getFuncOpsOrderedByCalls (ModuleOp moduleOp,
294- SmallVectorImpl<func::FuncOp> &orderedFuncOps,
295- FuncCallerMap &callerMap) {
293+ // /
294+ // / Return `failure()` if we are unable to retrieve the called FuncOp from
295+ // / any func::CallOp.
296+ static LogicalResult getFuncOpsOrderedByCalls (
297+ ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
298+ SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) {
296299 // For each FuncOp, the set of functions called by it (i.e. the union of
297300 // symbols of all nested func::CallOp).
298301 DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
@@ -326,19 +329,25 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
326329 });
327330 if (res.wasInterrupted ())
328331 return failure ();
332+
329333 // Iteratively remove function operations that do not call any of the
330- // functions remaining in the callCounter map and add them to the worklist .
334+ // functions remaining in the callCounter map and add them to ordered list .
331335 while (!numberCallOpsContainedInFuncOp.empty ()) {
332336 auto it = llvm::find_if (numberCallOpsContainedInFuncOp,
333337 [](auto entry) { return entry.getSecond () == 0 ; });
334338 if (it == numberCallOpsContainedInFuncOp.end ())
335- return moduleOp.emitOpError (
336- " expected callgraph to be free of circular dependencies." );
339+ break ;
337340 orderedFuncOps.push_back (it->getFirst ());
338341 for (auto callee : calledBy[it->getFirst ()])
339342 numberCallOpsContainedInFuncOp[callee]--;
340343 numberCallOpsContainedInFuncOp.erase (it);
341344 }
345+
346+ // Put all other functions in the list of remaining functions. These are
347+ // functions that call each each circularly.
348+ for (auto it : numberCallOpsContainedInFuncOp)
349+ remainingFuncOps.push_back (it.first );
350+
342351 return success ();
343352}
344353
@@ -379,15 +388,17 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
379388 FuncAnalysisState &funcState = getOrCreateFuncAnalysisState (state);
380389
381390 // A list of functions in the order in which they are analyzed + bufferized.
382- SmallVector<func::FuncOp> orderedFuncOps;
391+ SmallVector<func::FuncOp> orderedFuncOps, remainingFuncOps ;
383392
384393 // A mapping of FuncOps to their callers.
385394 FuncCallerMap callerMap;
386395
387- if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps, callerMap)))
396+ if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps,
397+ remainingFuncOps, callerMap)))
388398 return failure ();
389399
390- // Analyze ops.
400+ // Analyze ops in order. Starting with functions that are not calling any
401+ // other functions.
391402 for (func::FuncOp funcOp : orderedFuncOps) {
392403 if (!state.getOptions ().isOpAllowed (funcOp))
393404 continue ;
@@ -411,6 +422,25 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
411422 funcState.analyzedFuncOps [funcOp] = FuncOpAnalysisState::Analyzed;
412423 }
413424
425+ // Analyze all other ops.
426+ for (func::FuncOp funcOp : remainingFuncOps) {
427+ if (!state.getOptions ().isOpAllowed (funcOp))
428+ continue ;
429+
430+ // Gather equivalence info for CallOps.
431+ equivalenceAnalysis (funcOp, state, funcState);
432+
433+ // Analyze funcOp.
434+ if (failed (analyzeOp (funcOp, state, statistics)))
435+ return failure ();
436+
437+ // TODO: We currently skip all function argument analyses for functions
438+ // that call each other circularly. These analyses do not support recursive
439+ // calls yet. The `BufferizableOpInterface` implementations of `func`
440+ // dialect ops return conservative results in the absence of analysis
441+ // information.
442+ }
443+
414444 return success ();
415445}
416446
@@ -430,13 +460,20 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
430460 IRRewriter rewriter (moduleOp.getContext ());
431461
432462 // A list of functions in the order in which they are analyzed + bufferized.
433- SmallVector<func::FuncOp> orderedFuncOps;
463+ SmallVector<func::FuncOp> orderedFuncOps, remainingFuncOps ;
434464
435465 // A mapping of FuncOps to their callers.
436466 FuncCallerMap callerMap;
437467
438- if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps, callerMap)))
468+ // Try to bufferize functions in calling order. I.e., first bufferize
469+ // functions that do not call other functions. This allows us to infer
470+ // accurate buffer types for function return values. Functions that call
471+ // each other recursively are bufferized in an unspecified order at the end.
472+ // We may use unnecessarily "complex" (in terms of layout map) buffer types.
473+ if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps,
474+ remainingFuncOps, callerMap)))
439475 return failure ();
476+ llvm::append_range (orderedFuncOps, remainingFuncOps);
440477
441478 // Bufferize functions.
442479 for (func::FuncOp funcOp : orderedFuncOps) {
0 commit comments