@@ -285,14 +285,17 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
285285}
286286
287287// / Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
288- // / callee-caller order (i.e. callees without callers first).
288+ // / callee-caller order (i.e., callees without callers first). Store all
289+ // / remaining functions (i.e., the ones that call each other recursively) in
290+ // / `remainingFuncOps`.
291+ // /
289292// / Store the map of FuncOp to all its callers in `callerMap`.
290- // / Return `failure()` if a cycle of calls is detected or if we are unable to
291- // / retrieve the called FuncOp from any func::CallOp.
292- static LogicalResult
293- getFuncOpsOrderedByCalls (ModuleOp moduleOp,
294- SmallVectorImpl<func::FuncOp> &orderedFuncOps,
295- FuncCallerMap &callerMap) {
293+ // /
294+ // / Return `failure()` if we are unable to retrieve the called FuncOp from
295+ // / any func::CallOp.
296+ static LogicalResult getFuncOpsOrderedByCalls (
297+ ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
298+ SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) {
296299 // For each FuncOp, the set of functions called by it (i.e. the union of
297300 // symbols of all nested func::CallOp).
298301 DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
@@ -326,19 +329,25 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
326329 });
327330 if (res.wasInterrupted ())
328331 return failure ();
332+
329333 // Iteratively remove function operations that do not call any of the
330- // functions remaining in the callCounter map and add them to the worklist .
334+ // functions remaining in the callCounter map and add them to ordered list .
331335 while (!numberCallOpsContainedInFuncOp.empty ()) {
332336 auto it = llvm::find_if (numberCallOpsContainedInFuncOp,
333337 [](auto entry) { return entry.getSecond () == 0 ; });
334338 if (it == numberCallOpsContainedInFuncOp.end ())
335- return moduleOp.emitOpError (
336- " expected callgraph to be free of circular dependencies." );
339+ break ;
337340 orderedFuncOps.push_back (it->getFirst ());
338341 for (auto callee : calledBy[it->getFirst ()])
339342 numberCallOpsContainedInFuncOp[callee]--;
340343 numberCallOpsContainedInFuncOp.erase (it);
341344 }
345+
346+ // Put all other functions in the list of remaining functions. These are
347+ // functions that call each other circularly.
348+ for (auto it : numberCallOpsContainedInFuncOp)
349+ remainingFuncOps.push_back (it.first );
350+
342351 return success ();
343352}
344353
@@ -378,16 +387,23 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
378387 " expected that function boundary bufferization is activated" );
379388 FuncAnalysisState &funcState = getOrCreateFuncAnalysisState (state);
380389
381- // A list of functions in the order in which they are analyzed + bufferized.
390+ // A list of non-circular functions in the order in which they are analyzed
391+ // and bufferized.
382392 SmallVector<func::FuncOp> orderedFuncOps;
393+ // A list of all other functions. I.e., functions that call each other
394+ // recursively. For these, we analyze the function body but not the function
395+ // boundary.
396+ SmallVector<func::FuncOp> remainingFuncOps;
383397
384398 // A mapping of FuncOps to their callers.
385399 FuncCallerMap callerMap;
386400
387- if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps, callerMap)))
401+ if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps,
402+ remainingFuncOps, callerMap)))
388403 return failure ();
389404
390- // Analyze ops.
405+ // Analyze functions in order. Starting with functions that are not calling
406+ // any other functions.
391407 for (func::FuncOp funcOp : orderedFuncOps) {
392408 if (!state.getOptions ().isOpAllowed (funcOp))
393409 continue ;
@@ -411,6 +427,25 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
411427 funcState.analyzedFuncOps [funcOp] = FuncOpAnalysisState::Analyzed;
412428 }
413429
430+ // Analyze all other functions. All function boundary analyses are skipped.
431+ for (func::FuncOp funcOp : remainingFuncOps) {
432+ if (!state.getOptions ().isOpAllowed (funcOp))
433+ continue ;
434+
435+ // Gather equivalence info for CallOps.
436+ equivalenceAnalysis (funcOp, state, funcState);
437+
438+ // Analyze funcOp.
439+ if (failed (analyzeOp (funcOp, state, statistics)))
440+ return failure ();
441+
442+ // TODO: We currently skip all function argument analyses for functions
443+ // that call each other circularly. These analyses do not support recursive
444+ // calls yet. The `BufferizableOpInterface` implementations of `func`
445+ // dialect ops return conservative results in the absence of analysis
446+ // information.
447+ }
448+
414449 return success ();
415450}
416451
@@ -429,14 +464,26 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
429464 " expected that function boundary bufferization is activated" );
430465 IRRewriter rewriter (moduleOp.getContext ());
431466
432- // A list of functions in the order in which they are analyzed + bufferized.
467+ // A list of non-circular functions in the order in which they are analyzed
468+ // and bufferized.
433469 SmallVector<func::FuncOp> orderedFuncOps;
470+ // A list of all other functions. I.e., functions that call each other
471+ // recursively. For these, we analyze the function body but not the function
472+ // boundary.
473+ SmallVector<func::FuncOp> remainingFuncOps;
434474
435475 // A mapping of FuncOps to their callers.
436476 FuncCallerMap callerMap;
437477
438- if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps, callerMap)))
478+ // Try to bufferize functions in calling order. I.e., first bufferize
479+ // functions that do not call other functions. This allows us to infer
480+ // accurate buffer types for function return values. Functions that call
481+ // each other recursively are bufferized in an unspecified order at the end.
482+ // We may use unnecessarily "complex" (in terms of layout map) buffer types.
483+ if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps,
484+ remainingFuncOps, callerMap)))
439485 return failure ();
486+ llvm::append_range (orderedFuncOps, remainingFuncOps);
440487
441488 // Bufferize functions.
442489 for (func::FuncOp funcOp : orderedFuncOps) {
0 commit comments