-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][bufferization] Add tensor-like and buffer-like interfaces #134220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
33e03b3
cb5892f
911096e
253296f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| //===- BufferizationTypeInterfaces.h - Type Interfaces ----------*- C++ -*-===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_ | ||
| #define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_ | ||
|
|
||
| #include "mlir/IR/BuiltinTypeInterfaces.h" // for ShapedTypeInterface | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // Bufferization Type Interfaces | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc" | ||
|
|
||
| #endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| //===- BufferizationTypeInterfaces.td - Type Interfaces ----*- tablegen -*-===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
| // | ||
| // This is the definition file for type interfaces used in Bufferization. | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #ifndef BUFFERIZATION_TYPE_INTERFACES | ||
| #define BUFFERIZATION_TYPE_INTERFACES | ||
|
|
||
| include "mlir/IR/OpBase.td" | ||
| include "mlir/IR/BuiltinTypeInterfaces.td" | ||
|
|
||
| def Bufferization_TensorLikeTypeInterface | ||
| : TypeInterface<"TensorLikeType", [ShapedTypeInterface]> { | ||
| let cppNamespace = "::mlir::bufferization"; | ||
| let description = [{ | ||
| Indicates that the type that attaches this interface can be treated as a | ||
|
||
| tensor type (similarly to a MLIR builtin tensor) during bufferization. | ||
|
|
||
| Implementing this interface means that the type also implements | ||
| ShapedTypeInterface. | ||
|
|
||
| The interface currently has no methods as it is used by types to opt into | ||
| being supported by the bufferization procedures. | ||
| }]; | ||
| } | ||
|
|
||
| def Bufferization_MemRefLikeTypeInterface | ||
|
||
| : TypeInterface<"MemRefLikeType", [ShapedTypeInterface]> { | ||
| let cppNamespace = "::mlir::bufferization"; | ||
| let description = [{ | ||
| Indicates that the type that attaches this interface can be treated as a | ||
| memref type (similarly to a MLIR builtin memref) during bufferization. | ||
|
|
||
| Implementing this interface means that the type also implements | ||
| ShapedTypeInterface. | ||
|
|
||
| The interface currently has no methods as it is used by types to opt into | ||
| being supported by the bufferization procedures. | ||
| }]; | ||
| } | ||
|
|
||
| #endif // BUFFERIZATION_TYPE_INTERFACES | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,8 @@ | |
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "mlir/Dialect/Bufferization/IR/Bufferization.h" | ||
| #include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" | ||
|
||
| #include "mlir/IR/BuiltinAttributes.h" | ||
| #include "mlir/IR/BuiltinDialect.h" | ||
| #include "mlir/IR/BuiltinOps.h" | ||
|
|
@@ -84,3 +86,78 @@ TEST(InterfaceTest, TestImplicitConversion) { | |
| typeA = typeB; | ||
| EXPECT_EQ(typeA, typeB); | ||
| } | ||
|
|
||
| TEST(InterfaceTest, TestBuiltinTensorIsTensorLikeType) { | ||
| MLIRContext context; | ||
| // Note: attaches external model to builtins | ||
| context.loadDialect<bufferization::BufferizationDialect>(); | ||
|
|
||
| auto builtinRankedTensor = mlir::RankedTensorType::get( | ||
| {1, 2, 3}, mlir::IntegerType::get(&context, 32)); | ||
| EXPECT_TRUE(mlir::isa<bufferization::TensorLikeType>(builtinRankedTensor)); | ||
| EXPECT_FALSE(mlir::isa<bufferization::MemRefLikeType>(builtinRankedTensor)); | ||
|
|
||
| auto builtinUnrankedTensor = | ||
| mlir::UnrankedTensorType::get(mlir::IntegerType::get(&context, 32)); | ||
| EXPECT_TRUE(mlir::isa<bufferization::TensorLikeType>(builtinUnrankedTensor)); | ||
| EXPECT_FALSE(mlir::isa<bufferization::MemRefLikeType>(builtinUnrankedTensor)); | ||
| } | ||
|
|
||
| TEST(InterfaceTest, TestCustomTensorIsTensorLikeType) { | ||
| MLIRContext context; | ||
| context.loadDialect<test::TestDialect>(); | ||
|
|
||
| auto customTensorType = test::TestTensorType::get( | ||
| &context, {1, 2, 3}, mlir::IntegerType::get(&context, 32)); | ||
| EXPECT_TRUE(mlir::isa<bufferization::TensorLikeType>(customTensorType)); | ||
|
|
||
| auto customCloneType = customTensorType.cloneWith( | ||
| ArrayRef<int64_t>{3, 4, 5}, customTensorType.getElementType()); | ||
| EXPECT_EQ(customTensorType.getElementType(), | ||
| customCloneType.getElementType()); | ||
| EXPECT_TRUE(mlir::isa<bufferization::TensorLikeType>(customCloneType)); | ||
| EXPECT_TRUE(mlir::isa<test::TestTensorType>(customCloneType)); | ||
|
|
||
| // user-specified conversions | ||
| bufferization::TensorLikeType baseCopy = customTensorType; | ||
| std::ignore = baseCopy; | ||
| } | ||
|
|
||
| TEST(InterfaceTest, TestBuiltinMemrefIsMemRefLikeType) { | ||
| MLIRContext context; | ||
| // Note: attaches external model to builtins | ||
| context.loadDialect<bufferization::BufferizationDialect>(); | ||
|
|
||
| auto builtinRankedMemref = | ||
| mlir::MemRefType::get({1, 2, 3}, mlir::IntegerType::get(&context, 32)); | ||
| EXPECT_TRUE(mlir::isa<bufferization::MemRefLikeType>(builtinRankedMemref)); | ||
| EXPECT_FALSE(mlir::isa<bufferization::TensorLikeType>(builtinRankedMemref)); | ||
|
|
||
| auto builtinUnrankedMemref = mlir::UnrankedMemRefType::get( | ||
| mlir::IntegerType::get(&context, 32), nullptr); | ||
| EXPECT_TRUE(mlir::isa<bufferization::MemRefLikeType>(builtinUnrankedMemref)); | ||
| EXPECT_FALSE(mlir::isa<bufferization::TensorLikeType>(builtinUnrankedMemref)); | ||
| } | ||
|
|
||
| TEST(InterfaceTest, TestCustomMemrefIsMemRefLikeType) { | ||
| MLIRContext context; | ||
| context.loadDialect<test::TestDialect>(); | ||
|
|
||
| auto customMemrefType = test::TestMemrefType::get( | ||
| &context, {1, 2, 3}, mlir::IntegerType::get(&context, 32), | ||
| mlir::StringAttr::get(&context, "some_memspace")); | ||
| EXPECT_TRUE(mlir::isa<bufferization::MemRefLikeType>(customMemrefType)); | ||
|
|
||
| auto customCloneType = customMemrefType.cloneWith( | ||
| ArrayRef<int64_t>{3, 4, 5}, customMemrefType.getElementType()); | ||
| EXPECT_EQ(customMemrefType.getElementType(), | ||
| customCloneType.getElementType()); | ||
| EXPECT_TRUE(mlir::isa<bufferization::MemRefLikeType>(customCloneType)); | ||
| EXPECT_TRUE(mlir::isa<test::TestMemrefType>(customCloneType)); | ||
| EXPECT_EQ(customMemrefType.getMemSpace(), | ||
| mlir::cast<test::TestMemrefType>(customCloneType).getMemSpace()); | ||
|
|
||
| // user-specified conversions | ||
| bufferization::MemRefLikeType baseCopy = customMemrefType; | ||
| std::ignore = baseCopy; | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: this attaches
ShapedTypeInterfacebutTensorType(base class of ranked tensor) also attachesShapedTypeInterface. is there any risk that we can run into trouble due to:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it have to inherit the
ShapedTypeInterface? If not, let's keep this simpler.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking at the code,
TensorLikedoesn't need it (or I haven't seen it).MemRefLikedoes need it: there are.hasRank()and.getShape()API usages (albeit the ones I've seen are inasserts)(also, there are
memref.getMemorySpace()usages - this is an "API" of BaseMemRefType but I guess I could introduce it later once switch to these new type interfaces happens.)overall, my motivation is to have ShapedTypeInterface APIs available to avoid boilerplate of
cast<ShapedTypeInterface>(tensorLike).getBlah(). however, sinceTensorLikedoesn't seem to need it, maybe i can drop it in that one at least? (but then it's going to be "tensor like" - not shaped type, "memref like" - shaped type which is kind of dumb given that memref is "created" from tensor).it feels like - in MLIR - tensor and memref are both shaped types "by design" but as we don't have a generic type interfaces for those, this design has to kind of leak into bufferization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which functions require these? The main places to look at are in
BufferizableOpInterface.cpp. There are functions such asbufferization::getMemRefType. These will probably have to become interface methods onTensorTypeInterface. Apart from that, I don't think the bufferization driver itself really needs anything from the shape type interface.So I would recommend to go without
ShapeTypeInterfacefor now. Are there any places where you'd have to insert explicit casts because of this today?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I only saw the shaped type APIs around here (note NDEBUG):
llvm-project/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
Lines 88 to 99 in 4a425a4
I don't think I can see such places (well, maybe in the user code but then those places are likely to cast down to the actual type so kind of not an issue).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're going to need a function such as
MemRefTypeInterface::toTensorType(). It should be possible to reimplement these asserts based on that function. We then useoperator==instead of checking shape, rank, etc.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack. then let's proceed without enforcing shaped type interface. worst case, can always be added later once the bulk of the code is migrated and new cases are discovered "lazily".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed ShapedTypeInterface propagation.