1- // ===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
1+ // ===- OneShotModuleBufferize.cpp - Bufferization across Func. Boundaries
2+ // ----===//
23//
34// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45// See https://llvm.org/LICENSE.txt for license information.
89//
910// Module Bufferization is an extension of One-Shot Bufferize that
1011// bufferizes function boundaries. It provides `BufferizableOpInterface`
11- // implementations for FuncOp, CallOp and ReturnOp.
12+ // implementations for FuncOp, CallOp and ReturnOp. Although it is named
13+ // Module Bufferization, it may operate on any SymbolTable.
1214//
13- // Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`.
14- // This function analyzes the given module and determines the order of analysis
15- // and bufferization: Functions that are called are processed before their
16- // respective callers.
15+ // Module Bufferization is run via `runOneShotModuleBufferize(SymbolTableOp,
16+ // ...)`. This function analyzes the given op and determines the order of
17+ // analysis and bufferization: Functions that are called are processed before
18+ // their respective callers.
1719//
1820// After analyzing a FuncOp, additional information about its bbArgs is
1921// gathered and stored in `FuncAnalysisState`.
@@ -309,34 +311,37 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
309311// / Return `failure()` if we are unable to retrieve the called FuncOp from
310312// / any func::CallOp.
311313static LogicalResult getFuncOpsOrderedByCalls (
312- ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
314+ Operation * moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
313315 SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
314316 SymbolTableCollection &symbolTables) {
315317 // For each FuncOp, the set of functions called by it (i.e. the union of
316318 // symbols of all nested func::CallOp).
317319 DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
318320 // For each FuncOp, the number of func::CallOp it contains.
319321 DenseMap<func::FuncOp, unsigned > numberCallOpsContainedInFuncOp;
320-
321- for (func::FuncOp funcOp : moduleOp.getOps <func::FuncOp>()) {
322- // Collect function calls and populate the caller map.
323- numberCallOpsContainedInFuncOp[funcOp] = 0 ;
324- WalkResult res = funcOp.walk ([&](func::CallOp callOp) -> WalkResult {
325- func::FuncOp calledFunction = getCalledFunction (callOp, symbolTables);
326- assert (calledFunction && " could not retrieved called func::FuncOp" );
327- // If the called function does not have any tensors in its signature, then
328- // it is not necessary to bufferize the callee before the caller.
329- if (!hasTensorSignature (calledFunction))
330- return WalkResult::skip ();
331-
332- callerMap[calledFunction].insert (callOp);
333- if (calledBy[calledFunction].insert (funcOp).second ) {
334- numberCallOpsContainedInFuncOp[funcOp]++;
322+ for (mlir::Region ®ion : moduleOp->getRegions ()) {
323+ for (mlir::Block &block : region.getBlocks ()) {
324+ for (func::FuncOp funcOp : block.getOps <func::FuncOp>()) {
325+ // Collect function calls and populate the caller map.
326+ numberCallOpsContainedInFuncOp[funcOp] = 0 ;
327+ WalkResult res = funcOp.walk ([&](func::CallOp callOp) -> WalkResult {
328+ func::FuncOp calledFunction = getCalledFunction (callOp, symbolTables);
329+ assert (calledFunction && " could not retrieved called func::FuncOp" );
330+ // If the called function does not have any tensors in its signature,
331+ // then it is not necessary to bufferize the callee before the caller.
332+ if (!hasTensorSignature (calledFunction))
333+ return WalkResult::skip ();
334+
335+ callerMap[calledFunction].insert (callOp);
336+ if (calledBy[calledFunction].insert (funcOp).second ) {
337+ numberCallOpsContainedInFuncOp[funcOp]++;
338+ }
339+ return WalkResult::advance ();
340+ });
341+ if (res.wasInterrupted ())
342+ return failure ();
335343 }
336- return WalkResult::advance ();
337- });
338- if (res.wasInterrupted ())
339- return failure ();
344+ }
340345 }
341346
342347 // Iteratively remove function operations that do not call any of the
@@ -447,7 +452,7 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
447452}
448453
449454LogicalResult
450- mlir::bufferization::analyzeModuleOp (ModuleOp moduleOp,
455+ mlir::bufferization::analyzeModuleOp (Operation * moduleOp,
451456 OneShotAnalysisState &state,
452457 BufferizationStatistics *statistics) {
453458 assert (state.getOptions ().bufferizeFunctionBoundaries &&
@@ -512,19 +517,23 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
512517}
513518
514519void mlir::bufferization::removeBufferizationAttributesInModule (
515- ModuleOp moduleOp) {
516- for (auto op : moduleOp.getOps <func::FuncOp>()) {
517- for (BlockArgument bbArg : op.getArguments ())
518- removeBufferizationAttributes (bbArg);
520+ Operation *moduleOp) {
521+ for (mlir::Region ®ion : moduleOp->getRegions ()) {
522+ for (mlir::Block &block : region.getBlocks ()) {
523+ for (func::FuncOp funcOp : block.getOps <func::FuncOp>()) {
524+ for (BlockArgument bbArg : funcOp.getArguments ())
525+ removeBufferizationAttributes (bbArg);
526+ }
527+ }
519528 }
520529}
521530
522531LogicalResult mlir::bufferization::bufferizeModuleOp (
523- ModuleOp moduleOp, const OneShotBufferizationOptions &options,
532+ Operation * moduleOp, const OneShotBufferizationOptions &options,
524533 BufferizationState &state, BufferizationStatistics *statistics) {
525534 assert (options.bufferizeFunctionBoundaries &&
526535 " expected that function boundary bufferization is activated" );
527- IRRewriter rewriter (moduleOp. getContext ());
536+ IRRewriter rewriter (moduleOp-> getContext ());
528537
529538 // A list of non-circular functions in the order in which they are analyzed
530539 // and bufferized.
@@ -571,12 +580,17 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
571580 }
572581
573582 // Bufferize all other ops.
574- for (Operation &op : llvm::make_early_inc_range (moduleOp.getOps ())) {
575- // Functions were already bufferized.
576- if (isa<func::FuncOp>(&op) || op.hasTrait <OpTrait::SymbolTable>())
577- continue ;
578- if (failed (bufferizeOp (&op, options, state, statistics)))
579- return failure ();
583+ for (mlir::Region ®ion : moduleOp->getRegions ()) {
584+ for (mlir::Block &block : region.getBlocks ()) {
585+ for (mlir::Operation &op :
586+ llvm::make_early_inc_range (block.getOperations ())) {
587+ // Functions were already bufferized.
588+ if (isa<func::FuncOp>(&op) || op.hasTrait <OpTrait::SymbolTable>())
589+ continue ;
590+ if (failed (bufferizeOp (&op, options, state, statistics)))
591+ return failure ();
592+ }
593+ }
580594 }
581595
582596 // Post-pass cleanup of function argument attributes.
@@ -586,7 +600,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
586600}
587601
588602LogicalResult mlir::bufferization::runOneShotModuleBufferize (
589- ModuleOp moduleOp, const OneShotBufferizationOptions &options,
603+ Operation * moduleOp, const OneShotBufferizationOptions &options,
590604 BufferizationState &state, BufferizationStatistics *statistics) {
591605 assert (options.bufferizeFunctionBoundaries &&
592606 " expected that function boundary bufferization is activated" );
0 commit comments