Skip to content

Commit 6edc1fa

Browse files
authored
[mlir][llvm dialect] Verify element type of nested types (#148975)
Before this PR, this was valid ``` %0 = llvm.mlir.constant(dense<[1, 2]> : vector<2xi32>) : vector<2xf32> ``` but this was not: ``` %0 = llvm.mlir.constant(1 : i32) : f32 ``` because only scalar types were checked for compatibility, not the element types of nested types. Another additional check that this PR adds is to verify the float semantics. Before this PR, ``` %cst = llvm.mlir.constant(1.0 : bf16) : f16 ``` was considered valid (because bf16 and f16 both have 16 bits), but with this PR it is not considered valid. This PR also moves all tests on the verifier of the llvm constant op into a single file. To summarize the state after this PR. Invalid: ```mlir %0 = llvm.mlir.constant(dense<[128, 1024]> : vector<2xi32>) : vector<2xf32> %0 = llvm.mlir.constant(dense<[128., 1024.]> : vector<2xbf16>) : vector<2xf16> ``` Valid: ```mlir %0 = llvm.mlir.constant(dense<[128., 1024.]> : vector<2xf32>) : vector<2xi32> %0 = llvm.mlir.constant(dense<[128, 1024]> : vector<2xi64>) : vector<2xi8> ``` and identical valid/invalid cases for the scalar cases.
1 parent 401b5cc commit 6edc1fa

File tree

3 files changed

+183
-114
lines changed

3 files changed

+183
-114
lines changed

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 75 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/Interfaces/FunctionImplementation.h"
2525
#include "mlir/Transforms/InliningUtils.h"
2626

27+
#include "llvm/ADT/APFloat.h"
2728
#include "llvm/ADT/TypeSwitch.h"
2829
#include "llvm/IR/Function.h"
2930
#include "llvm/IR/Type.h"
@@ -3187,6 +3188,18 @@ static int64_t getNumElements(Type t) {
31873188
return 1;
31883189
}
31893190

3191+
/// Determine the element type of `type`. Supported types are `VectorType`,
3192+
/// `TensorType`, and `LLVMArrayType`. Everything else is treated as a scalar.
3193+
static Type getElementType(Type type) {
3194+
while (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(type))
3195+
type = arrayType.getElementType();
3196+
if (auto vecType = dyn_cast<VectorType>(type))
3197+
return vecType.getElementType();
3198+
if (auto tenType = dyn_cast<TensorType>(type))
3199+
return tenType.getElementType();
3200+
return type;
3201+
}
3202+
31903203
/// Check if the given type is a scalable vector type or a vector/array type
31913204
/// that contains a nested scalable vector type.
31923205
static bool hasScalableVectorType(Type t) {
@@ -3281,60 +3294,69 @@ LogicalResult LLVM::ConstantOp::verify() {
32813294
}
32823295
if (auto structType = dyn_cast<LLVMStructType>(getType())) {
32833296
auto arrayAttr = dyn_cast<ArrayAttr>(getValue());
3284-
if (!arrayAttr) {
3285-
return emitOpError() << "expected array attribute for a struct constant";
3286-
}
3297+
if (!arrayAttr)
3298+
return emitOpError() << "expected array attribute for struct type";
32873299

32883300
ArrayRef<Type> elementTypes = structType.getBody();
32893301
if (arrayAttr.size() != elementTypes.size()) {
32903302
return emitOpError() << "expected array attribute of size "
32913303
<< elementTypes.size();
32923304
}
3293-
for (auto elementTy : elementTypes) {
3294-
if (!isa<IntegerType, FloatType, LLVMPPCFP128Type>(elementTy)) {
3305+
for (auto [i, attr, type] : llvm::enumerate(arrayAttr, elementTypes)) {
3306+
if (!type.isSignlessIntOrIndexOrFloat()) {
32953307
return emitOpError() << "expected struct element types to be floating "
32963308
"point type or integer type";
32973309
}
3298-
}
3299-
3300-
for (size_t i = 0; i < elementTypes.size(); ++i) {
3301-
Attribute element = arrayAttr[i];
3302-
if (!isa<IntegerAttr, FloatAttr>(element)) {
3303-
return emitOpError()
3304-
<< "expected struct element attribute types to be floating "
3305-
"point type or integer type";
3310+
if (!isa<FloatAttr, IntegerAttr>(attr)) {
3311+
return emitOpError() << "expected element of array attribute to be "
3312+
"floating point or integer";
33063313
}
3307-
auto elementType = cast<TypedAttr>(element).getType();
3308-
if (elementType != elementTypes[i]) {
3314+
if (cast<TypedAttr>(attr).getType() != type)
33093315
return emitOpError()
33103316
<< "struct element at index " << i << " is of wrong type";
3311-
}
33123317
}
33133318

33143319
return success();
33153320
}
3316-
if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
3321+
if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType()))
33173322
return emitOpError() << "does not support target extension type.";
3318-
}
3323+
3324+
// Check that an attribute whose element type has floating point semantics
3325+
// `attributeFloatSemantics` is compatible with a type whose element type
3326+
// is `constantElementType`.
3327+
//
3328+
// Requirement is that either
3329+
// 1) They have identical floating point types.
3330+
// 2) `constantElementType` is an integer type of the same width as the float
3331+
// attribute. This is to support builtin MLIR float types without LLVM
3332+
// equivalents, see comments in getLLVMConstant for more details.
3333+
auto verifyFloatSemantics =
3334+
[this](const llvm::fltSemantics &attributeFloatSemantics,
3335+
Type constantElementType) -> LogicalResult {
3336+
if (auto floatType = dyn_cast<FloatType>(constantElementType)) {
3337+
if (&floatType.getFloatSemantics() != &attributeFloatSemantics) {
3338+
return emitOpError()
3339+
<< "attribute and type have different float semantics";
3340+
}
3341+
return success();
3342+
}
3343+
unsigned floatWidth = APFloat::getSizeInBits(attributeFloatSemantics);
3344+
if (isa<IntegerType>(constantElementType)) {
3345+
if (!constantElementType.isInteger(floatWidth))
3346+
return emitOpError() << "expected integer type of width " << floatWidth;
3347+
3348+
return success();
3349+
}
3350+
return success();
3351+
};
33193352

33203353
// Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
3321-
if (auto intAttr = dyn_cast<IntegerAttr>(getValue())) {
3354+
if (isa<IntegerAttr>(getValue())) {
33223355
if (!llvm::isa<IntegerType>(getType()))
33233356
return emitOpError() << "expected integer type";
33243357
} else if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
3325-
const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
3326-
unsigned floatWidth = APFloat::getSizeInBits(sem);
3327-
if (auto floatTy = dyn_cast<FloatType>(getType())) {
3328-
if (floatTy.getWidth() != floatWidth) {
3329-
return emitOpError() << "expected float type of width " << floatWidth;
3330-
}
3331-
}
3332-
// See the comment for getLLVMConstant for more details about why 8-bit
3333-
// floats can be represented by integers.
3334-
if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) {
3335-
return emitOpError() << "expected integer type of width " << floatWidth;
3336-
}
3337-
} else if (isa<ElementsAttr>(getValue())) {
3358+
return verifyFloatSemantics(floatAttr.getValue().getSemantics(), getType());
3359+
} else if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) {
33383360
if (hasScalableVectorType(getType())) {
33393361
// The exact number of elements of a scalable vector is unknown, so we
33403362
// allow only splat attributes.
@@ -3346,18 +3368,32 @@ LogicalResult LLVM::ConstantOp::verify() {
33463368
}
33473369
if (!isa<VectorType, LLVM::LLVMArrayType>(getType()))
33483370
return emitOpError() << "expected vector or array type";
3371+
33493372
// The number of elements of the attribute and the type must match.
3350-
if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) {
3351-
int64_t attrNumElements = elementsAttr.getNumElements();
3352-
if (getNumElements(getType()) != attrNumElements)
3353-
return emitOpError()
3354-
<< "type and attribute have a different number of elements: "
3355-
<< getNumElements(getType()) << " vs. " << attrNumElements;
3373+
int64_t attrNumElements = elementsAttr.getNumElements();
3374+
if (getNumElements(getType()) != attrNumElements) {
3375+
return emitOpError()
3376+
<< "type and attribute have a different number of elements: "
3377+
<< getNumElements(getType()) << " vs. " << attrNumElements;
3378+
}
3379+
3380+
Type attrElmType = getElementType(elementsAttr.getType());
3381+
Type resultElmType = getElementType(getType());
3382+
if (auto floatType = dyn_cast<FloatType>(attrElmType))
3383+
return verifyFloatSemantics(floatType.getFloatSemantics(), resultElmType);
3384+
3385+
if (isa<IntegerType>(attrElmType) && !isa<IntegerType>(resultElmType)) {
3386+
return emitOpError(
3387+
"expected integer element type for integer elements attribute");
33563388
}
33573389
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
3390+
3391+
// The case where the constant is LLVMStructType has already been handled.
33583392
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(getType());
33593393
if (!arrayType)
3360-
return emitOpError() << "expected array type";
3394+
return emitOpError()
3395+
<< "expected array or struct type for array attribute";
3396+
33613397
// When the attribute is an ArrayAttr, check that its nesting matches the
33623398
// corresponding ArrayType or VectorType nesting.
33633399
return verifyStructArrayConstant(*this, arrayType, arrayAttr, /*dim=*/0);

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ llvm.func @array_attribute_two_different_types() -> !llvm.struct<(f64, f64)> {
394394
// -----
395395

396396
llvm.func @struct_wrong_attribute_type() -> !llvm.struct<(f64, f64)> {
397-
// expected-error @+1 {{expected array attribute}}
397+
// expected-error @+1 {{expected array attribute for struct type}}
398398
%0 = llvm.mlir.constant(1.0 : f64) : !llvm.struct<(f64, f64)>
399399
llvm.return %0 : !llvm.struct<(f64, f64)>
400400
}
@@ -439,6 +439,111 @@ llvm.func @scalable_vec_requires_splat() -> vector<[4]xf64> {
439439
llvm.return %0 : vector<[4]xf64>
440440
}
441441

442+
443+
// -----
444+
445+
llvm.func @int_attr_requires_int_type() -> f32 {
446+
// expected-error @below{{expected integer type}}
447+
%0 = llvm.mlir.constant(1 : index) : f32
448+
llvm.return %0 : f32
449+
}
450+
451+
// -----
452+
453+
llvm.func @vector_int_attr_requires_int_type() -> vector<2xf32> {
454+
// expected-error @below{{expected integer element type}}
455+
%0 = llvm.mlir.constant(dense<[1, 2]> : vector<2xi32>) : vector<2xf32>
456+
llvm.return %0 : vector<2xf32>
457+
}
458+
459+
// -----
460+
461+
llvm.func @float_attr_and_type_required_same() -> f16 {
462+
// expected-error @below{{attribute and type have different float semantics}}
463+
%cst = llvm.mlir.constant(1.0 : bf16) : f16
464+
llvm.return %cst : f16
465+
}
466+
467+
// -----
468+
469+
llvm.func @vector_float_attr_and_type_required_same() -> vector<2xf16> {
470+
// expected-error @below{{attribute and type have different float semantics}}
471+
%cst = llvm.mlir.constant(dense<[1.0, 2.0]> : vector<2xbf16>) : vector<2xf16>
472+
llvm.return %cst : vector<2xf16>
473+
}
474+
475+
// -----
476+
477+
llvm.func @incompatible_integer_type_for_float_attr() -> i32 {
478+
// expected-error @below{{expected integer type of width 16}}
479+
%cst = llvm.mlir.constant(1.0 : f16) : i32
480+
llvm.return %cst : i32
481+
}
482+
483+
// -----
484+
485+
llvm.func @vector_incompatible_integer_type_for_float_attr() -> vector<2xi8> {
486+
// expected-error @below{{expected integer type of width 16}}
487+
%cst = llvm.mlir.constant(dense<[1.0, 2.0]> : vector<2xf16>) : vector<2xi8>
488+
llvm.return %cst : vector<2xi8>
489+
}
490+
491+
// -----
492+
493+
llvm.func @vector_with_non_vector_type() -> f32 {
494+
// expected-error @below{{expected vector or array type}}
495+
%cst = llvm.mlir.constant(dense<100.0> : vector<1xf64>) : f32
496+
llvm.return %cst : f32
497+
}
498+
499+
// -----
500+
501+
llvm.func @array_attr_with_invalid_type() -> i32 {
502+
// expected-error @below{{expected array or struct type for array attribute}}
503+
%0 = llvm.mlir.constant([1 : i32]) : i32
504+
llvm.return %0 : i32
505+
}
506+
507+
// -----
508+
509+
llvm.func @elements_attribute_incompatible_nested_array_struct1_type() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
510+
// expected-error @below{{expected integer element type for integer elements attribute}}
511+
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
512+
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
513+
}
514+
515+
// -----
516+
517+
llvm.func @elements_attribute_incompatible_nested_array_struct3_type() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
518+
// expected-error @below{{expected integer element type for integer elements attribute}}
519+
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
520+
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
521+
}
522+
523+
// -----
524+
525+
llvm.func @invalid_struct_element_type() -> !llvm.struct<(f64, array<2 x i32>)> {
526+
// expected-error @below{{expected struct element types to be floating point type or integer type}}
527+
%0 = llvm.mlir.constant([1.0 : f64, dense<[1, 2]> : tensor<2xi32>]) : !llvm.struct<(f64, array<2 x i32>)>
528+
llvm.return %0 : !llvm.struct<(f64, array<2 x i32>)>
529+
}
530+
531+
// -----
532+
533+
llvm.func @wrong_struct_element_attr_type() -> !llvm.struct<(f64, f64)> {
534+
// expected-error @below{{expected element of array attribute to be floating point or integer}}
535+
%0 = llvm.mlir.constant([dense<[1, 2]> : tensor<2xi32>, 2.0 : f64]) : !llvm.struct<(f64, f64)>
536+
llvm.return %0 : !llvm.struct<(f64, f64)>
537+
}
538+
539+
// -----
540+
541+
llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
542+
// expected-error @below{{struct element at index 0 is of wrong type}}
543+
%0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
544+
llvm.return %0 : !llvm.struct<(f64, f64)>
545+
}
546+
442547
// -----
443548

444549
func.func @insertvalue_non_llvm_type(%a : i32, %b : i32) {
@@ -484,13 +589,13 @@ func.func @extractvalue_invalid_type(%a : !llvm.array<4 x vector<8xf32>>) -> !ll
484589
return %b : !llvm.array<4 x vector<8xf32>>
485590
}
486591

487-
488592
// -----
489593

490594
func.func @extractvalue_non_llvm_type(%a : i32, %b : tensor<*xi32>) {
491595
// expected-error@+2 {{expected LLVM IR Dialect type}}
492596
llvm.extractvalue %b[0] : tensor<*xi32>
493597
}
598+
494599
// -----
495600

496601
func.func @extractvalue_struct_out_of_bounds() {
@@ -659,6 +764,7 @@ func.func @atomicrmw_scalable_vector(%ptr : !llvm.ptr, %f32_vec : vector<[2]xf32
659764
%0 = llvm.atomicrmw fadd %ptr, %f32_vec unordered : !llvm.ptr, vector<[2]xf32>
660765
llvm.return
661766
}
767+
662768
// -----
663769

664770
func.func @atomicrmw_vector_expected_float(%ptr : !llvm.ptr, %i32_vec : vector<3xi32>) {
@@ -1667,7 +1773,6 @@ func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !
16671773
return
16681774
}
16691775

1670-
16711776
// -----
16721777

16731778
func.func @tma_load(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {

mlir/test/Target/LLVMIR/llvmir-invalid.mlir

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -7,78 +7,6 @@ func.func @foo() {
77

88
// -----
99

10-
llvm.func @vector_with_non_vector_type() -> f32 {
11-
// expected-error @below{{expected vector or array type}}
12-
%cst = llvm.mlir.constant(dense<100.0> : vector<1xf64>) : f32
13-
llvm.return %cst : f32
14-
}
15-
16-
// -----
17-
18-
llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
19-
// expected-error @below{{expected an array attribute for a struct constant}}
20-
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
21-
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
22-
}
23-
24-
// -----
25-
26-
llvm.func @non_array_attr_for_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> {
27-
// expected-error @below{{expected an array attribute for a struct constant}}
28-
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
29-
llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>>
30-
}
31-
32-
// -----
33-
34-
llvm.func @invalid_struct_element_type() -> !llvm.struct<(f64, array<2 x i32>)> {
35-
// expected-error @below{{expected struct element types to be floating point type or integer type}}
36-
%0 = llvm.mlir.constant([1.0 : f64, dense<[1, 2]> : tensor<2xi32>]) : !llvm.struct<(f64, array<2 x i32>)>
37-
llvm.return %0 : !llvm.struct<(f64, array<2 x i32>)>
38-
}
39-
40-
// -----
41-
42-
llvm.func @wrong_struct_element_attr_type() -> !llvm.struct<(f64, f64)> {
43-
// expected-error @below{{expected struct element attribute types to be floating point type or integer type}}
44-
%0 = llvm.mlir.constant([dense<[1, 2]> : tensor<2xi32>, 2.0 : f64]) : !llvm.struct<(f64, f64)>
45-
llvm.return %0 : !llvm.struct<(f64, f64)>
46-
}
47-
48-
// -----
49-
50-
llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
51-
// expected-error @below{{struct element at index 0 is of wrong type}}
52-
%0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)>
53-
llvm.return %0 : !llvm.struct<(f64, f64)>
54-
}
55-
56-
// -----
57-
58-
llvm.func @integer_with_float_type() -> f32 {
59-
// expected-error @+1 {{expected integer type}}
60-
%0 = llvm.mlir.constant(1 : index) : f32
61-
llvm.return %0 : f32
62-
}
63-
64-
// -----
65-
66-
llvm.func @incompatible_float_attribute_type() -> f32 {
67-
// expected-error @below{{expected float type of width 64}}
68-
%cst = llvm.mlir.constant(1.0 : f64) : f32
69-
llvm.return %cst : f32
70-
}
71-
72-
// -----
73-
74-
llvm.func @incompatible_integer_type_for_float_attr() -> i32 {
75-
// expected-error @below{{expected integer type of width 16}}
76-
%cst = llvm.mlir.constant(1.0 : f16) : i32
77-
llvm.return %cst : i32
78-
}
79-
80-
// -----
81-
8210
// expected-error @below{{LLVM attribute 'readonly' does not expect a value}}
8311
llvm.func @passthrough_unexpected_value() attributes {passthrough = [["readonly", "42"]]}
8412

0 commit comments

Comments
 (0)