24
24
#include " mlir/Interfaces/FunctionImplementation.h"
25
25
#include " mlir/Transforms/InliningUtils.h"
26
26
27
+ #include " llvm/ADT/APFloat.h"
27
28
#include " llvm/ADT/TypeSwitch.h"
28
29
#include " llvm/IR/Function.h"
29
30
#include " llvm/IR/Type.h"
@@ -3187,6 +3188,18 @@ static int64_t getNumElements(Type t) {
3187
3188
return 1 ;
3188
3189
}
3189
3190
3191
+ // / Determine the element type of `type`. Supported types are `VectorType`,
3192
+ // / `TensorType`, and `LLVMArrayType`. Everything else is treated as a scalar.
3193
+ static Type getElementType (Type type) {
3194
+ while (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(type))
3195
+ type = arrayType.getElementType ();
3196
+ if (auto vecType = dyn_cast<VectorType>(type))
3197
+ return vecType.getElementType ();
3198
+ if (auto tenType = dyn_cast<TensorType>(type))
3199
+ return tenType.getElementType ();
3200
+ return type;
3201
+ }
3202
+
3190
3203
// / Check if the given type is a scalable vector type or a vector/array type
3191
3204
// / that contains a nested scalable vector type.
3192
3205
static bool hasScalableVectorType (Type t) {
@@ -3281,60 +3294,69 @@ LogicalResult LLVM::ConstantOp::verify() {
3281
3294
}
3282
3295
if (auto structType = dyn_cast<LLVMStructType>(getType ())) {
3283
3296
auto arrayAttr = dyn_cast<ArrayAttr>(getValue ());
3284
- if (!arrayAttr) {
3285
- return emitOpError () << " expected array attribute for a struct constant" ;
3286
- }
3297
+ if (!arrayAttr)
3298
+ return emitOpError () << " expected array attribute for struct type" ;
3287
3299
3288
3300
ArrayRef<Type> elementTypes = structType.getBody ();
3289
3301
if (arrayAttr.size () != elementTypes.size ()) {
3290
3302
return emitOpError () << " expected array attribute of size "
3291
3303
<< elementTypes.size ();
3292
3304
}
3293
- for (auto elementTy : elementTypes) {
3294
- if (!isa<IntegerType, FloatType, LLVMPPCFP128Type>(elementTy )) {
3305
+ for (auto [i, attr, type] : llvm::enumerate (arrayAttr, elementTypes) ) {
3306
+ if (!type. isSignlessIntOrIndexOrFloat ( )) {
3295
3307
return emitOpError () << " expected struct element types to be floating "
3296
3308
" point type or integer type" ;
3297
3309
}
3298
- }
3299
-
3300
- for (size_t i = 0 ; i < elementTypes.size (); ++i) {
3301
- Attribute element = arrayAttr[i];
3302
- if (!isa<IntegerAttr, FloatAttr>(element)) {
3303
- return emitOpError ()
3304
- << " expected struct element attribute types to be floating "
3305
- " point type or integer type" ;
3310
+ if (!isa<FloatAttr, IntegerAttr>(attr)) {
3311
+ return emitOpError () << " expected element of array attribute to be "
3312
+ " floating point or integer" ;
3306
3313
}
3307
- auto elementType = cast<TypedAttr>(element).getType ();
3308
- if (elementType != elementTypes[i]) {
3314
+ if (cast<TypedAttr>(attr).getType () != type)
3309
3315
return emitOpError ()
3310
3316
<< " struct element at index " << i << " is of wrong type" ;
3311
- }
3312
3317
}
3313
3318
3314
3319
return success ();
3315
3320
}
3316
- if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType ())) {
3321
+ if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType ()))
3317
3322
return emitOpError () << " does not support target extension type." ;
3318
- }
3323
+
3324
+ // Check that an attribute whose element type has floating point semantics
3325
+ // `attributeFloatSemantics` is compatible with a type whose element type
3326
+ // is `constantElementType`.
3327
+ //
3328
+ // Requirement is that either
3329
+ // 1) They have identical floating point types.
3330
+ // 2) `constantElementType` is an integer type of the same width as the float
3331
+ // attribute. This is to support builtin MLIR float types without LLVM
3332
+ // equivalents, see comments in getLLVMConstant for more details.
3333
+ auto verifyFloatSemantics =
3334
+ [this ](const llvm::fltSemantics &attributeFloatSemantics,
3335
+ Type constantElementType) -> LogicalResult {
3336
+ if (auto floatType = dyn_cast<FloatType>(constantElementType)) {
3337
+ if (&floatType.getFloatSemantics () != &attributeFloatSemantics) {
3338
+ return emitOpError ()
3339
+ << " attribute and type have different float semantics" ;
3340
+ }
3341
+ return success ();
3342
+ }
3343
+ unsigned floatWidth = APFloat::getSizeInBits (attributeFloatSemantics);
3344
+ if (isa<IntegerType>(constantElementType)) {
3345
+ if (!constantElementType.isInteger (floatWidth))
3346
+ return emitOpError () << " expected integer type of width " << floatWidth;
3347
+
3348
+ return success ();
3349
+ }
3350
+ return success ();
3351
+ };
3319
3352
3320
3353
// Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
3321
- if (auto intAttr = dyn_cast <IntegerAttr>(getValue ())) {
3354
+ if (isa <IntegerAttr>(getValue ())) {
3322
3355
if (!llvm::isa<IntegerType>(getType ()))
3323
3356
return emitOpError () << " expected integer type" ;
3324
3357
} else if (auto floatAttr = dyn_cast<FloatAttr>(getValue ())) {
3325
- const llvm::fltSemantics &sem = floatAttr.getValue ().getSemantics ();
3326
- unsigned floatWidth = APFloat::getSizeInBits (sem);
3327
- if (auto floatTy = dyn_cast<FloatType>(getType ())) {
3328
- if (floatTy.getWidth () != floatWidth) {
3329
- return emitOpError () << " expected float type of width " << floatWidth;
3330
- }
3331
- }
3332
- // See the comment for getLLVMConstant for more details about why 8-bit
3333
- // floats can be represented by integers.
3334
- if (isa<IntegerType>(getType ()) && !getType ().isInteger (floatWidth)) {
3335
- return emitOpError () << " expected integer type of width " << floatWidth;
3336
- }
3337
- } else if (isa<ElementsAttr>(getValue ())) {
3358
+ return verifyFloatSemantics (floatAttr.getValue ().getSemantics (), getType ());
3359
+ } else if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue ())) {
3338
3360
if (hasScalableVectorType (getType ())) {
3339
3361
// The exact number of elements of a scalable vector is unknown, so we
3340
3362
// allow only splat attributes.
@@ -3346,18 +3368,32 @@ LogicalResult LLVM::ConstantOp::verify() {
3346
3368
}
3347
3369
if (!isa<VectorType, LLVM::LLVMArrayType>(getType ()))
3348
3370
return emitOpError () << " expected vector or array type" ;
3371
+
3349
3372
// The number of elements of the attribute and the type must match.
3350
- if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue ())) {
3351
- int64_t attrNumElements = elementsAttr.getNumElements ();
3352
- if (getNumElements (getType ()) != attrNumElements)
3353
- return emitOpError ()
3354
- << " type and attribute have a different number of elements: "
3355
- << getNumElements (getType ()) << " vs. " << attrNumElements;
3373
+ int64_t attrNumElements = elementsAttr.getNumElements ();
3374
+ if (getNumElements (getType ()) != attrNumElements) {
3375
+ return emitOpError ()
3376
+ << " type and attribute have a different number of elements: "
3377
+ << getNumElements (getType ()) << " vs. " << attrNumElements;
3378
+ }
3379
+
3380
+ Type attrElmType = getElementType (elementsAttr.getType ());
3381
+ Type resultElmType = getElementType (getType ());
3382
+ if (auto floatType = dyn_cast<FloatType>(attrElmType))
3383
+ return verifyFloatSemantics (floatType.getFloatSemantics (), resultElmType);
3384
+
3385
+ if (isa<IntegerType>(attrElmType) && !isa<IntegerType>(resultElmType)) {
3386
+ return emitOpError (
3387
+ " expected integer element type for integer elements attribute" );
3356
3388
}
3357
3389
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue ())) {
3390
+
3391
+ // The case where the constant is LLVMStructType has already been handled.
3358
3392
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(getType ());
3359
3393
if (!arrayType)
3360
- return emitOpError () << " expected array type" ;
3394
+ return emitOpError ()
3395
+ << " expected array or struct type for array attribute" ;
3396
+
3361
3397
// When the attribute is an ArrayAttr, check that its nesting matches the
3362
3398
// corresponding ArrayType or VectorType nesting.
3363
3399
return verifyStructArrayConstant (*this , arrayType, arrayAttr, /* dim=*/ 0 );
0 commit comments