Skip to content

Commit 2646c36

Browse files
[mlir][bufferization] Change OneShotModuleBufferize to not analyze or bufferize nested symbol tables (llvm#127726)
The existing OneShotModuleBufferize will analyze and bufferize operations which are in nested symbol tables (e.g. nested `builtin.module`, `gpu.module`, or similar operations). This behavior is untested and likely unintentional given other limitations of OneShotModuleBufferize (`func.call` can't call into nested symbol tables). This change reverses the existing behavior so that the operations considered by the analysis and bufferization exclude any operations in nested symbol table scopes. Users who desire to bufferize nested modules can still do so by applying the transformation in a pass pipeline or in a custom pass. This further enables controlling the order in which modules are bufferized as well as allowing use of different options for different kinds of modules.
1 parent 0be3f13 commit 2646c36

File tree

4 files changed

+52
-31
lines changed

4 files changed

+52
-31
lines changed

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
300300
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
301301
/// callee-caller order (i.e., callees without callers first). Store all
302302
/// remaining functions (i.e., the ones that call each other recursively) in
303-
/// `remainingFuncOps`.
303+
/// `remainingFuncOps`. Does not traverse nested symbol tables.
304304
///
305305
/// Store the map of FuncOp to all its callers in `callerMap`.
306306
///
@@ -314,10 +314,10 @@ static LogicalResult getFuncOpsOrderedByCalls(
314314
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
315315
// For each FuncOp, the number of func::CallOp it contains.
316316
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
317-
WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
317+
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
318318
// Collect function calls and populate the caller map.
319319
numberCallOpsContainedInFuncOp[funcOp] = 0;
320-
return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
320+
WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
321321
func::FuncOp calledFunction = getCalledFunction(callOp);
322322
assert(calledFunction && "could not retrieved called func::FuncOp");
323323
// If the called function does not have any tensors in its signature, then
@@ -331,9 +331,9 @@ static LogicalResult getFuncOpsOrderedByCalls(
331331
}
332332
return WalkResult::advance();
333333
});
334-
});
335-
if (res.wasInterrupted())
336-
return failure();
334+
if (res.wasInterrupted())
335+
return failure();
336+
}
337337

338338
// Iteratively remove function operations that do not call any of the
339339
// functions remaining in the callCounter map and add them to ordered list.
@@ -498,10 +498,10 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
498498

499499
void mlir::bufferization::removeBufferizationAttributesInModule(
500500
ModuleOp moduleOp) {
501-
moduleOp.walk([&](func::FuncOp op) {
501+
for (auto op : moduleOp.getOps<func::FuncOp>()) {
502502
for (BlockArgument bbArg : op.getArguments())
503503
removeBufferizationAttributes(bbArg);
504-
});
504+
}
505505
}
506506

507507
LogicalResult mlir::bufferization::bufferizeModuleOp(
@@ -557,7 +557,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
557557
// Bufferize all other ops.
558558
for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
559559
// Functions were already bufferized.
560-
if (isa<func::FuncOp>(&op))
560+
if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
561561
continue;
562562
if (failed(bufferizeOp(&op, options, statistics)))
563563
return failure();

mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,23 @@ mlir::bufferization::insertTensorCopies(Operation *op,
5252
const AnalysisState &state) {
5353
IRRewriter rewriter(op->getContext());
5454

55-
WalkResult result = op->walk([&](Operation *op) {
56-
auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op);
55+
// It may be more efficient to walk in pre-order here, but the current
56+
// implementation visits regions of ops even if they are not allowed or
57+
// bufferizable, and existing tests rely on this behavior.
58+
// For now, only exclude nested operations if they are in a different symbol
59+
// table scope.
60+
WalkResult result = op->walk([&](Operation *nestedOp) {
61+
if (op->hasTrait<OpTrait::SymbolTable>() &&
62+
nestedOp->getParentWithTrait<OpTrait::SymbolTable>() != op)
63+
return WalkResult::skip();
64+
65+
auto bufferizableOp = state.getOptions().dynCastBufferizableOp(nestedOp);
5766
if (!bufferizableOp)
5867
return WalkResult::skip();
5968

6069
// Find inplacability conflicts and resolve them. (Typically with explicit
6170
// tensor copies in the form of AllocTensorOps.)
62-
rewriter.setInsertionPoint(op);
71+
rewriter.setInsertionPoint(nestedOp);
6372
if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
6473
return WalkResult::interrupt();
6574

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,3 +796,17 @@ func.func @result_type_mismatch(%c: i1) -> tensor<5xf32> {
796796
return %1 : tensor<5xf32>
797797
}
798798

799+
800+
// -----
801+
802+
// CHECK-LABEL: @outer_func({{.+}}: memref<
803+
func.func @outer_func(%t: tensor<5xf32>) -> tensor<5xf32> {
804+
return %t : tensor<5xf32>
805+
}
806+
807+
module @inner_module {
808+
// CHECK: @inner_func({{.+}}: tensor<5xf32> {bufferization.writable = false})
809+
func.func @inner_func(%t: tensor<5xf32> {bufferization.writable = false}) -> tensor<5xf32> {
810+
return %t : tensor<5xf32>
811+
}
812+
}

mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -111,23 +111,21 @@ module attributes {transform.with_named_sequence} {
111111
}
112112
}
113113

114-
module {
115-
// CHECK-LABEL: func @test_function(
116-
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
117-
func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
118-
%c0 = arith.constant 0 : index
119-
120-
// CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
121-
// CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
122-
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
123-
// CHECK: memref.copy %[[A_memref]], %[[alloc]]
124-
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
125-
// CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
126-
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
127-
128-
// CHECK: return %[[res_tensor]]
129-
return %0 : tensor<?xf32>
130-
}
114+
// CHECK-LABEL: func @test_function(
115+
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
116+
func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
117+
%c0 = arith.constant 0 : index
118+
119+
// CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
120+
// CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
121+
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
122+
// CHECK: memref.copy %[[A_memref]], %[[alloc]]
123+
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
124+
// CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
125+
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
126+
127+
// CHECK: return %[[res_tensor]]
128+
return %0 : tensor<?xf32>
131129
}
132130

133131
// -----
@@ -222,8 +220,8 @@ module attributes {transform.with_named_sequence} {
222220
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
223221
%alloc_tensor = transform.structured.match ops{["bufferization.alloc_tensor"]} in %arg1
224222
: (!transform.any_op) -> !transform.op<"bufferization.alloc_tensor">
225-
%2, %new = transform.structured.bufferize_to_allocation %alloc_tensor
226-
{alloc_op = "memref.alloca"}
223+
%2, %new = transform.structured.bufferize_to_allocation %alloc_tensor
224+
{alloc_op = "memref.alloca"}
227225
: !transform.op<"bufferization.alloc_tensor">
228226
transform.yield
229227
}

0 commit comments

Comments
 (0)