Skip to content

Conversation

@christopherbate
Copy link
Contributor

@christopherbate christopherbate commented Feb 19, 2025

The existing OneShotModuleBufferize will analyze and bufferize
operations which are in nested symbol tables (e.g. nested
builtin.module, gpu.module, or similar operations). This
behavior is untested and likely unintentional given other
limitations of OneShotModuleBufferize (func.call can't call
into nested symbol tables). This change reverses the existing
behavior so that the operations considered by the analysis and
bufferization exclude any operations in nested symbol table
scopes. Users who desire to bufferize nested modules can still do
so by applying the transformation in a pass pipeline or in a
custom pass. This further enables controlling the order in which
modules are bufferized as well as allowing use of different
options for different kinds of modules.

@llvmbot llvmbot added mlir mlir:bufferization Bufferization infrastructure labels Feb 19, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 19, 2025

@llvm/pr-subscribers-mlir-bufferization

@llvm/pr-subscribers-mlir

Author: Christopher Bate (christopherbate)

Changes

The existing OneShotModuleBufferize will analyze and bufferize
operations which are in nested symbol tables (e.g. nested
builtin.module, gpu.module, or similar operations). This
behavior is untested and likely unintential given other
limitations of OneShotModuleBufferize (func.call can't call
into nested symbol tables). This change reverses the existing
behavior so that the operations considered by the analysis and
bufferization exclude any operations in nested symbol table
scopes. Users who desire to bufferize nested modules can still do
so by applying the transformation in a pass pipeline or in a
custom pass. This further enables controlling the order in which
moduels are bufferized as well as allowing use of different
options for different kinds of modules.


Full diff: https://github.com/llvm/llvm-project/pull/127726.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+9-9)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp (+12-3)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir (+14)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir (+17-19)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 71ea0fd9d43cd..00085dc1b61dd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -329,7 +329,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
 /// callee-caller order (i.e., callees without callers first). Store all
 /// remaining functions (i.e., the ones that call each other recursively) in
-/// `remainingFuncOps`.
+/// `remainingFuncOps`. Does not traverse nested symbol tables.
 ///
 /// Store the map of FuncOp to all its callers in `callerMap`.
 ///
@@ -343,10 +343,10 @@ static LogicalResult getFuncOpsOrderedByCalls(
   DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
   // For each FuncOp, the number of func::CallOp it contains.
   DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
-  WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
+  for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
     // Collect function calls and populate the caller map.
     numberCallOpsContainedInFuncOp[funcOp] = 0;
-    return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
+    WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
       func::FuncOp calledFunction = getCalledFunction(callOp);
       assert(calledFunction && "could not retrieved called func::FuncOp");
       // If the called function does not have any tensors in its signature, then
@@ -360,9 +360,9 @@ static LogicalResult getFuncOpsOrderedByCalls(
       }
       return WalkResult::advance();
     });
-  });
-  if (res.wasInterrupted())
-    return failure();
+    if (res.wasInterrupted())
+      return failure();
+  }
 
   // Iteratively remove function operations that do not call any of the
   // functions remaining in the callCounter map and add them to ordered list.
@@ -533,10 +533,10 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
 
 void mlir::bufferization::removeBufferizationAttributesInModule(
     ModuleOp moduleOp) {
-  moduleOp.walk([&](func::FuncOp op) {
+  for (auto op : moduleOp.getOps<func::FuncOp>()) {
     for (BlockArgument bbArg : op.getArguments())
       removeBufferizationAttributes(bbArg);
-  });
+  }
 }
 
 LogicalResult mlir::bufferization::bufferizeModuleOp(
@@ -592,7 +592,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
   // Bufferize all other ops.
   for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
     // Functions were already bufferized.
-    if (isa<func::FuncOp>(&op))
+    if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
       continue;
     if (failed(bufferizeOp(&op, options, statistics)))
       return failure();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
index 6db60b75b302b..4326b19f3104d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
@@ -52,14 +52,23 @@ mlir::bufferization::insertTensorCopies(Operation *op,
                                         const AnalysisState &state) {
   IRRewriter rewriter(op->getContext());
 
-  WalkResult result = op->walk([&](Operation *op) {
-    auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op);
+  // It may be more efficient to walk in pre-order here, but the current
+  // implementation visits regions of ops even if they are not allowed or
+  // bufferizable, and existing tests rely on this behavior.
+  // For now, only exclude nested operations if they are in a different symbol
+  // table scope.
+  WalkResult result = op->walk([&](Operation *nestedOp) {
+    if (op->hasTrait<OpTrait::SymbolTable>() &&
+        nestedOp->getParentWithTrait<OpTrait::SymbolTable>() != op)
+      return WalkResult::skip();
+
+    auto bufferizableOp = state.getOptions().dynCastBufferizableOp(nestedOp);
     if (!bufferizableOp)
       return WalkResult::skip();
 
     // Find inplacability conflicts and resolve them. (Typically with explicit
     // tensor copies in the form of AllocTensorOps.)
-    rewriter.setInsertionPoint(op);
+    rewriter.setInsertionPoint(nestedOp);
     if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
       return WalkResult::interrupt();
 
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index ec2fb58ee03f8..e7797d4bc50a9 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -796,3 +796,17 @@ func.func @result_type_mismatch(%c: i1) -> tensor<5xf32> {
   return %1 : tensor<5xf32>
 }
 
+
+// -----
+
+// CHECK-LABEL: @outer_func({{.+}}: memref<
+func.func @outer_func(%t: tensor<5xf32>) -> tensor<5xf32> {
+  return %t : tensor<5xf32>
+}
+
+module @inner_module {
+  // CHECK: @inner_func({{.+}}: tensor<5xf32> {bufferization.writable = false})
+  func.func @inner_func(%t: tensor<5xf32> {bufferization.writable = false}) -> tensor<5xf32> {
+    return %t : tensor<5xf32>
+  }
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
index 3c50a9e72d9d9..a2741abbda3b0 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
@@ -111,23 +111,21 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-module {
-  // CHECK-LABEL: func @test_function(
-  //  CHECK-SAME:     %[[A:.*]]: tensor<?xf32>
-  func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
-    %c0 = arith.constant 0 : index
-
-    // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
-    // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
-    // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
-    // CHECK: memref.copy %[[A_memref]], %[[alloc]]
-    // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
-    // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
-    %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
-
-    // CHECK: return %[[res_tensor]]
-    return %0 : tensor<?xf32>
-  }
+// CHECK-LABEL: func @test_function(
+//  CHECK-SAME:     %[[A:.*]]: tensor<?xf32>
+func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
+  %c0 = arith.constant 0 : index
+
+  // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
+  // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
+  // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
+  // CHECK: memref.copy %[[A_memref]], %[[alloc]]
+  // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
+  // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
+  %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
+
+  // CHECK: return %[[res_tensor]]
+  return %0 : tensor<?xf32>
 }
 
 // -----
@@ -222,8 +220,8 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %alloc_tensor = transform.structured.match ops{["bufferization.alloc_tensor"]} in %arg1
       : (!transform.any_op) -> !transform.op<"bufferization.alloc_tensor">
-    %2, %new = transform.structured.bufferize_to_allocation %alloc_tensor 
-      {alloc_op = "memref.alloca"} 
+    %2, %new = transform.structured.bufferize_to_allocation %alloc_tensor
+      {alloc_op = "memref.alloca"}
         : !transform.op<"bufferization.alloc_tensor">
     transform.yield
   }

@christopherbate
Copy link
Contributor Author

@matthias-springer Let me know if the existing behavior is intentional and I can adjust.

If the existing behavior is intentional (we want one-shot-bufferization to bufferize nested modules), then I can add tests for the filtering to ensure nested modules can be correctly excluded.

I can also take a look at switching some of the walk operations to pre-order, which I think would be much more efficient when nested modules are present but are excluded from analysis/copy-insertion/bufferization.

… bufferize nested symbol tables

The existing OneShotModuleBufferize will analyze and bufferize
operations which are in nested symbol tables (e.g. nested
`builtin.module`, `gpu.module`, or similar operations). This
behavior is untested and likely unintential given other
limitations of OneShotModuleBufferize (`func.call` can't call
into nested symbol tables). This change reverses the existing
behavior so that the operations considered by the analysis and
bufferization exclude any operations in nested symbol table
scopes. Users who desire to bufferize nested modules can still do
so by applying the transformation in a pass pipeline or in a
custom pass. This further enables controlling the order in which
moduels are bufferized as well as allowing use of different
options for different kinds of modules.
@christopherbate christopherbate force-pushed the mlir-module-bufferization-improvements branch from f08cdd3 to 79b6d02 Compare February 24, 2025 22:14
@christopherbate christopherbate merged commit 2646c36 into llvm:main Feb 25, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:bufferization Bufferization infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants