Skip to content

Commit 9bae051

Browse files
committed
add checks on element attributes for types
1 parent 5cfdb67 commit 9bae051

File tree

3 files changed

+147
-89
lines changed

3 files changed

+147
-89
lines changed

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 67 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/Interfaces/FunctionImplementation.h"
2525
#include "mlir/Transforms/InliningUtils.h"
2626

27+
#include "llvm/ADT/APFloat.h"
2728
#include "llvm/ADT/TypeSwitch.h"
2829
#include "llvm/IR/Function.h"
2930
#include "llvm/IR/Type.h"
@@ -3187,6 +3188,18 @@ static int64_t getNumElements(Type t) {
31873188
return 1;
31883189
}
31893190

3191+
/// Determine the element type of `type`. Supported types are `VectorType`,
3192+
/// `TensorType`, and `LLVMArrayType`. Everything else is treated as a scalar.
3193+
static Type getElementType(Type type) {
3194+
while (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(type))
3195+
type = arrayType.getElementType();
3196+
if (auto vecType = dyn_cast<VectorType>(type))
3197+
return vecType.getElementType();
3198+
if (auto tenType = dyn_cast<TensorType>(type))
3199+
return tenType.getElementType();
3200+
return type;
3201+
}
3202+
31903203
/// Check if the given type is a scalable vector type or a vector/array type
31913204
/// that contains a nested scalable vector type.
31923205
static bool hasScalableVectorType(Type t) {
@@ -3290,22 +3303,13 @@ LogicalResult LLVM::ConstantOp::verify() {
32903303
return emitOpError() << "expected array attribute of size "
32913304
<< elementTypes.size();
32923305
}
3293-
for (auto elementTy : elementTypes) {
3294-
if (!isa<IntegerType, FloatType, LLVMPPCFP128Type>(elementTy)) {
3295-
return emitOpError() << "expected struct element types to be floating "
3296-
"point type or integer type";
3297-
}
3298-
}
3299-
3300-
for (size_t i = 0; i < elementTypes.size(); ++i) {
3301-
Attribute element = arrayAttr[i];
3302-
if (!isa<IntegerAttr, FloatAttr>(element)) {
3303-
return emitOpError()
3304-
<< "expected struct element attribute types to be floating "
3305-
"point type or integer type";
3306+
for (auto [i, attr, type] : llvm::enumerate(arrayAttr, elementTypes)) {
3307+
if (!isa<IntegerAttr, FloatAttr>(attr)) {
3308+
return emitOpError() << "expected element of array attribute to be "
3309+
"floating point or integer";
33063310
}
3307-
auto elementType = cast<TypedAttr>(element).getType();
3308-
if (elementType != elementTypes[i]) {
3311+
auto attrType = cast<TypedAttr>(attr).getType();
3312+
if (attrType != type) {
33093313
return emitOpError()
33103314
<< "struct element at index " << i << " is of wrong type";
33113315
}
@@ -3317,24 +3321,42 @@ LogicalResult LLVM::ConstantOp::verify() {
33173321
return emitOpError() << "does not support target extension type.";
33183322
}
33193323

3324+
// Check that an attribute whose element type has floating point semantics
3325+
// `attributeFloatSemantics` is compatible with a type whose element type
3326+
// is `constantElementType`.
3327+
//
3328+
// Requirement is that either
3329+
// 1) They have identical floating point types.
3330+
// 2) `constantElementType` is an integer type of the same width as the float
3331+
// attribute. This is to support builtin MLIR float types without LLVM
3332+
// equivalents, see comments in getLLVMConstant for more details.
3333+
auto verifyFloatSemantics =
3334+
[this](const llvm::fltSemantics &attributeFloatSemantics,
3335+
Type constantElementType) -> LogicalResult {
3336+
if (auto floatType = dyn_cast<FloatType>(constantElementType)) {
3337+
if (&floatType.getFloatSemantics() != &attributeFloatSemantics) {
3338+
return emitOpError()
3339+
<< "attribute and type have different float semantics";
3340+
}
3341+
return success();
3342+
}
3343+
unsigned floatWidth = APFloat::getSizeInBits(attributeFloatSemantics);
3344+
if (isa<IntegerType>(constantElementType)) {
3345+
if (!constantElementType.isInteger(floatWidth)) {
3346+
return emitOpError() << "expected integer type of width " << floatWidth;
3347+
}
3348+
return success();
3349+
}
3350+
return success();
3351+
};
3352+
33203353
// Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
3321-
if (auto intAttr = dyn_cast<IntegerAttr>(getValue())) {
3354+
if (isa<IntegerAttr>(getValue())) {
33223355
if (!llvm::isa<IntegerType>(getType()))
33233356
return emitOpError() << "expected integer type";
33243357
} else if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
3325-
const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
3326-
unsigned floatWidth = APFloat::getSizeInBits(sem);
3327-
if (auto floatTy = dyn_cast<FloatType>(getType())) {
3328-
if (floatTy.getWidth() != floatWidth) {
3329-
return emitOpError() << "expected float type of width " << floatWidth;
3330-
}
3331-
}
3332-
// See the comment for getLLVMConstant for more details about why 8-bit
3333-
// floats can be represented by integers.
3334-
if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) {
3335-
return emitOpError() << "expected integer type of width " << floatWidth;
3336-
}
3337-
} else if (isa<ElementsAttr>(getValue())) {
3358+
return verifyFloatSemantics(floatAttr.getValue().getSemantics(), getType());
3359+
} else if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) {
33383360
if (hasScalableVectorType(getType())) {
33393361
// The exact number of elements of a scalable vector is unknown, so we
33403362
// allow only splat attributes.
@@ -3346,13 +3368,23 @@ LogicalResult LLVM::ConstantOp::verify() {
33463368
}
33473369
if (!isa<VectorType, LLVM::LLVMArrayType>(getType()))
33483370
return emitOpError() << "expected vector or array type";
3371+
33493372
// The number of elements of the attribute and the type must match.
3350-
if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) {
3351-
int64_t attrNumElements = elementsAttr.getNumElements();
3352-
if (getNumElements(getType()) != attrNumElements)
3353-
return emitOpError()
3354-
<< "type and attribute have a different number of elements: "
3355-
<< getNumElements(getType()) << " vs. " << attrNumElements;
3373+
int64_t attrNumElements = elementsAttr.getNumElements();
3374+
if (getNumElements(getType()) != attrNumElements) {
3375+
return emitOpError()
3376+
<< "type and attribute have a different number of elements: "
3377+
<< getNumElements(getType()) << " vs. " << attrNumElements;
3378+
}
3379+
3380+
Type attrElmType = getElementType(elementsAttr.getType());
3381+
Type resultElmType = getElementType(getType());
3382+
if (auto floatType = dyn_cast<FloatType>(attrElmType)) {
3383+
return verifyFloatSemantics(floatType.getFloatSemantics(), resultElmType);
3384+
}
3385+
if (isa<IntegerType>(attrElmType) && !isa<IntegerType>(resultElmType)) {
3386+
return emitOpError(
3387+
"expected integer element type for integer elements attribute");
33563388
}
33573389
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
33583390
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(getType());

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ llvm.func @struct_two_different_elements() -> !llvm.struct<(f64, f32)> {
418418
// -----
419419

420420
llvm.func @struct_wrong_element_types() -> !llvm.struct<(!llvm.array<2 x f64>, !llvm.array<2 x f64>)> {
421-
// expected-error @+1 {{expected struct element types to be floating point type or integer type}}
421+
// expected-error @+1 {{expected element of array attribute to be floating point or integer}}
422422
%0 = llvm.mlir.constant([dense<[1.0, 1.0]> : tensor<2xf64>, dense<[1.0, 1.0]> : tensor<2xf64>]) : !llvm.struct<(!llvm.array<2 x f64>, !llvm.array<2 x f64>)>
423423
llvm.return %0 : !llvm.struct<(!llvm.array<2 x f64>, !llvm.array<2 x f64>)>
424424
}
@@ -442,18 +442,34 @@ llvm.func @scalable_vec_requires_splat() -> vector<[4]xf64> {
442442

443443
// -----
444444

445-
llvm.func @integer_with_float_type() -> f32 {
445+
llvm.func @int_attr_requires_int_type() -> f32 {
446446
// expected-error @+1 {{expected integer type}}
447447
%0 = llvm.mlir.constant(1 : index) : f32
448448
llvm.return %0 : f32
449449
}
450450

451451
// -----
452452

453-
llvm.func @incompatible_float_attribute_type() -> f32 {
454-
// expected-error @below{{expected float type of width 64}}
455-
%cst = llvm.mlir.constant(1.0 : f64) : f32
456-
llvm.return %cst : f32
453+
llvm.func @vector_int_attr_requires_int_type() -> vector<2xf32> {
454+
// expected-error @+1 {{expected integer element type}}
455+
%0 = llvm.mlir.constant(dense<[1, 2]> : vector<2xi32>) : vector<2xf32>
456+
llvm.return %0 : vector<2xf32>
457+
}
458+
459+
// -----
460+
461+
llvm.func @float_attr_and_type_required_same() -> f16 {
462+
// expected-error @below{{attribute and type have different float semantics}}
463+
%cst = llvm.mlir.constant(1.0 : bf16) : f16
464+
llvm.return %cst : f16
465+
}
466+
467+
// -----
468+
469+
llvm.func @vector_float_attr_and_type_required_same() -> vector<2xf16> {
470+
// expected-error @below{{attribute and type have different float semantics}}
471+
%cst = llvm.mlir.constant(dense<[1.0, 2.0]> : vector<2xbf16>) : vector<2xf16>
472+
llvm.return %cst : vector<2xf16>
457473
}
458474

459475
// -----
@@ -466,6 +482,64 @@ llvm.func @incompatible_integer_type_for_float_attr() -> i32 {
466482

467483
// -----
468484

485+
llvm.func @vector_incompatible_integer_type_for_float_attr() -> vector<2xi8> {
486+
// expected-error @below{{expected integer type of width 16}}
487+
%cst = llvm.mlir.constant(dense<[1.0, 2.0]> : vector<2xf16>) : vector<2xi8>
488+
llvm.return %cst : vector<2xi8>
489+
}
490+
491+
// -----
492+
493+
llvm.func @vector_with_non_vector_type() -> f32 {
494+
// expected-error @below{{expected vector or array type}}
495+
%cst = llvm.mlir.constant(dense<100.0> : vector<1xf64>) : f32
496+
llvm.return %cst : f32
497+
}
498+
499+
// -----
500+
501+
llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
502+
// expected-error @below{{expected integer element type for integer elements attribute}}
503+
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
504+
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
505+
}
506+
507+
// -----
508+
509+
llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
510+
// expected-error @below{{expected integer element type for integer elements attribute}}
511+
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
512+
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
513+
}
514+
515+
// -----
516+
517+
llvm.func @invalid_struct_element_type() -> !llvm.struct<(f64, array<2 x i32>)> {
518+
// expected-error @below{{expected element of array attribute to be floating point or integer}}
519+
%0 = llvm.mlir.constant([1.0 : f64, dense<[1, 2]> : tensor<2xi32>]) : !llvm.struct<(f64, array<2 x i32>)>
520+
llvm.return %0 : !llvm.struct<(f64, array<2 x i32>)>
521+
}
522+
523+
// -----
524+
525+
llvm.func @wrong_struct_element_attr_type() -> !llvm.struct<(f64, f64)> {
526+
// expected-error @below{{expected element of array attribute to be floating point or integer}}
527+
%0 = llvm.mlir.constant([dense<[1, 2]> : tensor<2xi32>, 2.0 : f64]) : !llvm.struct<(f64, f64)>
528+
llvm.return %0 : !llvm.struct<(f64, f64)>
529+
}
530+
531+
// -----
532+
533+
llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
534+
// expected-error @below{{struct element at index 0 is of wrong type}}
535+
%0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
536+
llvm.return %0 : !llvm.struct<(f64, f64)>
537+
}
538+
539+
// -----
540+
541+
// -----
542+
469543
func.func @insertvalue_non_llvm_type(%a : i32, %b : i32) {
470544
// expected-error@+2 {{expected LLVM IR Dialect type}}
471545
llvm.insertvalue %a, %b[0] : tensor<*xi32>

mlir/test/Target/LLVMIR/llvmir-invalid.mlir

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7,54 +7,6 @@ func.func @foo() {
77

88
// -----
99

10-
llvm.func @vector_with_non_vector_type() -> f32 {
11-
// expected-error @below{{expected vector or array type}}
12-
%cst = llvm.mlir.constant(dense<100.0> : vector<1xf64>) : f32
13-
llvm.return %cst : f32
14-
}
15-
16-
// -----
17-
18-
llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
19-
// expected-error @below{{expected an array attribute for a struct constant}}
20-
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
21-
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
22-
}
23-
24-
// -----
25-
26-
llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
27-
// expected-error @below{{expected an array attribute for a struct constant}}
28-
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
29-
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
30-
}
31-
32-
// -----
33-
34-
llvm.func @invalid_struct_element_type() -> !llvm.struct<(f64, array<2 x i32>)> {
35-
// expected-error @below{{expected struct element types to be floating point type or integer type}}
36-
%0 = llvm.mlir.constant([1.0 : f64, dense<[1, 2]> : tensor<2xi32>]) : !llvm.struct<(f64, array<2 x i32>)>
37-
llvm.return %0 : !llvm.struct<(f64, array<2 x i32>)>
38-
}
39-
40-
// -----
41-
42-
llvm.func @wrong_struct_element_attr_type() -> !llvm.struct<(f64, f64)> {
43-
// expected-error @below{{expected struct element attribute types to be floating point type or integer type}}
44-
%0 = llvm.mlir.constant([dense<[1, 2]> : tensor<2xi32>, 2.0 : f64]) : !llvm.struct<(f64, f64)>
45-
llvm.return %0 : !llvm.struct<(f64, f64)>
46-
}
47-
48-
// -----
49-
50-
llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
51-
// expected-error @below{{struct element at index 0 is of wrong type}}
52-
%0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
53-
llvm.return %0 : !llvm.struct<(f64, f64)>
54-
}
55-
56-
// -----
57-
5810
// expected-error @below{{LLVM attribute 'readonly' does not expect a value}}
5911
llvm.func @passthrough_unexpected_value() attributes {passthrough = [["readonly", "42"]]}
6012

0 commit comments

Comments
 (0)