Skip to content

Commit b05a291

Browse files
Switch from ConversionDialectInterface to TensorLike API
Noteworthy changes: * bufferization::getMemRefType() accepts a TensorType instead of Value to achieve broader applicability * BufferizationOptions::UnknownTypeConverterFn accepts a TensorType instead of Value to allow it being used in the updated getMemRefType()
1 parent 7ef1183 commit b05a291

File tree

12 files changed

+127
-181
lines changed

12 files changed

+127
-181
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -745,8 +745,8 @@ bool defaultHasTensorSemantics(Operation *op);
745745
FailureOr<BaseMemRefType> castToMemRef(FailureOr<BufferLikeType> bufferType);
746746

747747
/// This function is a free-standing helper that relies on
748-
/// bufferization::ConversionInterface to verify the types in tensor and buffer
749-
/// worlds match.
748+
/// bufferization::TensorLikeTypeInterface to verify the types in tensor and
749+
/// buffer worlds match.
750750
bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer);
751751
} // namespace detail
752752

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h

Lines changed: 0 additions & 64 deletions
This file was deleted.

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,15 @@
1313
// Bufferization Type Interfaces
1414
//===----------------------------------------------------------------------===//
1515

16+
#include "mlir/IR/Diagnostics.h"
1617
#include "mlir/IR/Types.h"
1718

19+
namespace mlir::bufferization {
20+
struct BufferizationOptions;
21+
class BufferizationState;
22+
class BufferLikeType;
23+
} // namespace mlir::bufferization
24+
1825
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc"
1926

2027
#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,31 @@ def Bufferization_TensorLikeTypeInterface
2121
let description = [{
2222
Indicates that this type is a tensor type (similarly to a MLIR builtin
2323
tensor) for bufferization purposes.
24-
25-
The interface currently has no methods as it is used by types to opt into
26-
being supported by the bufferization procedures.
2724
}];
25+
26+
let methods = [
27+
InterfaceMethod<[{
28+
Returns a BufferLike type for this TensorLike type.
29+
}],
30+
/*retTy=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
31+
/*methodName=*/"getBufferType",
32+
/*args=*/(ins
33+
"const ::mlir::bufferization::BufferizationOptions &":$options,
34+
"const ::mlir::bufferization::BufferizationState &":$state,
35+
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError
36+
)
37+
>,
38+
InterfaceMethod<[{
39+
Returns whether a BufferLike type is compatible to this TensorLike type.
40+
The BufferLike type is assumed to be created by getBufferType().
41+
}],
42+
/*retTy=*/"::mlir::LogicalResult",
43+
/*methodName=*/"verifyCompatibleBufferType",
44+
/*args=*/(ins
45+
"::mlir::bufferization::BufferLikeType":$bufferType,
46+
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
47+
>
48+
];
2849
}
2950

3051
def Bufferization_BufferLikeTypeInterface

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1010
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11-
#include "mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h"
1211
#include "mlir/Dialect/Func/IR/FuncOps.h"
1312
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1413
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -720,11 +719,9 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
720719
if (bufferizableOp)
721720
return bufferizableOp.getBufferType(value, options, state, invocationStack);
722721

723-
// Op is not bufferizable, use conversion interface.
724-
bufferization::ConversionInterface iface(value.getContext());
725-
return iface.getBufferType(value, options, state, [&](const Twine &message) {
726-
return op->emitError(message);
727-
});
722+
// Op is not bufferizable.
723+
return cast<TensorLikeType>(value.getType())
724+
.getBufferType(options, state, [&]() { return op->emitError(); });
728725
}
729726

730727
bool bufferization::hasTensorSemantics(Operation *op) {
@@ -1059,10 +1056,8 @@ bool bufferization::detail::typesMatchAfterBufferization(Operation &op,
10591056
assert(isa<TensorLikeType>(tensor.getType()) && "expected TensorLikeType");
10601057
assert(isa<BufferLikeType>(buffer.getType()) && "expected BufferLikeType");
10611058

1062-
// Op is not bufferizable, use conversion interface.
1063-
bufferization::ConversionInterface iface(op.getContext());
1064-
return succeeded(iface.typesMatch(
1065-
cast<TensorLikeType>(tensor.getType()),
1066-
cast<BufferLikeType>(buffer.getType()),
1067-
[&](const Twine &message) { return op.emitError(message); }));
1059+
return mlir::succeeded(
1060+
cast<TensorLikeType>(tensor.getType())
1061+
.verifyCompatibleBufferType(cast<BufferLikeType>(buffer.getType()),
1062+
[&]() { return op.emitError(); }));
10681063
}

mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp

Lines changed: 0 additions & 58 deletions
This file was deleted.

mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,40 @@ struct BufferizationInlinerInterface : public DialectInlinerInterface {
5757
template <typename Tensor>
5858
struct BuiltinTensorExternalModel
5959
: TensorLikeType::ExternalModel<BuiltinTensorExternalModel<Tensor>,
60-
Tensor> {};
60+
Tensor> {
61+
llvm::FailureOr<BufferLikeType> getBufferType(
62+
mlir::Type tensor, const BufferizationOptions &options,
63+
const BufferizationState &state,
64+
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
65+
auto tensorType = cast<TensorType>(tensor);
66+
// Fall back to tensor -> memref conversion.
67+
auto memSpace = options.defaultMemorySpaceFn(tensorType);
68+
if (!memSpace.has_value())
69+
return emitError() << "could not infer memory space";
70+
71+
return cast<BufferLikeType>(
72+
getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
73+
}
74+
75+
mlir::LogicalResult verifyCompatibleBufferType(
76+
mlir::Type tensor, BufferLikeType bufferType,
77+
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
78+
// Fall back to tensor, memref checking.
79+
assert(isa<TensorType>(tensor) && "expected tensor type");
80+
assert(isa<BaseMemRefType>(bufferType) && "expected memref type");
81+
82+
auto tensorType = cast<ShapedType>(tensor);
83+
auto memrefType = cast<ShapedType>(bufferType);
84+
85+
if (tensorType.getShape() != memrefType.getShape())
86+
return emitError() << "shapes do not match";
87+
88+
if (tensorType.getElementType() != memrefType.getElementType())
89+
return emitError() << "element types do not match";
90+
91+
return mlir::success();
92+
}
93+
};
6194

6295
template <typename MemRef>
6396
struct BuiltinMemRefExternalModel
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- BufferizationTypeInterfaces.cpp - Type Interfaces --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
10+
11+
//===----------------------------------------------------------------------===//
12+
// Bufferization Type Interfaces
13+
//===----------------------------------------------------------------------===//
14+
15+
namespace mlir {
16+
namespace bufferization {
17+
18+
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp.inc"
19+
20+
} // namespace bufferization
21+
} // namespace mlir

mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
66
BufferizationDialect.cpp
77
BufferViewFlowOpInterface.cpp
88
UnstructuredControlFlow.cpp
9-
BufferizationConversionInterface.cpp
9+
BufferizationTypeInterfaces.cpp
1010

1111
ADDITIONAL_HEADER_DIRS
1212
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization

mlir/test/lib/Dialect/Test/TestDialect.cpp

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include "TestTypes.h"
1212
#include "mlir/Bytecode/BytecodeImplementation.h"
1313
#include "mlir/Dialect/Arith/IR/Arith.h"
14-
#include "mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h"
1514
#include "mlir/Dialect/Func/IR/FuncOps.h"
1615
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1716
#include "mlir/IR/AsmState.h"
@@ -285,44 +284,6 @@ getDynamicCustomParserPrinterOp(TestDialect *dialect) {
285284
verifier, regionVerifier, parser, printer);
286285
}
287286

288-
namespace {
289-
290-
struct TestConverter : bufferization::ConversionDialectInterface {
291-
TestConverter(Dialect *dialect)
292-
: bufferization::ConversionDialectInterface(dialect) {}
293-
294-
FailureOr<bufferization::BufferLikeType>
295-
getBufferType(Value value, const bufferization::BufferizationOptions &options,
296-
const bufferization::BufferizationState &state,
297-
function_ref<InFlightDiagnostic(const Twine &)> emitError)
298-
const override {
299-
auto testTensor = dyn_cast<TestTensorType>(value.getType());
300-
if (!testTensor)
301-
return emitError("expected TestTensorType");
302-
303-
return cast<bufferization::BufferLikeType>(
304-
TestMemrefType::get(value.getContext(), testTensor.getShape(),
305-
testTensor.getElementType(), nullptr));
306-
}
307-
308-
LogicalResult typesMatch(bufferization::TensorLikeType tensor,
309-
bufferization::BufferLikeType buffer,
310-
function_ref<InFlightDiagnostic(const Twine &)>
311-
emitError) const override {
312-
auto testTensor = dyn_cast<TestTensorType>(tensor);
313-
auto testMemref = dyn_cast<TestMemrefType>(buffer);
314-
if (!testTensor || !testMemref)
315-
return emitError("expected TestTensorType and TestMemrefType");
316-
317-
const bool valid =
318-
testTensor.getShape() == testMemref.getShape() &&
319-
testTensor.getElementType() == testMemref.getElementType();
320-
return success(valid);
321-
}
322-
};
323-
324-
} // namespace
325-
326287
//===----------------------------------------------------------------------===//
327288
// TestDialect
328289
//===----------------------------------------------------------------------===//
@@ -372,7 +333,6 @@ void TestDialect::initialize() {
372333
registerDynamicOp(getDynamicCustomParserPrinterOp(this));
373334
registerInterfaces();
374335
allowUnknownOperations();
375-
addInterface<TestConverter>();
376336

377337
// Instantiate our fallback op interface that we'll use on specific
378338
// unregistered op.

0 commit comments

Comments
 (0)