Skip to content

Commit 90e46a4

Browse files
[mlir][bufferization] Support custom types (1/N) (llvm#142986)
Following the addition of TensorLike and BufferLike type interfaces (see 00eaff3), introduce minimal changes required to bufferize a custom tensor operation into a custom buffer operation. To achieve this, new interface methods are added to TensorLike type interface that abstract away the differences between existing (tensor -> memref) and custom conversions. The scope of the changes is intentionally limited (for example, BufferizableOpInterface is untouched) in order to first understand the basics and reach consensus design-wise. --- Notable changes: * mlir::bufferization::getBufferType() returns BufferLikeType (instead of BaseMemRefType) * ToTensorOp / ToBufferOp operate on TensorLikeType / BufferLikeType. Operation argument "memref" renamed to "buffer" * ToTensorOp's tensor type inferring builder is dropped (users now need to provide the tensor type explicitly)
1 parent c743086 commit 90e46a4

27 files changed

+381
-129
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <optional>
1818

1919
#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
20+
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
2021

2122
namespace mlir {
2223
class OpBuilder;
@@ -600,7 +601,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
600601
/// IR, this function can be used.
601602
///
602603
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
603-
FailureOr<BaseMemRefType> getBufferType(Value value,
604+
FailureOr<BufferLikeType> getBufferType(Value value,
604605
const BufferizationOptions &options);
605606

606607
/// Return the buffer type for a given Value (tensor) after bufferization
@@ -613,7 +614,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
613614
/// IR, this function can be used.
614615
///
615616
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
616-
FailureOr<BaseMemRefType> getBufferType(Value value,
617+
FailureOr<BufferLikeType> getBufferType(Value value,
617618
const BufferizationOptions &options,
618619
SmallVector<Value> &invocationStack);
619620

@@ -721,6 +722,19 @@ AliasingValueList unknownGetAliasingValues(OpOperand &opOperand);
721722
/// This is the default implementation of
722723
/// BufferizableOpInterface::hasTensorSemantics
723724
bool defaultHasTensorSemantics(Operation *op);
725+
726+
/// This is a helper function used when buffer type is guaranteed to be memref.
727+
/// It performs two actions: failure state checking and an explicit llvm::cast<>
728+
/// from the buffer-like type interface to a BaseMemRefType. This allows easier
729+
/// management of differences in C++ types at the API boundaries. Valid buffer
730+
/// type is casted to the memref type. Otherwise, the failure state is
731+
/// propagated i.e. asMemRefType(mlir::failure()) returns mlir::failure().
732+
FailureOr<BaseMemRefType> asMemRefType(FailureOr<BufferLikeType> bufferType);
733+
734+
/// This function is a free-standing helper that relies on
735+
/// bufferization::TensorLikeTypeInterface to verify the types in tensor and
736+
/// buffer worlds match.
737+
bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer);
724738
} // namespace detail
725739

726740
} // namespace bufferization

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
1313
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
1414
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
15+
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
1516
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
1617
include "mlir/Interfaces/DestinationStyleOpInterface.td"
1718
include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -383,20 +384,31 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
383384
// ToTensorOp
384385
//===----------------------------------------------------------------------===//
385386

387+
class Bufferization_TensorAndBufferMatch<string tensor, string buffer> : PredOpTrait<
388+
"specified tensor and buffer types match",
389+
CPred<
390+
"::mlir::bufferization::detail::typesMatchAfterBufferization("
391+
"$_op, $" # tensor # ", $" # buffer #")"
392+
>
393+
>;
394+
386395
def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
387396
BufferizableOpInterface,
388397
SameOperandsAndResultShape,
389398
SameOperandsAndResultElementType,
390-
AllElementTypesMatch<["memref", "result"]>
399+
Bufferization_TensorAndBufferMatch<"result", "buffer">
391400
]> {
392-
let summary = "create a tensor from a `memref`";
401+
let summary = "create a buffer-like type from a tensor-like type";
393402
let description = [{
394-
An operation that creates a tensor from a `memref`. The result value is a
395-
tensor whose shape and element type match the memref operand.
403+
An operation that creates a tensor from a buffer. The result value is a
404+
tensor-like type that must match the corresponding buffer-like operand as
405+
per TensorLikeType::verifyCompatibleBufferType(). For builtins (TensorType
406+
and BaseMemRefType), this means that shapes and element types match between
407+
the tensor and the buffer.
396408

397409
The opposite of this op is `to_buffer`. Together, these two ops are
398410
useful for source/target materializations when doing type conversions
399-
involving tensors and memrefs.
411+
involving tensors and buffers.
400412

401413
Example:
402414

@@ -438,19 +450,16 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
438450
away. However, such IR is no longer bufferizable with One-Shot Bufferize.
439451
}];
440452

441-
let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
453+
let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
442454
"the reference to load from",
443-
[MemReadAt<0, FullEffect>]>:$memref,
455+
[MemReadAt<0, FullEffect>]>:$buffer,
444456
UnitAttr:$restrict, UnitAttr:$writable);
445-
let results = (outs AnyTensor:$result);
457+
let results = (outs Bufferization_TensorLikeTypeInterface:$result);
446458

447459
let extraClassDeclaration = [{
448460
/// The result of a to_tensor is always a tensor.
449-
TensorType getType() {
450-
Type resultType = getResult().getType();
451-
if (::llvm::isa<TensorType>(resultType))
452-
return ::llvm::cast<TensorType>(resultType);
453-
return {};
461+
::mlir::bufferization::TensorLikeType getType() {
462+
return getResult().getType();
454463
}
455464

456465
//===------------------------------------------------------------------===//
@@ -468,22 +477,15 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
468477
FailureOr<BaseMemRefType> getBufferType(
469478
Value value, const BufferizationOptions &options,
470479
SmallVector<Value> &invocationStack) {
471-
return ::llvm::cast<BaseMemRefType>(getMemref().getType());
480+
return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
472481
}
473482
}];
474483

475484
let assemblyFormat = [{
476-
$memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
477-
`:` type($memref) `to` type($result)
485+
$buffer (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
486+
`:` type($buffer) `to` type($result)
478487
}];
479488

480-
let builders = [
481-
OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
482-
auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
483-
build($_builder, $_state, rtt, memref, restrict, writeable);
484-
}]>
485-
];
486-
487489
let hasCanonicalizer = 1;
488490
let hasFolder = 1;
489491
}
@@ -498,10 +500,9 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
498500
SameOperandsAndResultShape,
499501
SameOperandsAndResultElementType,
500502
Pure,
501-
AllShapesMatch<["memref", "tensor"]>,
502-
AllElementTypesMatch<["memref", "tensor"]>
503+
Bufferization_TensorAndBufferMatch<"tensor", "buffer">
503504
]> {
504-
let summary = "cast a tensor to memref";
505+
let summary = "cast a tensor-like type to buffer-like type";
505506
let description = [{
506507
An operation that returns the future buffer of a `tensor`.
507508

@@ -519,8 +520,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
519520
the returned buffer) will not be written to.
520521
}];
521522

522-
let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
523-
let results = (outs AnyRankedOrUnrankedMemRef:$memref);
523+
let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor, UnitAttr:$read_only);
524+
let results = (outs Bufferization_BufferLikeTypeInterface:$buffer);
524525

525526
let extraClassDeclaration = [{
526527
//===------------------------------------------------------------------===//
@@ -554,7 +555,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
554555
}];
555556

556557
let assemblyFormat = [{
557-
$tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref)
558+
$tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($buffer)
558559
}];
559560

560561
let hasFolder = 1;

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@
1313
// Bufferization Type Interfaces
1414
//===----------------------------------------------------------------------===//
1515

16+
#include "mlir/IR/Diagnostics.h"
17+
#include "mlir/IR/Types.h"
18+
19+
namespace mlir::bufferization {
20+
struct BufferizationOptions;
21+
class BufferizationState;
22+
class BufferLikeType;
23+
} // namespace mlir::bufferization
24+
1625
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc"
1726

1827
#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,30 @@ def Bufferization_TensorLikeTypeInterface
2121
let description = [{
2222
Indicates that this type is a tensor type (similarly to a MLIR builtin
2323
tensor) for bufferization purposes.
24-
25-
The interface currently has no methods as it is used by types to opt into
26-
being supported by the bufferization procedures.
2724
}];
25+
26+
let methods = [
27+
InterfaceMethod<[{
28+
Returns a BufferLike type for this TensorLike type.
29+
}],
30+
/*retTy=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
31+
/*methodName=*/"getBufferType",
32+
/*args=*/(ins
33+
"const ::mlir::bufferization::BufferizationOptions &":$options,
34+
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError
35+
)
36+
>,
37+
InterfaceMethod<[{
38+
Returns whether a BufferLike type is compatible to this TensorLike type.
39+
The BufferLike type is assumed to be created by getBufferType().
40+
}],
41+
/*retTy=*/"::mlir::LogicalResult",
42+
/*methodName=*/"verifyCompatibleBufferType",
43+
/*args=*/(ins
44+
"::mlir::bufferization::BufferLikeType":$bufferType,
45+
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
46+
>
47+
];
2848
}
2949

3050
def Bufferization_BufferLikeTypeInterface

mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
6565
// The operand was already bufferized. Take its type directly.
6666
callerType = memrefType;
6767
} else {
68-
FailureOr<BaseMemRefType> maybeCallerType =
68+
FailureOr<BufferLikeType> maybeCallerType =
6969
bufferization::getBufferType(opOperand->get(), options,
7070
invocationStack);
7171
if (failed(maybeCallerType))
7272
return failure();
73-
callerType = *maybeCallerType;
73+
assert(isa<BaseMemRefType>(*maybeCallerType) && "expected memref type");
74+
callerType = cast<BaseMemRefType>(*maybeCallerType);
7475
}
7576

7677
if (!bufferType) {

mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ struct SelectOpInterface
159159
// buffers have different types, they differ only in their layout map. Cast
160160
// both of them to the most dynamic MemRef type.
161161
if (trueBuffer.getType() != falseBuffer.getType()) {
162-
auto targetType =
163-
bufferization::getBufferType(selectOp.getResult(), options);
162+
auto targetType = bufferization::detail::asMemRefType(
163+
bufferization::getBufferType(selectOp.getResult(), options));
164164
if (failed(targetType))
165165
return failure();
166166
if (trueBuffer.getType() != *targetType)
@@ -181,10 +181,12 @@ struct SelectOpInterface
181181
SmallVector<Value> &invocationStack) const {
182182
auto selectOp = cast<arith::SelectOp>(op);
183183
assert(value == selectOp.getResult() && "invalid value");
184-
auto trueType = bufferization::getBufferType(selectOp.getTrueValue(),
185-
options, invocationStack);
186-
auto falseType = bufferization::getBufferType(selectOp.getFalseValue(),
187-
options, invocationStack);
184+
auto trueType =
185+
bufferization::detail::asMemRefType(bufferization::getBufferType(
186+
selectOp.getTrueValue(), options, invocationStack));
187+
auto falseType =
188+
bufferization::detail::asMemRefType(bufferization::getBufferType(
189+
selectOp.getFalseValue(), options, invocationStack));
188190
if (failed(trueType) || failed(falseType))
189191
return failure();
190192
if (*trueType == *falseType)

0 commit comments

Comments
 (0)