2424#include " mlir/Interfaces/FunctionImplementation.h"
2525#include " mlir/Transforms/InliningUtils.h"
2626
27+ #include " llvm/ADT/APFloat.h"
2728#include " llvm/ADT/TypeSwitch.h"
2829#include " llvm/IR/Function.h"
2930#include " llvm/IR/Type.h"
@@ -3187,6 +3188,18 @@ static int64_t getNumElements(Type t) {
31873188 return 1 ;
31883189}
31893190
3191+ // / Determine the element type of `type`. Supported types are `VectorType`,
3192+ // / `TensorType`, and `LLVMArrayType`. Everything else is treated as a scalar.
3193+ static Type getElementType (Type type) {
3194+ while (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(type))
3195+ type = arrayType.getElementType ();
3196+ if (auto vecType = dyn_cast<VectorType>(type))
3197+ return vecType.getElementType ();
3198+ if (auto tenType = dyn_cast<TensorType>(type))
3199+ return tenType.getElementType ();
3200+ return type;
3201+ }
3202+
31903203// / Check if the given type is a scalable vector type or a vector/array type
31913204// / that contains a nested scalable vector type.
31923205static bool hasScalableVectorType (Type t) {
@@ -3281,60 +3294,69 @@ LogicalResult LLVM::ConstantOp::verify() {
32813294 }
32823295 if (auto structType = dyn_cast<LLVMStructType>(getType ())) {
32833296 auto arrayAttr = dyn_cast<ArrayAttr>(getValue ());
3284- if (!arrayAttr) {
3285- return emitOpError () << " expected array attribute for a struct constant" ;
3286- }
3297+ if (!arrayAttr)
3298+ return emitOpError () << " expected array attribute for struct type" ;
32873299
32883300 ArrayRef<Type> elementTypes = structType.getBody ();
32893301 if (arrayAttr.size () != elementTypes.size ()) {
32903302 return emitOpError () << " expected array attribute of size "
32913303 << elementTypes.size ();
32923304 }
3293- for (auto elementTy : elementTypes) {
3294- if (!isa<IntegerType, FloatType, LLVMPPCFP128Type>(elementTy )) {
3305+ for (auto [i, attr, type] : llvm::enumerate (arrayAttr, elementTypes) ) {
3306+ if (!type. isSignlessIntOrIndexOrFloat ( )) {
32953307 return emitOpError () << " expected struct element types to be floating "
32963308 " point type or integer type" ;
32973309 }
3298- }
3299-
3300- for (size_t i = 0 ; i < elementTypes.size (); ++i) {
3301- Attribute element = arrayAttr[i];
3302- if (!isa<IntegerAttr, FloatAttr>(element)) {
3303- return emitOpError ()
3304- << " expected struct element attribute types to be floating "
3305- " point type or integer type" ;
3310+ if (!isa<FloatAttr, IntegerAttr>(attr)) {
3311+ return emitOpError () << " expected element of array attribute to be "
3312+ " floating point or integer" ;
33063313 }
3307- auto elementType = cast<TypedAttr>(element).getType ();
3308- if (elementType != elementTypes[i]) {
3314+ if (cast<TypedAttr>(attr).getType () != type)
33093315 return emitOpError ()
33103316 << " struct element at index " << i << " is of wrong type" ;
3311- }
33123317 }
33133318
33143319 return success ();
33153320 }
3316- if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType ())) {
3321+ if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType ()))
33173322 return emitOpError () << " does not support target extension type." ;
3318- }
3323+
3324+ // Check that an attribute whose element type has floating point semantics
3325+ // `attributeFloatSemantics` is compatible with a type whose element type
3326+ // is `constantElementType`.
3327+ //
3328+ // Requirement is that either
3329+ // 1) They have identical floating point types.
3330+ // 2) `constantElementType` is an integer type of the same width as the float
3331+ // attribute. This is to support builtin MLIR float types without LLVM
3332+ // equivalents, see comments in getLLVMConstant for more details.
3333+ auto verifyFloatSemantics =
3334+ [this ](const llvm::fltSemantics &attributeFloatSemantics,
3335+ Type constantElementType) -> LogicalResult {
3336+ if (auto floatType = dyn_cast<FloatType>(constantElementType)) {
3337+ if (&floatType.getFloatSemantics () != &attributeFloatSemantics) {
3338+ return emitOpError ()
3339+ << " attribute and type have different float semantics" ;
3340+ }
3341+ return success ();
3342+ }
3343+ unsigned floatWidth = APFloat::getSizeInBits (attributeFloatSemantics);
3344+ if (isa<IntegerType>(constantElementType)) {
3345+ if (!constantElementType.isInteger (floatWidth))
3346+ return emitOpError () << " expected integer type of width " << floatWidth;
3347+
3348+ return success ();
3349+ }
3350+ return success ();
3351+ };
33193352
33203353 // Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
3321- if (auto intAttr = dyn_cast <IntegerAttr>(getValue ())) {
3354+ if (isa <IntegerAttr>(getValue ())) {
33223355 if (!llvm::isa<IntegerType>(getType ()))
33233356 return emitOpError () << " expected integer type" ;
33243357 } else if (auto floatAttr = dyn_cast<FloatAttr>(getValue ())) {
3325- const llvm::fltSemantics &sem = floatAttr.getValue ().getSemantics ();
3326- unsigned floatWidth = APFloat::getSizeInBits (sem);
3327- if (auto floatTy = dyn_cast<FloatType>(getType ())) {
3328- if (floatTy.getWidth () != floatWidth) {
3329- return emitOpError () << " expected float type of width " << floatWidth;
3330- }
3331- }
3332- // See the comment for getLLVMConstant for more details about why 8-bit
3333- // floats can be represented by integers.
3334- if (isa<IntegerType>(getType ()) && !getType ().isInteger (floatWidth)) {
3335- return emitOpError () << " expected integer type of width " << floatWidth;
3336- }
3337- } else if (isa<ElementsAttr>(getValue ())) {
3358+ return verifyFloatSemantics (floatAttr.getValue ().getSemantics (), getType ());
3359+ } else if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue ())) {
33383360 if (hasScalableVectorType (getType ())) {
33393361 // The exact number of elements of a scalable vector is unknown, so we
33403362 // allow only splat attributes.
@@ -3346,18 +3368,32 @@ LogicalResult LLVM::ConstantOp::verify() {
33463368 }
33473369 if (!isa<VectorType, LLVM::LLVMArrayType>(getType ()))
33483370 return emitOpError () << " expected vector or array type" ;
3371+
33493372 // The number of elements of the attribute and the type must match.
3350- if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue ())) {
3351- int64_t attrNumElements = elementsAttr.getNumElements ();
3352- if (getNumElements (getType ()) != attrNumElements)
3353- return emitOpError ()
3354- << " type and attribute have a different number of elements: "
3355- << getNumElements (getType ()) << " vs. " << attrNumElements;
3373+ int64_t attrNumElements = elementsAttr.getNumElements ();
3374+ if (getNumElements (getType ()) != attrNumElements) {
3375+ return emitOpError ()
3376+ << " type and attribute have a different number of elements: "
3377+ << getNumElements (getType ()) << " vs. " << attrNumElements;
3378+ }
3379+
3380+ Type attrElmType = getElementType (elementsAttr.getType ());
3381+ Type resultElmType = getElementType (getType ());
3382+ if (auto floatType = dyn_cast<FloatType>(attrElmType))
3383+ return verifyFloatSemantics (floatType.getFloatSemantics (), resultElmType);
3384+
3385+ if (isa<IntegerType>(attrElmType) && !isa<IntegerType>(resultElmType)) {
3386+ return emitOpError (
3387+ " expected integer element type for integer elements attribute" );
33563388 }
33573389 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue ())) {
3390+
3391+ // The case where the constant is LLVMStructType has already been handled.
33583392 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(getType ());
33593393 if (!arrayType)
3360- return emitOpError () << " expected array type" ;
3394+ return emitOpError ()
3395+ << " expected array or struct type for array attribute" ;
3396+
33613397 // When the attribute is an ArrayAttr, check that its nesting matches the
33623398 // corresponding ArrayType or VectorType nesting.
33633399 return verifyStructArrayConstant (*this , arrayType, arrayAttr, /* dim=*/ 0 );
0 commit comments