Skip to content

Commit 3be54b0

Browse files
andrey-golubevermilindwalekar
authored andcommitted
[mlir][bufferization] Support custom types at function boundaries (#159766)
Support custom types (3/N): allow custom tensor and buffer types in function signatures and at call-sites. This is one of the major building blocks to move in the direction of module-level one-shot-bufferization support. To achieve this, `BufferizationOptions::FunctionArgTypeConverterFn` callback is converted to work with tensor-like and buffer-like types, instead of the builtin counterparts. The default behavior for builtins remains unchanged, while custom types by default go through `TensorLikeType::getBufferType()` which is a general conversion interface.
1 parent 4940643 commit 3be54b0

File tree

5 files changed

+149
-68
lines changed

5 files changed

+149
-68
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -260,12 +260,12 @@ struct BufferizationOptions {
260260
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
261261
/// Initializer function for analysis state.
262262
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
263-
/// Tensor -> MemRef type converter.
264-
/// Parameters: tensor type, memory space, func op, bufferization options
263+
/// Tensor-like -> Buffer-like type conversion.
264+
/// Parameters: tensor-like type, memory space, func op, bufferization options
265265
using FunctionArgTypeConverterFn =
266-
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
266+
std::function<BufferLikeType(TensorLikeType, Attribute memorySpace,
267267
func::FuncOp, const BufferizationOptions &)>;
268-
/// Tensor -> MemRef type converter.
268+
/// Tensor -> MemRef type conversion.
269269
/// Parameters: tensor type, memory space, bufferization options
270270
using UnknownTypeConverterFn = std::function<BaseMemRefType(
271271
TensorType, Attribute memorySpace, const BufferizationOptions &)>;
@@ -335,10 +335,12 @@ struct BufferizationOptions {
335335
/// predictable.
336336
void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption);
337337

338-
/// Type converter from tensors to memrefs. This type converter is used to
339-
/// determine bufferized function argument and result types. By default, a
340-
/// type converter that returns a memref type with a fully dynamic layout map
341-
/// is used.
338+
/// Type conversion from tensors to buffers. This type conversion is used to
339+
/// determine bufferized function argument and result types.
340+
///
341+
/// By default, if tensor is a (builtin) tensor type, it is converted to a
342+
/// memref type with a fully dynamic layout map; if tensor is a (generic)
343+
/// tensor-like type, it is converted using TensorLikeType::getBufferType().
342344
///
343345
/// If `bufferizeFunctionBoundaries` is not set, this function isn't used.
344346
FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr;
@@ -350,10 +352,9 @@ struct BufferizationOptions {
350352
/// If `bufferizeFunctionBoundaries` is not set, this flag has no effect.
351353
bool inferFunctionResultLayout = true;
352354

353-
/// Type converter from tensors to memrefs. This type converter is used if no
354-
/// memref type could be inferred during bufferization. By default, a type
355-
/// converter that returns a memref type with a fully dynamic layout map is
356-
/// used.
355+
/// Type conversion from tensors to memrefs. This type conversion is used if
356+
/// no memref type could be inferred during bufferization. By default, returns
357+
/// a memref type with a fully dynamic layout map.
357358
UnknownTypeConverterFn unknownTypeConverterFn = nullptr;
358359

359360
// Use during type conversion to determine the memory space for memref based

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -340,11 +340,21 @@ bool OpFilter::isOpAllowed(Operation *op) const {
340340
namespace {
341341

342342
/// Default function arg type converter: Use a fully dynamic layout map.
343-
BaseMemRefType
344-
defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
343+
BufferLikeType
344+
defaultFunctionArgTypeConverter(TensorLikeType type, Attribute memorySpace,
345345
func::FuncOp funcOp,
346346
const BufferizationOptions &options) {
347-
return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
347+
if (auto tensorType = mlir::dyn_cast<TensorType>(type)) {
348+
return cast<BufferLikeType>(
349+
getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace));
350+
}
351+
352+
// If not builtin, fallback to TensorLikeType::getBufferType()
353+
auto bufferType =
354+
type.getBufferType(options, [&]() { return funcOp->emitError(); });
355+
assert(succeeded(bufferType) &&
356+
"a valid buffer is always expected at function boundary");
357+
return *bufferType;
348358
}
349359
/// Default unknown type converter: Use a fully dynamic layout map.
350360
BaseMemRefType
@@ -387,14 +397,25 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
387397

388398
void BufferizationOptions::setFunctionBoundaryTypeConversion(
389399
LayoutMapOption layoutMapOption) {
390-
functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
400+
functionArgTypeConverterFn = [=](TensorLikeType type, Attribute memorySpace,
391401
func::FuncOp funcOp,
392402
const BufferizationOptions &options) {
393-
if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
394-
return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
395-
memorySpace);
396-
return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
397-
memorySpace);
403+
if (auto tensorType = mlir::dyn_cast<TensorType>(type)) {
404+
if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
405+
return cast<BufferLikeType>(
406+
bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
407+
memorySpace));
408+
return cast<BufferLikeType>(
409+
bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
410+
memorySpace));
411+
}
412+
413+
// If not builtin, fallback to TensorLikeType::getBufferType()
414+
auto bufferType =
415+
type.getBufferType(options, [&]() { return funcOp->emitError(); });
416+
assert(succeeded(bufferType) &&
417+
"a valid buffer is always expected at function boundary");
418+
return *bufferType;
398419
};
399420
inferFunctionResultLayout =
400421
layoutMapOption == LayoutMapOption::InferLayoutMap;

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
406406
// Compute the new signature.
407407
SmallVector<Type> newTypes;
408408
for (BlockArgument &bbArg : block->getArguments()) {
409-
auto tensorType = dyn_cast<TensorType>(bbArg.getType());
409+
auto tensorType = dyn_cast<TensorLikeType>(bbArg.getType());
410410
if (!tensorType) {
411411
newTypes.push_back(bbArg.getType());
412412
continue;

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 55 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -49,29 +49,47 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
4949
#endif // NDEBUG
5050
}
5151

52+
// Note: this is a local adaptor to unify TensorType and TensorLikeType code
53+
// paths that both work with BufferizationOptions.
54+
static mlir::Attribute
55+
getDefaultMemorySpace(const BufferizationOptions &options,
56+
TensorLikeType type) {
57+
if (auto tensorType = dyn_cast<TensorType>(type)) {
58+
return *options.defaultMemorySpaceFn(tensorType);
59+
}
60+
return nullptr;
61+
}
62+
5263
/// Return the index-th bufferized function argument type. This assumes that the
5364
/// specified argument is a tensor. If the tensor is ranked, a layout map may be
5465
/// specified by the user (as per `options.functionArgTypeConverterFn`).
55-
static BaseMemRefType
66+
static BufferLikeType
5667
getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
5768
const BufferizationOptions &options) {
58-
auto tensorType =
59-
dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index));
60-
assert(tensorType && "expected TensorType");
61-
62-
BaseMemRefType memrefType = options.functionArgTypeConverterFn(
63-
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
64-
65-
auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
66-
index, BufferizationDialect::kBufferLayoutAttrName);
67-
if (!layoutAttr)
68-
return memrefType;
69-
70-
auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
71-
assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
72-
return MemRefType::get(rankedMemrefType.getShape(),
73-
rankedMemrefType.getElementType(), layoutAttr,
74-
rankedMemrefType.getMemorySpace());
69+
auto type =
70+
dyn_cast<TensorLikeType>(funcOp.getFunctionType().getInput(index));
71+
assert(type && "expected TensorLikeType");
72+
73+
// Note: For builtin tensors there is additional logic related to layout.
74+
if (auto tensorType = dyn_cast<TensorType>(type)) {
75+
BufferLikeType memrefType = options.functionArgTypeConverterFn(
76+
type, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
77+
78+
auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
79+
index, BufferizationDialect::kBufferLayoutAttrName);
80+
if (!layoutAttr)
81+
return memrefType;
82+
83+
auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
84+
assert(rankedMemrefType &&
85+
"buffer layout not supported on unranked tensors");
86+
return cast<BufferLikeType>(MemRefType::get(
87+
rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
88+
layoutAttr, rankedMemrefType.getMemorySpace()));
89+
}
90+
91+
return options.functionArgTypeConverterFn(type, /*memSpace=*/nullptr, funcOp,
92+
options);
7593
}
7694

7795
/// Return the FuncOp called by `callOp`.
@@ -227,13 +245,13 @@ struct CallOpInterface
227245
FunctionType funcType = funcOp.getFunctionType();
228246
Type resultType =
229247
funcType.getResult(cast<OpResult>(value).getResultNumber());
230-
if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
231-
return cast<BufferLikeType>(bufferizedType);
248+
if (auto bufferizedType = dyn_cast<BufferLikeType>(resultType))
249+
return bufferizedType;
232250

233251
// Otherwise, call the type converter to compute the bufferized type.
234-
auto tensorType = cast<TensorType>(resultType);
252+
auto tensorType = cast<TensorLikeType>(resultType);
235253
return cast<BufferLikeType>(options.functionArgTypeConverterFn(
236-
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
254+
tensorType, getDefaultMemorySpace(options, tensorType), funcOp,
237255
options));
238256
}
239257

@@ -248,7 +266,7 @@ struct CallOpInterface
248266
SmallVector<Type> resultTypes;
249267
for (Value result : callOp.getResults()) {
250268
Type returnType = result.getType();
251-
if (!isa<TensorType>(returnType)) {
269+
if (!isa<TensorLikeType>(returnType)) {
252270
// Non-tensor values are returned.
253271
resultTypes.push_back(returnType);
254272
continue;
@@ -272,7 +290,7 @@ struct CallOpInterface
272290

273291
for (OpOperand &opOperand : callOp->getOpOperands()) {
274292
// Non-tensor operands are just copied.
275-
if (!isa<TensorType>(opOperand.get().getType())) {
293+
if (!isa<TensorLikeType>(opOperand.get().getType())) {
276294
newOperands.push_back(opOperand.get());
277295
continue;
278296
}
@@ -285,8 +303,8 @@ struct CallOpInterface
285303
Value buffer = *maybeBuffer;
286304

287305
// Caller / callee type mismatch is handled with castOrReallocMemRefValue.
288-
auto memRefType = funcType.getInput(opOperand.getOperandNumber());
289-
if (!isa<BaseMemRefType>(memRefType)) {
306+
auto bufferType = funcType.getInput(opOperand.getOperandNumber());
307+
if (!isa<BufferLikeType>(bufferType)) {
290308
// The called function was not bufferized yet. This can happen when
291309
// there cycles in the function call graph. Compute the bufferized
292310
// result type.
@@ -295,7 +313,7 @@ struct CallOpInterface
295313
funcOp.getArgument(opOperand.getOperandNumber()), options);
296314
if (failed(maybeBufferType))
297315
return failure();
298-
memRefType = *maybeBufferType;
316+
bufferType = *maybeBufferType;
299317
}
300318

301319
// Since we don't yet have a clear layout story, to_buffer may
@@ -304,8 +322,8 @@ struct CallOpInterface
304322
// that will either canonicalize away or fail compilation until we can do
305323
// something better. Insert a reallocation + copy if it cannot be
306324
// statically guaranteed that a direct cast would be valid.
307-
if (buffer.getType() != memRefType) {
308-
auto memrefDstType = dyn_cast<MemRefType>(memRefType);
325+
if (buffer.getType() != bufferType) {
326+
auto memrefDstType = dyn_cast<MemRefType>(bufferType);
309327
assert(memrefDstType &&
310328
"buffer layout not supported on unranked tensors");
311329
FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue(
@@ -368,7 +386,7 @@ struct FuncOpInterface
368386
static bool supportsUnstructuredControlFlow() { return true; }
369387

370388
bool hasTensorSemantics(Operation *op) const {
371-
auto isaTensor = llvm::IsaPred<TensorType>;
389+
auto isaTensor = llvm::IsaPred<TensorLikeType>;
372390

373391
// A function has tensor semantics if it has tensor arguments/results.
374392
auto funcOp = cast<FuncOp>(op);
@@ -404,8 +422,8 @@ struct FuncOpInterface
404422

405423
// Function arguments are special.
406424
if (bbArg.getOwner() == &funcOp.getBody().front())
407-
return cast<BufferLikeType>(
408-
getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
425+
return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
426+
options);
409427

410428
return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
411429
getBufferType(op, value, options, state, invocationStack);
@@ -428,7 +446,7 @@ struct FuncOpInterface
428446
SmallVector<Type> argTypes;
429447
for (const auto &it : llvm::enumerate(funcType.getInputs())) {
430448
Type argType = it.value();
431-
if (isa<TensorType>(argType)) {
449+
if (isa<TensorLikeType>(argType)) {
432450
argTypes.push_back(
433451
getBufferizedFunctionArgType(funcOp, it.index(), options));
434452
continue;
@@ -439,9 +457,9 @@ struct FuncOpInterface
439457
// Compute the result types.
440458
SmallVector<Type> retTypes;
441459
for (Type resultType : funcType.getResults()) {
442-
if (auto tensorType = dyn_cast<TensorType>(resultType)) {
443-
BaseMemRefType resultType = options.functionArgTypeConverterFn(
444-
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
460+
if (auto tensorType = dyn_cast<TensorLikeType>(resultType)) {
461+
BufferLikeType resultType = options.functionArgTypeConverterFn(
462+
tensorType, getDefaultMemorySpace(options, tensorType), funcOp,
445463
options);
446464
retTypes.push_back(resultType);
447465
continue;
@@ -471,7 +489,7 @@ struct FuncOpInterface
471489
SmallVector<Value> returnValues;
472490
for (auto [returnVal, bufferizedType] :
473491
llvm::zip_equal(returnOp->getOperands(), retTypes)) {
474-
auto tensorType = dyn_cast<TensorType>(returnVal.getType());
492+
auto tensorType = dyn_cast<TensorLikeType>(returnVal.getType());
475493
rewriter.setInsertionPoint(returnOp);
476494

477495
// If not a tensor type just forward it.

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -796,17 +796,58 @@ func.func @result_type_mismatch(%c: i1) -> tensor<5xf32> {
796796
return %1 : tensor<5xf32>
797797
}
798798

799-
800799
// -----
801800

802-
// CHECK-LABEL: @outer_func({{.+}}: memref<
803-
func.func @outer_func(%t: tensor<5xf32>) -> tensor<5xf32> {
804-
return %t : tensor<5xf32>
801+
// CHECK: func.func @custom_types(
802+
// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64>
803+
// CHECK-SAME: ) -> (!test.test_memref<[4, 8], f64>,
804+
// CHECK-SAME: !test.test_memref<[4, 8], f64>)
805+
func.func @custom_types(%arg: !test.test_tensor<[4, 4], f64>)
806+
-> (!test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64>) {
807+
// CHECK: %[[out1:.*]] = "test.dummy_memref_op"(%[[arg]]) :
808+
// CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64>
809+
%out1 = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
810+
-> !test.test_tensor<[4, 8], f64>
811+
812+
// CHECK: %[[alloc:.*]] = "test.create_memref_op"
813+
// CHECK: %[[out2:.*]] = "test.dummy_memref_op"(%[[alloc]])
814+
// CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64>
815+
%alloc = "test.create_tensor_op"() : () -> !test.test_tensor<[4, 4], f64>
816+
%out2 = "test.dummy_tensor_op"(%alloc) : (!test.test_tensor<[4, 4], f64>)
817+
-> !test.test_tensor<[4, 8], f64>
818+
819+
// CHECK: return %[[out1]], %[[out2]]
820+
return %out1, %out2 :
821+
!test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64>
805822
}
806823

807-
module @inner_module {
808-
// CHECK: @inner_func({{.+}}: tensor<5xf32> {bufferization.writable = false})
809-
func.func @inner_func(%t: tensor<5xf32> {bufferization.writable = false}) -> tensor<5xf32> {
810-
return %t : tensor<5xf32>
811-
}
824+
// -----
825+
826+
// CHECK: func.func @custom_types_foo(
827+
// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64>
828+
// CHECK-SAME: ) -> !test.test_memref<[4, 4], f64>
829+
func.func @custom_types_foo(%arg: !test.test_tensor<[4, 4], f64>)
830+
-> !test.test_tensor<[4, 4], f64> {
831+
// CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[arg]])
832+
%out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
833+
-> !test.test_tensor<[4, 4], f64>
834+
// CHECK: return %[[out]]
835+
return %out : !test.test_tensor<[4, 4], f64>
836+
}
837+
838+
// CHECK: func.func @custom_types_bar(
839+
// CHECK-SAME: %[[arg:.*]]: !test.test_memref<[4, 4], f64>
840+
// CHECK-SAME: ) -> !test.test_memref<[4, 8], f64>
841+
func.func @custom_types_bar(%arg: !test.test_tensor<[4, 4], f64>)
842+
-> !test.test_tensor<[4, 8], f64> {
843+
// CHECK: %[[call:.*]] = call @custom_types_foo(%[[arg]])
844+
%call = func.call @custom_types_foo(%arg) : (!test.test_tensor<[4, 4], f64>)
845+
-> !test.test_tensor<[4, 4], f64>
846+
847+
// CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[call]])
848+
%out = "test.dummy_tensor_op"(%call) : (!test.test_tensor<[4, 4], f64>)
849+
-> !test.test_tensor<[4, 8], f64>
850+
851+
// CHECK: return %[[out]]
852+
return %out : !test.test_tensor<[4, 8], f64>
812853
}

0 commit comments

Comments
 (0)