Skip to content

Commit fe90c52

Browse files
[mlir][bufferization] Use TensorLike, BufferLike type interfaces
The general idea is to replace most of the places that rely on builtin's TensorType / BaseMemRefType with the newly added type interfaces. Thus far, do the bare minimum: refactor (almost) "blindly" the API of the dialect and options, leaving most of the logic "as is". The exceptions are the bufferization.{to_tensor, to_memref} ops that act as "glue" when bufferizing neighbouring operations and the enclosing functions.
1 parent 23e3cbb commit fe90c52

26 files changed

+370
-191
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <optional>
1818

1919
#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
20+
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
2021

2122
namespace mlir {
2223
class OpBuilder;
@@ -259,18 +260,18 @@ struct BufferizationOptions {
259260
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
260261
/// Initializer function for analysis state.
261262
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
262-
/// Tensor -> MemRef type converter.
263-
/// Parameters: tensor type, memory space, func op, bufferization options
263+
/// TensorLike -> BufferLike type converter.
264+
/// Parameters: tensor like type, memory space, func op, bufferization options
264265
using FunctionArgTypeConverterFn =
265-
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
266+
std::function<BufferLikeType(TensorLikeType, Attribute memorySpace,
266267
func::FuncOp, const BufferizationOptions &)>;
267-
/// Tensor -> MemRef type converter.
268+
/// TensorLike -> BufferLike type converter.
268269
/// Parameters: Value, memory space, bufferization options
269-
using UnknownTypeConverterFn = std::function<BaseMemRefType(
270+
using UnknownTypeConverterFn = std::function<BufferLikeType(
270271
Value, Attribute memorySpace, const BufferizationOptions &)>;
271272
// Produce a MemorySpace attribute from a tensor type
272273
using DefaultMemorySpaceFn =
273-
std::function<std::optional<Attribute>(TensorType t)>;
274+
std::function<std::optional<Attribute>(TensorLikeType t)>;
274275

275276
BufferizationOptions();
276277

@@ -360,7 +361,7 @@ struct BufferizationOptions {
360361
// Returning std::nullopt will cause bufferization to fail (useful to indicate
361362
// failure to determine memory space for a tensor type).
362363
DefaultMemorySpaceFn defaultMemorySpaceFn =
363-
[](TensorType t) -> std::optional<Attribute> { return Attribute(); };
364+
[](TensorLikeType t) -> std::optional<Attribute> { return Attribute(); };
364365

365366
/// If set to `true`, the analysis is skipped. A buffer is copied before every
366367
/// write. This flag cannot be used together with `testAnalysisOnly = true`.
@@ -600,7 +601,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
600601
/// IR, this function can be used.
601602
///
602603
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
603-
FailureOr<BaseMemRefType> getBufferType(Value value,
604+
FailureOr<BufferLikeType> getBufferType(Value value,
604605
const BufferizationOptions &options);
605606

606607
/// Return the buffer type for a given Value (tensor) after bufferization
@@ -613,7 +614,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
613614
/// IR, this function can be used.
614615
///
615616
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
616-
FailureOr<BaseMemRefType> getBufferType(Value value,
617+
FailureOr<BufferLikeType> getBufferType(Value value,
617618
const BufferizationOptions &options,
618619
SmallVector<Value> &invocationStack);
619620

@@ -693,7 +694,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
693694
/// This is the default implementation of
694695
/// BufferizableOpInterface::getBufferType. Should not be called from other
695696
/// places.
696-
FailureOr<BaseMemRefType>
697+
FailureOr<BufferLikeType>
697698
defaultGetBufferType(Value value, const BufferizationOptions &options,
698699
SmallVector<Value> &invocationStack);
699700

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
518518
Note: This interface method should never be called directly from user
519519
code. Always use `bufferization::getBufferType`.
520520
}],
521-
/*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
521+
/*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
522522
/*methodName=*/"getBufferType",
523523
/*args=*/(ins "::mlir::Value":$value,
524524
"const ::mlir::bufferization::BufferizationOptions &":$options,

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
1313
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
1414
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
1515
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
16+
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
1617
include "mlir/Interfaces/DestinationStyleOpInterface.td"
1718
include "mlir/Interfaces/InferTypeOpInterface.td"
1819
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -109,7 +110,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
109110
AliasingValueList getAliasingValues(
110111
OpOperand &opOperand, const AnalysisState &state);
111112

112-
FailureOr<BaseMemRefType> getBufferType(
113+
FailureOr<BufferLikeType> getBufferType(
113114
Value value, const BufferizationOptions &options,
114115
SmallVector<Value> &invocationStack);
115116

@@ -438,11 +439,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
438439
away. However, such IR is no longer bufferizable with One-Shot Bufferize.
439440
}];
440441

441-
let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
442+
let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
442443
"the reference to load from",
443444
[MemReadAt<0, FullEffect>]>:$memref,
444445
UnitAttr:$restrict, UnitAttr:$writable);
445-
let results = (outs AnyTensor:$result);
446+
let results = (outs Bufferization_TensorLikeTypeInterface:$result);
446447

447448
let extraClassDeclaration = [{
448449
/// The result of a to_tensor is always a tensor.
@@ -465,10 +466,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
465466

466467
bool isWritable(Value value, const AnalysisState &state);
467468

468-
FailureOr<BaseMemRefType> getBufferType(
469+
FailureOr<BufferLikeType> getBufferType(
469470
Value value, const BufferizationOptions &options,
470471
SmallVector<Value> &invocationStack) {
471-
return ::llvm::cast<BaseMemRefType>(getMemref().getType());
472+
return ::llvm::cast<BufferLikeType>(getMemref().getType());
472473
}
473474
}];
474475

@@ -493,6 +494,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
493494
// ToMemrefOp
494495
//===----------------------------------------------------------------------===//
495496

497+
// TODO: rename to "to_buffer"
496498
def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
497499
BufferizableOpInterface,
498500
SameOperandsAndResultShape,
@@ -519,8 +521,9 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
519521
the returned buffer) will not be written to.
520522
}];
521523

522-
let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
523-
let results = (outs AnyRankedOrUnrankedMemRef:$memref);
524+
let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor,
525+
UnitAttr:$read_only);
526+
let results = (outs Bufferization_BufferLikeTypeInterface:$memref);
524527

525528
let extraClassDeclaration = [{
526529
//===------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// Bufferization Type Interfaces
1414
//===----------------------------------------------------------------------===//
1515

16+
#include "mlir/IR/Attributes.h" // mlir::Attribute
1617
#include "mlir/IR/Types.h"
1718

1819
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc"

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,17 @@ def Bufferization_BufferLikeTypeInterface
3333
let description = [{
3434
Indicates that this type is a buffer type (similarly to a MLIR builtin
3535
memref) for bufferization purposes.
36-
37-
The interface currently has no methods as it is used by types to opt into
38-
being supported by the bufferization procedures.
3936
}];
37+
38+
let methods = [
39+
InterfaceMethod<
40+
/*desc=*/[{
41+
Returns the memory space in which data referred to by this buffer resides.
42+
}],
43+
/*retType=*/"::mlir::Attribute",
44+
/*methodName=*/"getMemorySpace"
45+
>,
46+
];
4047
}
4148

4249
#endif // BUFFERIZATION_TYPE_INTERFACES

mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
3232
struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
3333
: public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
3434

35-
FailureOr<BaseMemRefType>
35+
FailureOr<BufferLikeType>
3636
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
3737
SmallVector<Value> &invocationStack) const {
3838
// Note: The user may want to override this function for OpResults in
@@ -46,7 +46,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
4646
// operand types of all forwarded values. If these are all the same type,
4747
// take that type. Otherwise, take only the memory space and fall back to a
4848
// buffer type with a fully dynamic layout map.
49-
BaseMemRefType bufferType;
49+
BufferLikeType bufferType;
5050
auto tensorType = cast<TensorType>(value.getType());
5151
for (OpOperand *opOperand :
5252
detail::getCallerOpOperands(cast<BlockArgument>(value))) {
@@ -59,13 +59,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
5959
continue;
6060

6161
// Compute the bufferized type of the forwarded operand.
62-
BaseMemRefType callerType;
63-
if (auto memrefType =
64-
dyn_cast<BaseMemRefType>(opOperand->get().getType())) {
62+
BufferLikeType callerType;
63+
if (auto bufferLikeType =
64+
dyn_cast<BufferLikeType>(opOperand->get().getType())) {
6565
// The operand was already bufferized. Take its type directly.
66-
callerType = memrefType;
66+
callerType = bufferLikeType;
6767
} else {
68-
FailureOr<BaseMemRefType> maybeCallerType =
68+
FailureOr<BufferLikeType> maybeCallerType =
6969
bufferization::getBufferType(opOperand->get(), options,
7070
invocationStack);
7171
if (failed(maybeCallerType))
@@ -86,14 +86,20 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
8686
// of the earlier forwarded operands, fall back to a buffer type with a
8787
// fully dynamic layout map.
8888
#ifndef NDEBUG
89+
assert(mlir::isa<BaseMemRefType>(bufferType) &&
90+
mlir::isa<BaseMemRefType>(callerType) && "expected memrefs");
91+
auto memrefType = mlir::cast<BaseMemRefType>(bufferType);
92+
auto callerMemrefType = mlir::cast<BaseMemRefType>(callerType);
93+
8994
if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
90-
assert(bufferType.hasRank() && callerType.hasRank() &&
95+
assert(memrefType.hasRank() && callerMemrefType.hasRank() &&
9196
"expected ranked memrefs");
92-
assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(),
93-
rankedTensorType.getShape()}) &&
94-
"expected same shape");
97+
assert(
98+
llvm::all_equal({memrefType.getShape(), callerMemrefType.getShape(),
99+
rankedTensorType.getShape()}) &&
100+
"expected same shape");
95101
} else {
96-
assert(!bufferType.hasRank() && !callerType.hasRank() &&
102+
assert(!memrefType.hasRank() && !callerMemrefType.hasRank() &&
97103
"expected unranked memrefs");
98104
}
99105
#endif // NDEBUG
@@ -102,8 +108,9 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
102108
return op->emitOpError("incoming operands of block argument have "
103109
"inconsistent memory spaces");
104110

105-
bufferType = getMemRefTypeWithFullyDynamicLayout(
106-
tensorType, bufferType.getMemorySpace());
111+
bufferType =
112+
mlir::cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
113+
tensorType, bufferType.getMemorySpace()));
107114
}
108115

109116
if (!bufferType)

mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ struct ConstantOpInterface
2626
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
2727
const BufferizationOptions &options) const {
2828
auto constantOp = cast<arith::ConstantOp>(op);
29-
auto type = dyn_cast<RankedTensorType>(constantOp.getType());
29+
auto type = dyn_cast<TensorLikeType>(constantOp.getType());
3030

3131
// Only ranked tensors are supported.
3232
if (!type)
@@ -176,7 +176,7 @@ struct SelectOpInterface
176176
return success();
177177
}
178178

179-
FailureOr<BaseMemRefType>
179+
FailureOr<bufferization::BufferLikeType>
180180
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
181181
SmallVector<Value> &invocationStack) const {
182182
auto selectOp = cast<arith::SelectOp>(op);
@@ -195,10 +195,11 @@ struct SelectOpInterface
195195
// If the buffers have different types, they differ only in their layout
196196
// map.
197197
auto memrefType = llvm::cast<MemRefType>(*trueType);
198-
return getMemRefTypeWithFullyDynamicLayout(
199-
RankedTensorType::get(memrefType.getShape(),
200-
memrefType.getElementType()),
201-
memrefType.getMemorySpace());
198+
return mlir::cast<bufferization::BufferLikeType>(
199+
getMemRefTypeWithFullyDynamicLayout(
200+
RankedTensorType::get(memrefType.getShape(),
201+
memrefType.getElementType()),
202+
memrefType.getMemorySpace()));
202203
}
203204
};
204205

0 commit comments

Comments
 (0)