From 5e5816982024ae26eabc2f72e58d6e8b9fb5ce9c Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Wed, 23 Apr 2025 13:45:58 -0700 Subject: [PATCH] add IntegerLikeTypeInterface to enable out-of-tree uses of int attribute parsers --- mlir/include/mlir/IR/BuiltinAttributes.h | 4 +- mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 38 +++++++++++++++++++ mlir/include/mlir/IR/BuiltinTypes.td | 6 ++- mlir/lib/AsmParser/AttributeParser.cpp | 16 ++++---- mlir/lib/IR/AsmPrinter.cpp | 2 +- mlir/lib/IR/AttributeDetail.h | 5 ++- mlir/lib/IR/BuiltinAttributes.cpp | 35 +++++++++-------- mlir/lib/IR/BuiltinTypes.cpp | 10 +++++ 8 files changed, 85 insertions(+), 31 deletions(-) diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h index c07ade606a775..005316a737dff 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -548,7 +548,9 @@ class DenseElementsAttr : public Attribute { std::enable_if_t::value>; template > FailureOr> tryGetValues() const { - if (!getElementType().isIntOrIndex()) + auto intLikeType = + llvm::dyn_cast(getElementType()); + if (!intLikeType) return failure(); return iterator_range_impl(getType(), raw_int_begin(), raw_int_end()); diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td index 4a4f818b46c57..3d459f006093a 100644 --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td @@ -257,4 +257,42 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> { }]; } +def IntegerLikeTypeInterface : TypeInterface<"IntegerLikeTypeInterface"> { + let cppNamespace = "::mlir"; + let description = [{ + This type interface is for types that behave like integers. It provides + the API that allows MLIR utilities to treat them the same was as MLIR + treats integer types in settings like parsing and printing. + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Returns the storage bit width for this type. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getStorageBitWidth", + /*args=*/(ins) + >, + InterfaceMethod< + /*desc=*/[{ + Returns true if this type is signed. + }], + /*retTy=*/"bool", + /*methodName=*/"isSigned", + /*args=*/(ins), + /*defaultImplementation=*/"return true;" + >, + InterfaceMethod< + /*desc=*/[{ + Returns true if this type is signless. + }], + /*retTy=*/"bool", + /*methodName=*/"isSignless", + /*args=*/(ins), + /*defaultImplementation=*/"return true;" + >, + ]; +} + #endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_ diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index 771de01fc8d5d..6eb2ec333351a 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -466,7 +466,8 @@ def Builtin_Function : Builtin_Type<"Function", "function"> { //===----------------------------------------------------------------------===// def Builtin_Index : Builtin_Type<"Index", "index", - [VectorElementTypeInterface]> { + [VectorElementTypeInterface, + DeclareTypeInterfaceMethods]> { let summary = "Integer-like type with unknown platform-dependent bit width"; let description = [{ Syntax: @@ -497,7 +498,8 @@ def Builtin_Index : Builtin_Type<"Index", "index", //===----------------------------------------------------------------------===// def Builtin_Integer : Builtin_Type<"Integer", "integer", - [VectorElementTypeInterface]> { + [VectorElementTypeInterface, + DeclareTypeInterfaceMethods]> { let summary = "Integer type with arbitrary precision up to a fixed limit"; let description = [{ Syntax: diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index 2474e88373e04..a3252cf1964ee 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/IntegerSet.h" @@ -366,8 +367,12 @@ static std::optional buildAttributeAPInt(Type type, bool isNegative, return std::nullopt; // Extend or truncate the bitwidth to the right size. - unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth - : type.getIntOrFloatBitWidth(); + unsigned width; + if (auto intLikeType = dyn_cast(type)) { + width = intLikeType.getStorageBitWidth(); + } else { + width = type.getIntOrFloatBitWidth(); + } if (width > result.getBitWidth()) { result = result.zext(width); @@ -425,10 +430,6 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) { return FloatAttr::get(floatType, *result); } - if (!isa(type)) - return emitError(loc, "integer literal not valid for specified type"), - nullptr; - if (isNegative && type.isUnsignedInteger()) { emitError(loc, "negative integer literal not valid for unsigned integer type"); @@ -584,7 +585,8 @@ DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) { } // Handle integer and index types. - if (eltType.isIntOrIndex()) { + auto integerLikeType = dyn_cast(eltType); + if (integerLikeType || eltType.isIntOrIndex()) { std::vector intValues; if (failed(getIntAttrElements(loc, eltType, intValues))) return nullptr; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 5b5ec841917e7..d74bf5d975f0b 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2656,7 +2656,7 @@ void AsmPrinter::Impl::printDenseIntOrFPElementsAttr( os << ")"; }); } - } else if (elementType.isIntOrIndex()) { + } else if (isa(elementType)) { auto valueIt = attr.value_begin(); printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) { printDenseIntElement(*(valueIt + index), os, elementType); diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index 26d40ac3a38f6..96b269e3b1363 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -37,8 +37,9 @@ inline size_t getDenseElementBitWidth(Type eltType) { // Align the width for complex to 8 to make storage and interpretation easier. if (ComplexType comp = llvm::dyn_cast(eltType)) return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2; - if (eltType.isIndex()) - return IndexType::kInternalStorageBitWidth; + if (auto intLikeType = dyn_cast(eltType)) + return intLikeType.getStorageBitWidth(); + return eltType.getIntOrFloatBitWidth(); } diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index e9af1f77a379e..7bc3d9a59a37e 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -10,6 +10,7 @@ #include "AttributeDetail.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/IntegerSet.h" @@ -379,22 +380,20 @@ APSInt IntegerAttr::getAPSInt() const { LogicalResult IntegerAttr::verify(function_ref emitError, Type type, APInt value) { - if (IntegerType integerType = llvm::dyn_cast(type)) { - if (integerType.getWidth() != value.getBitWidth()) - return emitError() << "integer type bit width (" << integerType.getWidth() - << ") doesn't match value bit width (" - << value.getBitWidth() << ")"; - return success(); + unsigned width; + if (auto intLikeType = dyn_cast(type)) { + width = intLikeType.getStorageBitWidth(); + } else { + return emitError() << "expected integer-like type"; } - if (llvm::isa(type)) { - if (value.getBitWidth() != IndexType::kInternalStorageBitWidth) - return emitError() - << "value bit width (" << value.getBitWidth() - << ") doesn't match index type internal storage bit width (" - << IndexType::kInternalStorageBitWidth << ")"; - return success(); + + if (width != value.getBitWidth()) { + return emitError() << "integer-like type bit width (" << width + << ") doesn't match value bit width (" + << value.getBitWidth() << ")"; } - return emitError() << "expected integer or index type"; + + return success(); } BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) { @@ -1019,7 +1018,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, /// element type of 'type'. DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef values) { - assert(type.getElementType().isIntOrIndex()); + assert(isa(type.getElementType())); assert(hasSameNumElementsOrSplat(type, values)); size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); @@ -1130,11 +1129,11 @@ static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, if (type.isIndex()) return true; - auto intType = llvm::dyn_cast(type); + auto intType = llvm::dyn_cast(type); if (!intType) { LLVM_DEBUG(llvm::dbgs() - << "expected integer type when isInt is true, but found " << type - << "\n"); + << "expected integer-like type when isInt is true, but found " + << type << "\n"); return false; } diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 3924d082f0628..c38e33dba7e5d 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -59,6 +59,14 @@ LogicalResult ComplexType::verify(function_ref emitError, return success(); } +//===----------------------------------------------------------------------===// +// Index Type +//===----------------------------------------------------------------------===// + +unsigned IndexType::getStorageBitWidth() const { + return kInternalStorageBitWidth; +} + //===----------------------------------------------------------------------===// // Integer Type //===----------------------------------------------------------------------===// @@ -86,6 +94,8 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); } +unsigned IntegerType::getStorageBitWidth() const { return getWidth(); } + //===----------------------------------------------------------------------===// // Float Types //===----------------------------------------------------------------------===//