From 1122ffeddcbcd27838386b952849a29d792dc9f1 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 29 Oct 2024 09:51:11 +0100 Subject: [PATCH] [mlir][bufferization] Add support for non-unique `func.return` --- .../FuncBufferizableOpInterfaceImpl.h | 4 + .../FuncBufferizableOpInterfaceImpl.cpp | 79 ++++---- .../Transforms/OneShotModuleBufferize.cpp | 174 +++++++++++++----- .../one-shot-module-bufferize-analysis.mlir | 46 +++++ .../one-shot-module-bufferize-invalid.mlir | 22 +-- .../Transforms/one-shot-module-bufferize.mlir | 25 +++ 6 files changed, 236 insertions(+), 114 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h index 0b91d3d675b7c..e8e6226460ac7 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h @@ -12,6 +12,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "llvm/ADT/SmallVector.h" namespace mlir { class DialectRegistry; @@ -21,6 +22,9 @@ class FuncOp; } // namespace func namespace bufferization { +/// Helper function that returns all func.return ops in the given function. +SmallVector getReturnOps(func::FuncOp funcOp); + namespace func_ext { /// The state of analysis of a FuncOp. enum class FuncOpAnalysisState { NotAnalyzed, InProgress, Analyzed }; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index 11ed434f774a8..c45678f1e4b4d 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -19,6 +19,15 @@ #include namespace mlir { +/// Return all func.return ops in the given function. +SmallVector bufferization::getReturnOps(func::FuncOp funcOp) { + SmallVector result; + for (Block &b : funcOp.getBody()) + if (auto returnOp = dyn_cast(b.getTerminator())) + result.push_back(returnOp); + return result; +} + namespace bufferization { namespace func_ext { @@ -41,20 +50,6 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) { #endif // NDEBUG } -/// Return the unique ReturnOp that terminates `funcOp`. -/// Return nullptr if there is no such unique ReturnOp. -static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) { - func::ReturnOp returnOp; - for (Block &b : funcOp.getBody()) { - if (auto candidateOp = dyn_cast(b.getTerminator())) { - if (returnOp) - return nullptr; - returnOp = candidateOp; - } - } - return returnOp; -} - /// Return the index-th bufferized function argument type. This assumes that the /// specified argument is a tensor. If the tensor is ranked, a layout map may be /// specified by the user (as per `options.functionArgTypeConverterFn`). @@ -391,15 +386,6 @@ struct FuncOpInterface getBufferType(op, value, options, invocationStack); } - LogicalResult verifyAnalysis(Operation *op, - const AnalysisState &state) const { - auto funcOp = cast(op); - // TODO: func.func with multiple returns are not supported. - if (!getAssumedUniqueReturnOp(funcOp) && !funcOp.isExternal()) - return op->emitOpError("op without unique func.return is not supported"); - return success(); - } - /// Rewrite function bbArgs and return values into buffer form. This function /// bufferizes the function signature and the ReturnOp. When the entire /// function body has been bufferized, function return types can be switched @@ -446,41 +432,38 @@ struct FuncOpInterface return success(); } - // TODO: Support functions with multiple returns. - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - assert(returnOp && "expected func with single return op"); - assert(returnOp->getNumOperands() == retTypes.size() && - "incorrect number of return values"); - Location loc = returnOp.getLoc(); - // 1. Bufferize every block. for (Block &block : funcOp.getBody()) if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options))) return failure(); - // 2. Bufferize all operands of the return op. - SmallVector returnValues; - for (auto [returnVal, bufferizedType] : - llvm::zip_equal(returnOp->getOperands(), retTypes)) { - auto tensorType = dyn_cast(returnVal.getType()); - rewriter.setInsertionPoint(returnOp); - - // If not a tensor type just forward it. - if (!tensorType) { - returnValues.push_back(returnVal); - continue; + // 2. Bufferize the operands of the all return op. + for (func::ReturnOp returnOp : getReturnOps(funcOp)) { + assert(returnOp->getNumOperands() == retTypes.size() && + "incorrect number of return values"); + SmallVector returnValues; + for (auto [returnVal, bufferizedType] : + llvm::zip_equal(returnOp->getOperands(), retTypes)) { + auto tensorType = dyn_cast(returnVal.getType()); + rewriter.setInsertionPoint(returnOp); + + // If not a tensor type just forward it. + if (!tensorType) { + returnValues.push_back(returnVal); + continue; + } + + // Note: If `inferFunctionResultLayout = true`, casts are later folded + // away. + Value toMemrefOp = rewriter.create( + returnOp.getLoc(), bufferizedType, returnVal); + returnValues.push_back(toMemrefOp); } - // Note: If `inferFunctionResultLayout = true`, casts are later folded - // away. - Value toMemrefOp = rewriter.create( - loc, bufferizedType, returnVal); - returnValues.push_back(toMemrefOp); + returnOp.getOperandsMutable().assign(returnValues); } - returnOp.getOperandsMutable().assign(returnValues); - // 3. Set the new function type. funcOp.setType(newFuncType); return success(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index a492bcdd0f3e3..71ea0fd9d43cd 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -86,20 +86,6 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) { return state.addExtension(); } -/// Return the unique ReturnOp that terminates `funcOp`. -/// Return nullptr if there is no such unique ReturnOp. -static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) { - func::ReturnOp returnOp; - for (Block &b : funcOp.getBody()) { - if (auto candidateOp = dyn_cast(b.getTerminator())) { - if (returnOp) - return nullptr; - returnOp = candidateOp; - } - } - return returnOp; -} - namespace { /// Annotate IR with the results of the analysis. For testing purposes only. @@ -146,24 +132,80 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, return success(); } - // Support only single return-terminated block in the function. - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - assert(returnOp && "expected func with single return op"); - - for (OpOperand &returnVal : returnOp->getOpOperands()) - if (isa(returnVal.get().getType())) - for (BlockArgument bbArg : funcOp.getArguments()) - if (isa(bbArg.getType())) { - int64_t returnIdx = returnVal.getOperandNumber(); - int64_t bbArgIdx = bbArg.getArgNumber(); - if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) { - funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx; - if (state.getOptions().testAnalysisOnly) - annotateEquivalentReturnBbArg(returnVal, bbArg); + // Find all func.return ops. + SmallVector returnOps = getReturnOps(funcOp); + assert(!returnOps.empty() && "expected at least one ReturnOp"); + + // Build alias sets. Merge all aliases from all func.return ops. + for (BlockArgument bbArg : funcOp.getArguments()) { + if (isa(bbArg.getType())) { + int64_t bbArgIdx = bbArg.getArgNumber(); + // Store aliases in a set, so that we don't add the same alias twice. + SetVector aliases; + for (func::ReturnOp returnOp : returnOps) { + for (OpOperand &returnVal : returnOp->getOpOperands()) { + if (isa(returnVal.get().getType())) { + int64_t returnIdx = returnVal.getOperandNumber(); + if (state.areAliasingBufferizedValues(returnVal.get(), bbArg)) + aliases.insert(returnIdx); } - if (state.areAliasingBufferizedValues(returnVal.get(), bbArg)) - funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx); } + } + for (int64_t alias : aliases) + funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias); + } + } + + // Build equivalence sets. + // Helper function that finds an equivalent block argument index for the + // given OpOperand. Return std::nullopt if no equivalent block argument could + // be found. + auto findEquivalentBlockArgIdx = + [&](OpOperand &opOperand) -> std::optional { + Value v = opOperand.get(); + if (!isa(v.getType())) + return std::nullopt; + for (BlockArgument bbArg : funcOp.getArguments()) { + if (isa(bbArg.getType())) { + if (state.areEquivalentBufferizedValues(v, bbArg)) { + if (state.getOptions().testAnalysisOnly) + annotateEquivalentReturnBbArg(opOperand, bbArg); + return bbArg.getArgNumber(); + } + } + } + return std::nullopt; + }; + + int64_t numResults = returnOps.front()->getNumOperands(); + for (int64_t i = 0; i < numResults; ++i) { + // Find the equivalent block argument index for the i-th operand of the + // first func.return op. + std::optional maybeEquiv = + findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i)); + if (!maybeEquiv.has_value()) + continue; + int64_t bbArgIdx = *maybeEquiv; + bool allEquiv = true; + + // Check if all other func.return ops have the same equivalent block + // argument for the i-th operand. In contrast to aliasing information, + // which is just "merged", equivalence information must match across all + // func.return ops. + for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) { + std::optional maybeEquiv = + findEquivalentBlockArgIdx(returnOp->getOpOperand(i)); + if (maybeEquiv != bbArgIdx) { + allEquiv = false; + break; + } + } + + // All func.return ops have the same equivalent block argument for the i-th + // operand. + if (allEquiv) + funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx; + } return success(); } @@ -302,14 +344,6 @@ static LogicalResult getFuncOpsOrderedByCalls( // For each FuncOp, the number of func::CallOp it contains. DenseMap numberCallOpsContainedInFuncOp; WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult { - if (!funcOp.getBody().empty()) { - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - if (!returnOp) - return funcOp->emitError() - << "cannot bufferize a FuncOp with tensors and " - "without a unique ReturnOp"; - } - // Collect function calls and populate the caller map. numberCallOpsContainedInFuncOp[funcOp] = 0; return funcOp.walk([&](func::CallOp callOp) -> WalkResult { @@ -351,6 +385,42 @@ static LogicalResult getFuncOpsOrderedByCalls( return success(); } +/// Helper function that extracts the source from a memref.cast. If the given +/// value is not a memref.cast result, simply returns the given value. +static Value unpackCast(Value v) { + auto castOp = v.getDefiningOp(); + if (!castOp) + return v; + return castOp.getSource(); +} + +/// Helper function that returns the return types (skipping casts) of the given +/// func.return ops. This function returns as many types as the return ops have +/// operands. If the i-th operand is not the same for all func.return ops, then +/// the i-th returned type is an "empty" type. +static SmallVector getReturnTypes(SmallVector returnOps) { + assert(!returnOps.empty() && "expected at least one ReturnOp"); + int numOperands = returnOps.front()->getNumOperands(); + + // Helper function that unpacks memref.cast ops and returns the type. + auto getSourceType = [&](Value v) { return unpackCast(v).getType(); }; + + SmallVector result; + for (int i = 0; i < numOperands; ++i) { + // Get the type of the i-th operand of the first func.return ops. + Type t = getSourceType(returnOps.front()->getOperand(i)); + + // Check if all other func.return ops have a matching operand type. + for (int j = 1; j < static_cast(returnOps.size()); ++j) + if (getSourceType(returnOps[j]->getOperand(i)) != t) + t = Type(); + + result.push_back(t); + } + + return result; +} + /// Fold return values that are memref casts and update function return types. /// /// During FuncOp bufferization, the exact type of the returned memrefs (if any) @@ -359,21 +429,33 @@ static LogicalResult getFuncOpsOrderedByCalls( /// entire function body, a more concise memref type can potentially be used for /// the return type of the function. static void foldMemRefCasts(func::FuncOp funcOp) { + // There is nothing to do for bodiless ops. if (funcOp.getBody().empty()) return; - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - SmallVector resultTypes; + // Compute the common result types of all return ops. + SmallVector returnOps = getReturnOps(funcOp); + SmallVector resultTypes = getReturnTypes(returnOps); - for (OpOperand &operand : returnOp->getOpOperands()) { - if (auto castOp = operand.get().getDefiningOp()) { - operand.set(castOp.getSource()); - resultTypes.push_back(castOp.getSource().getType()); - } else { - resultTypes.push_back(operand.get().getType()); + // Remove direct casts. + for (func::ReturnOp returnOp : returnOps) { + for (OpOperand &operand : returnOp->getOpOperands()) { + // Bail if no common result type was found. + if (resultTypes[operand.getOperandNumber()]) { + operand.set(unpackCast(operand.get())); + } } } + // Fill in the missing result types that were not the same among all + // func.return ops. + for (int i = 0; i < static_cast(resultTypes.size()); ++i) { + if (resultTypes[i]) + continue; + resultTypes[i] = funcOp.getFunctionType().getResult(i); + } + + // Update the function type. auto newFuncType = FunctionType::get( funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes); funcOp.setType(newFuncType); diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir index 3f6d182b57c03..35b28f7ec8391 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir @@ -1360,3 +1360,49 @@ func.func @recursive_function(%a: tensor, %b: tensor) -> (tensor, tensor) -> (tensor, tensor) return %0#0, %0#1 : tensor, tensor } + +// ----- + +// CHECK-ALIAS-SETS-LABEL: func @multiple_returns( +func.func @multiple_returns(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> { + cf.cond_br %c, ^bb1, ^bb2 +^bb1: + return %t0 : tensor<5xf32> +^bb2: + return %t1 : tensor<5xf32> +} + +// CHECK-ALIAS-SETS: func @caller( +// CHECK-ALIAS-SETS-SAME: %{{.*}}: i1, %[[t0:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t1:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t2:.*]]: tensor<5xf32> {bufferization.access = "none"}) +func.func @caller(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) { + // Check that alias sets are computed correctly. + // CHECK-ALIAS-SETS: %[[result:.*]] = call @multiple_returns + // CHECK-ALIAS-SETS-SAME: {__inplace_operands_attr__ = ["none", "true", "true", "true"], + // CHECK-ALIAS-SETS-SAME: __opresult_alias_set_attr__ = [{{\[}}"%[[result]]", "%[[t0]]", "%[[t1]]"]]} + call @multiple_returns(%c, %t0, %t1, %t2) : (i1, tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) -> (tensor<5xf32>) + return +} + +// ----- + +// CHECK-ALIAS-SETS-LABEL: func @multiple_equivalent_returns( +func.func @multiple_equivalent_returns(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> { + cf.cond_br %c, ^bb1, ^bb2 +^bb1: + return %t0 : tensor<5xf32> +^bb2: + return %t0 : tensor<5xf32> +} + +// CHECK-ALIAS-SETS: func @caller( +// CHECK-ALIAS-SETS-SAME: %{{.*}}: i1, %[[t0:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t1:.*]]: tensor<5xf32> {bufferization.access = "none"}, %[[t2:.*]]: tensor<5xf32> {bufferization.access = "none"}) +func.func @caller(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> { + // Check that equivalence sets are computed correctly. + // CHECK-ALIAS-SETS: %[[result:.*]] = call @multiple_equivalent_returns + // CHECK-ALIAS-SETS-SAME: {__inplace_operands_attr__ = ["none", "true", "true", "true"], + // CHECK-ALIAS-SETS-SAME: __opresult_alias_set_attr__ = [{{\[}}"%[[result]]", "%[[t0]]"]]} + %r = call @multiple_equivalent_returns(%c, %t0, %t1, %t2) : (i1, tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) -> (tensor<5xf32>) + // CHECK-ALIAS-SETS-SAME: {__equivalent_func_args__ = [1], __inplace_operands_attr__ = ["true"]} %[[result]] : tensor<5xf32> + return %r : tensor<5xf32> +} + diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir index 28ce0735e47b7..d773e1af43a76 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir @@ -1,24 +1,5 @@ // RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="bufferize-function-boundaries=1" -split-input-file -verify-diagnostics -// expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}} -func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor, %t2 : tensor) - -> (tensor, tensor) -{ - cf.cond_br %cond1, ^bb1, ^bb2 - - ^bb1: - %T:2 = scf.if %cond2 -> (tensor, tensor) { - scf.yield %t1, %t2 : tensor, tensor - } else { - scf.yield %t2, %t1 : tensor, tensor - } - return %T#0, %T#1 : tensor, tensor - ^bb2: - return %t2, %t1 : tensor, tensor -} - -// ----- - func.func @scf_for(%A : tensor, %B : tensor {bufferization.writable = true}, %C : tensor<4xf32>, @@ -146,7 +127,8 @@ func.func @regression_scf_while() { // ----- -// expected-error @below{{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}} +// expected-error @below{{could not infer buffer type of block argument}} +// expected-error @below{{failed to bufferize op}} func.func @func_multiple_yields(%t: tensor<5xf32>) -> tensor<5xf32> { func.return %t : tensor<5xf32> ^bb1(%arg1 : tensor<5xf32>): diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir index 2b5b863143670..65557a68d243a 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -771,3 +771,28 @@ func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{ %0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>) return %0 : tensor<5xf32> } + +// ----- + +// The two func.return operands have different types after bufferization. Make +// sure that memref.cast ops are inserted. + +// CHECK-LABEL: func @result_type_mismatch({{.*}}) -> memref<5xf32, strided<[?], offset: ?>> +func.func @result_type_mismatch(%c: i1) -> tensor<5xf32> { + // CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<10xf32> + %t = tensor.empty() : tensor<10xf32> + cf.cond_br %c, ^bb1, ^bb2 +^bb1: + // CHECK: %[[m0:.*]] = memref.subview %[[alloc]][0] [5] [2] : memref<10xf32> to memref<5xf32, strided<[2]>> + // CHECK: %[[cast0:.*]] = memref.cast %[[m0]] : memref<5xf32, strided<[2]>> to memref<5xf32, strided<[?], offset: ?>> + %0 = tensor.extract_slice %t[0][5][2] : tensor<10xf32> to tensor<5xf32> + // CHECK: return %[[cast0]] : memref<5xf32, strided<[?], offset: ?> + return %0 : tensor<5xf32> +^bb2: + // CHECK: %[[m1:.*]] = memref.subview %[[alloc]][2] [5] [1] : memref<10xf32> to memref<5xf32, strided<[1], offset: 2>> + // CHECK: %[[cast1:.*]] = memref.cast %[[m1]] : memref<5xf32, strided<[1], offset: 2>> to memref<5xf32, strided<[?], offset: ?>> + %1 = tensor.extract_slice %t[2][5][1] : tensor<10xf32> to tensor<5xf32> + // CHECK: return %[[cast1]] : memref<5xf32, strided<[?], offset: ?>> + return %1 : tensor<5xf32> +} +