diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index 77840690e6a26..edd6bcf84f460 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -300,7 +300,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) { /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by /// callee-caller order (i.e., callees without callers first). Store all /// remaining functions (i.e., the ones that call each other recursively) in -/// `remainingFuncOps`. +/// `remainingFuncOps`. Does not traverse nested symbol tables. /// /// Store the map of FuncOp to all its callers in `callerMap`. /// @@ -314,10 +314,10 @@ static LogicalResult getFuncOpsOrderedByCalls( DenseMap> calledBy; // For each FuncOp, the number of func::CallOp it contains. DenseMap numberCallOpsContainedInFuncOp; - WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult { + for (func::FuncOp funcOp : moduleOp.getOps()) { // Collect function calls and populate the caller map. numberCallOpsContainedInFuncOp[funcOp] = 0; - return funcOp.walk([&](func::CallOp callOp) -> WalkResult { + WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult { func::FuncOp calledFunction = getCalledFunction(callOp); assert(calledFunction && "could not retrieved called func::FuncOp"); // If the called function does not have any tensors in its signature, then @@ -331,9 +331,9 @@ static LogicalResult getFuncOpsOrderedByCalls( } return WalkResult::advance(); }); - }); - if (res.wasInterrupted()) - return failure(); + if (res.wasInterrupted()) + return failure(); + } // Iteratively remove function operations that do not call any of the // functions remaining in the callCounter map and add them to ordered list. @@ -498,10 +498,10 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, void mlir::bufferization::removeBufferizationAttributesInModule( ModuleOp moduleOp) { - moduleOp.walk([&](func::FuncOp op) { + for (auto op : moduleOp.getOps()) { for (BlockArgument bbArg : op.getArguments()) removeBufferizationAttributes(bbArg); - }); + } } LogicalResult mlir::bufferization::bufferizeModuleOp( @@ -557,7 +557,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( // Bufferize all other ops. for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) { // Functions were already bufferized. - if (isa(&op)) + if (isa(&op) || op.hasTrait()) continue; if (failed(bufferizeOp(&op, options, statistics))) return failure(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp index 6db60b75b302b..4326b19f3104d 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp @@ -52,14 +52,23 @@ mlir::bufferization::insertTensorCopies(Operation *op, const AnalysisState &state) { IRRewriter rewriter(op->getContext()); - WalkResult result = op->walk([&](Operation *op) { - auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op); + // It may be more efficient to walk in pre-order here, but the current + // implementation visits regions of ops even if they are not allowed or + // bufferizable, and existing tests rely on this behavior. + // For now, only exclude nested operations if they are in a different symbol + // table scope. + WalkResult result = op->walk([&](Operation *nestedOp) { + if (op->hasTrait() && + nestedOp->getParentWithTrait() != op) + return WalkResult::skip(); + + auto bufferizableOp = state.getOptions().dynCastBufferizableOp(nestedOp); if (!bufferizableOp) return WalkResult::skip(); // Find inplacability conflicts and resolve them. (Typically with explicit // tensor copies in the form of AllocTensorOps.) - rewriter.setInsertionPoint(op); + rewriter.setInsertionPoint(nestedOp); if (failed(bufferizableOp.resolveConflicts(rewriter, state))) return WalkResult::interrupt(); diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir index ec2fb58ee03f8..e7797d4bc50a9 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -796,3 +796,17 @@ func.func @result_type_mismatch(%c: i1) -> tensor<5xf32> { return %1 : tensor<5xf32> } + +// ----- + +// CHECK-LABEL: @outer_func({{.+}}: memref< +func.func @outer_func(%t: tensor<5xf32>) -> tensor<5xf32> { + return %t : tensor<5xf32> +} + +module @inner_module { + // CHECK: @inner_func({{.+}}: tensor<5xf32> {bufferization.writable = false}) + func.func @inner_func(%t: tensor<5xf32> {bufferization.writable = false}) -> tensor<5xf32> { + return %t : tensor<5xf32> + } +} diff --git a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir index 3c50a9e72d9d9..a2741abbda3b0 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir @@ -111,23 +111,21 @@ module attributes {transform.with_named_sequence} { } } -module { - // CHECK-LABEL: func @test_function( - // CHECK-SAME: %[[A:.*]]: tensor - func.func @test_function(%A : tensor, %v : vector<4xf32>) -> (tensor) { - %c0 = arith.constant 0 : index - - // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] - // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]] - // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) - // CHECK: memref.copy %[[A_memref]], %[[alloc]] - // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] - // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]] - %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor - - // CHECK: return %[[res_tensor]] - return %0 : tensor - } +// CHECK-LABEL: func @test_function( +// CHECK-SAME: %[[A:.*]]: tensor +func.func @test_function(%A : tensor, %v : vector<4xf32>) -> (tensor) { + %c0 = arith.constant 0 : index + + // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] + // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]] + // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) + // CHECK: memref.copy %[[A_memref]], %[[alloc]] + // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] + // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]] + %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor + + // CHECK: return %[[res_tensor]] + return %0 : tensor } // ----- @@ -222,8 +220,8 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %alloc_tensor = transform.structured.match ops{["bufferization.alloc_tensor"]} in %arg1 : (!transform.any_op) -> !transform.op<"bufferization.alloc_tensor"> - %2, %new = transform.structured.bufferize_to_allocation %alloc_tensor - {alloc_op = "memref.alloca"} + %2, %new = transform.structured.bufferize_to_allocation %alloc_tensor + {alloc_op = "memref.alloca"} : !transform.op<"bufferization.alloc_tensor"> transform.yield }