From 346d1d367edf5dc302bce9e831c4f0e570a5b8f0 Mon Sep 17 00:00:00 2001 From: hanhanW Date: Tue, 18 Feb 2025 14:22:12 -0800 Subject: [PATCH] [mlir] Add two clone methods about encoding to RankedTensorType. There are clone methods for shape and element type, but not for encodings. The revision adds two clone method to RankedTensorType: - dropEncoding(): Return a clone of this type without the encoding. - cloneWithEncoding(Attribute encoding): Return a clone of this type with the given new encoding and the same shape and element type as this type. Signed-off-by: hanhanW --- mlir/include/mlir/IR/BuiltinTypes.td | 11 +++++++++++ mlir/unittests/IR/ShapedTypeTest.cpp | 14 ++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index e5a2ae81da0c9..af474b3e3ec47 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -1035,6 +1035,17 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [ RankedTensorType clone(::mlir::Type elementType) { return ::llvm::cast(cloneWith(getShape(), elementType)); } + + /// Return a clone of this type without the encoding. + RankedTensorType dropEncoding() { + return RankedTensorType::get(getShape(), getElementType()); + } + + /// Return a clone of this type with the given new encoding and the same + /// shape and element type as this type. + RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) { + return RankedTensorType::get(getShape(), getElementType(), encoding); + } }]; let skipDefaultBuilders = 1; let genVerifyDecl = 1; diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp index c2900b5aaeeeb..bc4066ed210e8 100644 --- a/mlir/unittests/IR/ShapedTypeTest.cpp +++ b/mlir/unittests/IR/ShapedTypeTest.cpp @@ -282,6 +282,20 @@ TEST(ShapedTypeTest, RankedTensorTypeView) { ASSERT_TRUE(mlir::isa(viewCreated)); view = mlir::cast(viewCreated); EXPECT_EQ(view.getName(), "bob"); + + // Verify encoding clone methods. + EXPECT_EQ(unitEncodingRankedTensorType, + cast(noEncodingRankedTensorType) + .cloneWithEncoding(unitAttr)); + EXPECT_EQ(stringEncodingRankedTensorType, + cast(noEncodingRankedTensorType) + .cloneWithEncoding(stringAttr)); + EXPECT_EQ( + noEncodingRankedTensorType, + cast(unitEncodingRankedTensorType).dropEncoding()); + EXPECT_EQ( + noEncodingRankedTensorType, + cast(stringEncodingRankedTensorType).dropEncoding()); } } // namespace