Skip to content

Commit f08cdd3

Browse files
[mlir][bufferization] Change OneShotModuleBufferize to not analyze or bufferize nested symbol tables
The existing OneShotModuleBufferize will analyze and bufferize operations which are in nested symbol tables (e.g. nested `builtin.module`, `gpu.module`, or similar operations). This behavior is untested and likely unintential given other limitations of OneShotModuleBufferize (`func.call` can't call into nested symbol tables). This change reverses the existing behavior so that the operations considered by the analysis and bufferization exclude any operations in nested symbol table scopes. Users who desire to bufferize nested modules can still do so by applying the transformation in a pass pipeline or in a custom pass. This further enables controlling the order in which moduels are bufferized as well as allowing use of different options for different kinds of modules.
1 parent 9bf582f commit f08cdd3

File tree

4 files changed

+52
-31
lines changed

4 files changed

+52
-31
lines changed

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
329329
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
330330
/// callee-caller order (i.e., callees without callers first). Store all
331331
/// remaining functions (i.e., the ones that call each other recursively) in
332-
/// `remainingFuncOps`.
332+
/// `remainingFuncOps`. Does not traverse nested symbol tables.
333333
///
334334
/// Store the map of FuncOp to all its callers in `callerMap`.
335335
///
@@ -343,10 +343,10 @@ static LogicalResult getFuncOpsOrderedByCalls(
343343
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
344344
// For each FuncOp, the number of func::CallOp it contains.
345345
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
346-
WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
346+
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
347347
// Collect function calls and populate the caller map.
348348
numberCallOpsContainedInFuncOp[funcOp] = 0;
349-
return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
349+
WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
350350
func::FuncOp calledFunction = getCalledFunction(callOp);
351351
assert(calledFunction && "could not retrieved called func::FuncOp");
352352
// If the called function does not have any tensors in its signature, then
@@ -360,9 +360,9 @@ static LogicalResult getFuncOpsOrderedByCalls(
360360
}
361361
return WalkResult::advance();
362362
});
363-
});
364-
if (res.wasInterrupted())
365-
return failure();
363+
if (res.wasInterrupted())
364+
return failure();
365+
}
366366

367367
// Iteratively remove function operations that do not call any of the
368368
// functions remaining in the callCounter map and add them to ordered list.
@@ -533,10 +533,10 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
533533

534534
void mlir::bufferization::removeBufferizationAttributesInModule(
535535
ModuleOp moduleOp) {
536-
moduleOp.walk([&](func::FuncOp op) {
536+
for (auto op : moduleOp.getOps<func::FuncOp>()) {
537537
for (BlockArgument bbArg : op.getArguments())
538538
removeBufferizationAttributes(bbArg);
539-
});
539+
}
540540
}
541541

542542
LogicalResult mlir::bufferization::bufferizeModuleOp(
@@ -592,7 +592,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
592592
// Bufferize all other ops.
593593
for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
594594
// Functions were already bufferized.
595-
if (isa<func::FuncOp>(&op))
595+
if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
596596
continue;
597597
if (failed(bufferizeOp(&op, options, statistics)))
598598
return failure();

mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,23 @@ mlir::bufferization::insertTensorCopies(Operation *op,
5252
const AnalysisState &state) {
5353
IRRewriter rewriter(op->getContext());
5454

55-
WalkResult result = op->walk([&](Operation *op) {
56-
auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op);
55+
// It may be more efficient to walk in pre-order here, but the current
56+
// implementation visits regions of ops even if they are not allowed or
57+
// bufferizable, and existing tests rely on this behavior.
58+
// For now, only exclude nested operations if they are in a different symbol
59+
// table scope.
60+
WalkResult result = op->walk([&](Operation *nestedOp) {
61+
if (op->hasTrait<OpTrait::SymbolTable>() &&
62+
nestedOp->getParentWithTrait<OpTrait::SymbolTable>() != op)
63+
return WalkResult::skip();
64+
65+
auto bufferizableOp = state.getOptions().dynCastBufferizableOp(nestedOp);
5766
if (!bufferizableOp)
5867
return WalkResult::skip();
5968

6069
// Find inplacability conflicts and resolve them. (Typically with explicit
6170
// tensor copies in the form of AllocTensorOps.)
62-
rewriter.setInsertionPoint(op);
71+
rewriter.setInsertionPoint(nestedOp);
6372
if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
6473
return WalkResult::interrupt();
6574

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,3 +796,17 @@ func.func @result_type_mismatch(%c: i1) -> tensor<5xf32> {
796796
return %1 : tensor<5xf32>
797797
}
798798

799+
800+
// -----
801+
802+
// CHECK-LABEL: @outer_func({{.+}}: memref<
803+
func.func @outer_func(%t: tensor<5xf32>) -> tensor<5xf32> {
804+
return %t : tensor<5xf32>
805+
}
806+
807+
module @inner_module {
808+
// CHECK: @inner_func({{.+}}: tensor<5xf32> {bufferization.writable = false})
809+
func.func @inner_func(%t: tensor<5xf32> {bufferization.writable = false}) -> tensor<5xf32> {
810+
return %t : tensor<5xf32>
811+
}
812+
}

mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -111,23 +111,21 @@ module attributes {transform.with_named_sequence} {
111111
}
112112
}
113113

114-
module {
115-
// CHECK-LABEL: func @test_function(
116-
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
117-
func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
118-
%c0 = arith.constant 0 : index
119-
120-
// CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
121-
// CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
122-
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
123-
// CHECK: memref.copy %[[A_memref]], %[[alloc]]
124-
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
125-
// CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
126-
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
127-
128-
// CHECK: return %[[res_tensor]]
129-
return %0 : tensor<?xf32>
130-
}
114+
// CHECK-LABEL: func @test_function(
115+
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
116+
func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
117+
%c0 = arith.constant 0 : index
118+
119+
// CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
120+
// CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
121+
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
122+
// CHECK: memref.copy %[[A_memref]], %[[alloc]]
123+
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
124+
// CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
125+
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
126+
127+
// CHECK: return %[[res_tensor]]
128+
return %0 : tensor<?xf32>
131129
}
132130

133131
// -----
@@ -222,8 +220,8 @@ module attributes {transform.with_named_sequence} {
222220
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
223221
%alloc_tensor = transform.structured.match ops{["bufferization.alloc_tensor"]} in %arg1
224222
: (!transform.any_op) -> !transform.op<"bufferization.alloc_tensor">
225-
%2, %new = transform.structured.bufferize_to_allocation %alloc_tensor
226-
{alloc_op = "memref.alloca"}
223+
%2, %new = transform.structured.bufferize_to_allocation %alloc_tensor
224+
{alloc_op = "memref.alloca"}
227225
: !transform.op<"bufferization.alloc_tensor">
228226
transform.yield
229227
}

0 commit comments

Comments
 (0)