diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index 6e91d3b89a7c7..11ed434f774a8 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -207,11 +207,18 @@ struct CallOpInterface FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); - // The callee was already bufferized, so we can directly take the type from + // If the callee was already bufferized, we can directly take the type from // its signature. FunctionType funcType = funcOp.getFunctionType(); - return cast( - funcType.getResult(cast(value).getResultNumber())); + Type resultType = + funcType.getResult(cast(value).getResultNumber()); + if (auto bufferizedType = dyn_cast(resultType)) + return bufferizedType; + + // Otherwise, call the type converter to compute the bufferized type. + auto tensorType = cast(resultType); + return options.functionArgTypeConverterFn( + tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options); } /// All function arguments are writable. It is the responsibility of the @@ -261,6 +268,18 @@ struct CallOpInterface // Caller / callee type mismatch is handled with castOrReallocMemRefValue. auto memRefType = funcType.getInput(opOperand.getOperandNumber()); + if (!isa(memRefType)) { + // The called function was not bufferized yet. This can happen when + // there cycles in the function call graph. Compute the bufferized + // result type. + FailureOr maybeMemRefType = + bufferization::getBufferType( + funcOp.getArgument(opOperand.getOperandNumber()), options); + if (failed(maybeMemRefType)) + return failure(); + memRefType = *maybeMemRefType; + } + // Since we don't yet have a clear layout story, to_memref may // conservatively turn tensors into more dynamic memref than necessary. // If the memref type of the callee fails, introduce an extra memref.cast diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index 0a4072605c265..a492bcdd0f3e3 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -285,14 +285,17 @@ static bool hasTensorSignature(func::FuncOp funcOp) { } /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by -/// callee-caller order (i.e. callees without callers first). +/// callee-caller order (i.e., callees without callers first). Store all +/// remaining functions (i.e., the ones that call each other recursively) in +/// `remainingFuncOps`. +/// /// Store the map of FuncOp to all its callers in `callerMap`. -/// Return `failure()` if a cycle of calls is detected or if we are unable to -/// retrieve the called FuncOp from any func::CallOp. -static LogicalResult -getFuncOpsOrderedByCalls(ModuleOp moduleOp, - SmallVectorImpl &orderedFuncOps, - FuncCallerMap &callerMap) { +/// +/// Return `failure()` if we are unable to retrieve the called FuncOp from +/// any func::CallOp. +static LogicalResult getFuncOpsOrderedByCalls( + ModuleOp moduleOp, SmallVectorImpl &orderedFuncOps, + SmallVectorImpl &remainingFuncOps, FuncCallerMap &callerMap) { // For each FuncOp, the set of functions called by it (i.e. the union of // symbols of all nested func::CallOp). DenseMap> calledBy; @@ -326,19 +329,25 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp, }); if (res.wasInterrupted()) return failure(); + // Iteratively remove function operations that do not call any of the - // functions remaining in the callCounter map and add them to the worklist. + // functions remaining in the callCounter map and add them to ordered list. while (!numberCallOpsContainedInFuncOp.empty()) { auto it = llvm::find_if(numberCallOpsContainedInFuncOp, [](auto entry) { return entry.getSecond() == 0; }); if (it == numberCallOpsContainedInFuncOp.end()) - return moduleOp.emitOpError( - "expected callgraph to be free of circular dependencies."); + break; orderedFuncOps.push_back(it->getFirst()); for (auto callee : calledBy[it->getFirst()]) numberCallOpsContainedInFuncOp[callee]--; numberCallOpsContainedInFuncOp.erase(it); } + + // Put all other functions in the list of remaining functions. These are + // functions that call each other circularly. + for (auto it : numberCallOpsContainedInFuncOp) + remainingFuncOps.push_back(it.first); + return success(); } @@ -378,16 +387,23 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, "expected that function boundary bufferization is activated"); FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state); - // A list of functions in the order in which they are analyzed + bufferized. + // A list of non-circular functions in the order in which they are analyzed + // and bufferized. SmallVector orderedFuncOps; + // A list of all other functions. I.e., functions that call each other + // recursively. For these, we analyze the function body but not the function + // boundary. + SmallVector remainingFuncOps; // A mapping of FuncOps to their callers. FuncCallerMap callerMap; - if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) + if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, + remainingFuncOps, callerMap))) return failure(); - // Analyze ops. + // Analyze functions in order. Starting with functions that are not calling + // any other functions. for (func::FuncOp funcOp : orderedFuncOps) { if (!state.getOptions().isOpAllowed(funcOp)) continue; @@ -411,6 +427,25 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed; } + // Analyze all other functions. All function boundary analyses are skipped. + for (func::FuncOp funcOp : remainingFuncOps) { + if (!state.getOptions().isOpAllowed(funcOp)) + continue; + + // Gather equivalence info for CallOps. + equivalenceAnalysis(funcOp, state, funcState); + + // Analyze funcOp. + if (failed(analyzeOp(funcOp, state, statistics))) + return failure(); + + // TODO: We currently skip all function argument analyses for functions + // that call each other circularly. These analyses do not support recursive + // calls yet. The `BufferizableOpInterface` implementations of `func` + // dialect ops return conservative results in the absence of analysis + // information. + } + return success(); } @@ -429,14 +464,26 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( "expected that function boundary bufferization is activated"); IRRewriter rewriter(moduleOp.getContext()); - // A list of functions in the order in which they are analyzed + bufferized. + // A list of non-circular functions in the order in which they are analyzed + // and bufferized. SmallVector orderedFuncOps; + // A list of all other functions. I.e., functions that call each other + // recursively. For these, we analyze the function body but not the function + // boundary. + SmallVector remainingFuncOps; // A mapping of FuncOps to their callers. FuncCallerMap callerMap; - if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) + // Try to bufferize functions in calling order. I.e., first bufferize + // functions that do not call other functions. This allows us to infer + // accurate buffer types for function return values. Functions that call + // each other recursively are bufferized in an unspecified order at the end. + // We may use unnecessarily "complex" (in terms of layout map) buffer types. + if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, + remainingFuncOps, callerMap))) return failure(); + llvm::append_range(orderedFuncOps, remainingFuncOps); // Bufferize functions. for (func::FuncOp funcOp : orderedFuncOps) { diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir index 42d9cc00d3ff5..3f6d182b57c03 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir @@ -1348,3 +1348,15 @@ func.func @private_func_aliasing(%t: tensor) -> f32 { %2 = tensor.extract %1[%c0] : tensor<6xf32> return %2 : f32 } + +// ----- + +// CHECK-LABEL: func @recursive_function +func.func @recursive_function(%a: tensor, %b: tensor) -> (tensor, tensor) { + // The analysis does not support recursive function calls and is conservative + // around them. + // CHECK: call @recursive_function + // CHECK-SAME: {__inplace_operands_attr__ = ["false", "false"]} + %0:2 = call @recursive_function(%a, %b) : (tensor, tensor) -> (tensor, tensor) + return %0#0, %0#1 : tensor, tensor +} diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir index 2829eafb7c1c5..28ce0735e47b7 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir @@ -19,20 +19,6 @@ func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor, %t2 : tensor // ----- -// expected-error @-3 {{expected callgraph to be free of circular dependencies}} - -func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> { - %0 = call @bar(%t) : (tensor<5xf32>) -> (tensor<5xf32>) - return %0 : tensor<5xf32> -} - -func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{ - %0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>) - return %0 : tensor<5xf32> -} - -// ----- - func.func @scf_for(%A : tensor, %B : tensor {bufferization.writable = true}, %C : tensor<4xf32>, diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir index d31b43477beb9..2b5b863143670 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -722,3 +722,52 @@ func.func @bar(%t: tensor<5xf32>, %m: memref<5xf32>) -> memref<5xf32> { %0 = func.call @foo(%m) : (memref<5xf32>) -> (memref<5xf32>) return %0 : memref<5xf32> } + +// ----- + +// A recursive function. + +// CHECK-LABEL: func.func @foo( +// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> { +func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> { + // We are conservative around recursive functions. The analysis cannot handle + // them, so we have to assume the op operand of the call op bufferizes to a + // memory read and write. This causes a copy in this test case. + // CHECK: %[[copy:.*]] = memref.alloc() {alignment = 64 : i64} : memref<5xf32> + // CHECK: memref.copy %[[arg0]], %[[copy]] + // CHECK: %[[cast:.*]] = memref.cast %[[copy]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>> + // CHECK: %[[call:.*]] = call @foo(%[[cast]]) + %0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>) + + // CHECK: memref.load %[[arg0]] + %c0 = arith.constant 0 : index + %extr = tensor.extract %t[%c0] : tensor<5xf32> + vector.print %extr : f32 + + // CHECK: return %[[call]] + return %0 : tensor<5xf32> +} + +// ----- + +// Two functions calling each other recursively. + +// CHECK-LABEL: func.func @foo( +// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> { +// CHECK: %[[call:.*]] = call @bar(%[[arg0]]) : (memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> +// CHECK: return %[[call]] +// CHECK: } +func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> { + %0 = call @bar(%t) : (tensor<5xf32>) -> (tensor<5xf32>) + return %0 : tensor<5xf32> +} + +// CHECK-LABEL: func.func @bar( +// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> { +// CHECK: %[[call:.*]] = call @foo(%[[arg0]]) : (memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> +// CHECK: return %[[call]] +// CHECK: } +func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{ + %0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>) + return %0 : tensor<5xf32> +}