Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
/// callee-caller order (i.e., callees without callers first). Store all
/// remaining functions (i.e., the ones that call each other recursively) in
/// `remainingFuncOps`.
/// `remainingFuncOps`. Does not traverse nested symbol tables.
///
/// Store the map of FuncOp to all its callers in `callerMap`.
///
Expand All @@ -314,10 +314,10 @@ static LogicalResult getFuncOpsOrderedByCalls(
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
func::FuncOp calledFunction = getCalledFunction(callOp);
assert(calledFunction && "could not retrieved called func::FuncOp");
// If the called function does not have any tensors in its signature, then
Expand All @@ -331,9 +331,9 @@ static LogicalResult getFuncOpsOrderedByCalls(
}
return WalkResult::advance();
});
});
if (res.wasInterrupted())
return failure();
if (res.wasInterrupted())
return failure();
}

// Iteratively remove function operations that do not call any of the
// functions remaining in the callCounter map and add them to ordered list.
Expand Down Expand Up @@ -498,10 +498,10 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,

void mlir::bufferization::removeBufferizationAttributesInModule(
ModuleOp moduleOp) {
moduleOp.walk([&](func::FuncOp op) {
for (auto op : moduleOp.getOps<func::FuncOp>()) {
for (BlockArgument bbArg : op.getArguments())
removeBufferizationAttributes(bbArg);
});
}
}

LogicalResult mlir::bufferization::bufferizeModuleOp(
Expand Down Expand Up @@ -557,7 +557,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
// Bufferize all other ops.
for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
// Functions were already bufferized.
if (isa<func::FuncOp>(&op))
if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
continue;
if (failed(bufferizeOp(&op, options, statistics)))
return failure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,23 @@ mlir::bufferization::insertTensorCopies(Operation *op,
const AnalysisState &state) {
IRRewriter rewriter(op->getContext());

WalkResult result = op->walk([&](Operation *op) {
auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op);
// It may be more efficient to walk in pre-order here, but the current
// implementation visits regions of ops even if they are not allowed or
// bufferizable, and existing tests rely on this behavior.
// For now, only exclude nested operations if they are in a different symbol
// table scope.
WalkResult result = op->walk([&](Operation *nestedOp) {
if (op->hasTrait<OpTrait::SymbolTable>() &&
nestedOp->getParentWithTrait<OpTrait::SymbolTable>() != op)
return WalkResult::skip();

auto bufferizableOp = state.getOptions().dynCastBufferizableOp(nestedOp);
if (!bufferizableOp)
return WalkResult::skip();

// Find inplacability conflicts and resolve them. (Typically with explicit
// tensor copies in the form of AllocTensorOps.)
rewriter.setInsertionPoint(op);
rewriter.setInsertionPoint(nestedOp);
if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
return WalkResult::interrupt();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -796,3 +796,17 @@ func.func @result_type_mismatch(%c: i1) -> tensor<5xf32> {
return %1 : tensor<5xf32>
}


// -----

// CHECK-LABEL: @outer_func({{.+}}: memref<
func.func @outer_func(%t: tensor<5xf32>) -> tensor<5xf32> {
return %t : tensor<5xf32>
}

module @inner_module {
// CHECK: @inner_func({{.+}}: tensor<5xf32> {bufferization.writable = false})
func.func @inner_func(%t: tensor<5xf32> {bufferization.writable = false}) -> tensor<5xf32> {
return %t : tensor<5xf32>
}
}
36 changes: 17 additions & 19 deletions mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -111,23 +111,21 @@ module attributes {transform.with_named_sequence} {
}
}

module {
// CHECK-LABEL: func @test_function(
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
%c0 = arith.constant 0 : index

// CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
// CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
// CHECK: memref.copy %[[A_memref]], %[[alloc]]
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
// CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>

// CHECK: return %[[res_tensor]]
return %0 : tensor<?xf32>
}
// CHECK-LABEL: func @test_function(
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
%c0 = arith.constant 0 : index

// CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
// CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
// CHECK: memref.copy %[[A_memref]], %[[alloc]]
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
// CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>

// CHECK: return %[[res_tensor]]
return %0 : tensor<?xf32>
}

// -----
Expand Down Expand Up @@ -222,8 +220,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%alloc_tensor = transform.structured.match ops{["bufferization.alloc_tensor"]} in %arg1
: (!transform.any_op) -> !transform.op<"bufferization.alloc_tensor">
%2, %new = transform.structured.bufferize_to_allocation %alloc_tensor
{alloc_op = "memref.alloca"}
%2, %new = transform.structured.bufferize_to_allocation %alloc_tensor
{alloc_op = "memref.alloca"}
: !transform.op<"bufferization.alloc_tensor">
transform.yield
}
Expand Down