Skip to content

Commit 6e58e99

Browse files
authored
[Clang] [Sema] Reject non-power-of-2 _BitInt matrix element types (#117487)
Essentially, this makes this ill-formed: ```c++ using mat4 = _BitInt(12) [[clang::matrix_type(3, 3)]]; ``` This matches preexisting behaviour for vector types (e.g. `ext_vector_type`), and given that LLVM IR intrinsics for matrices also take vector types, it seems like a sensible thing to do. This is currently especially problematic since we sometimes lower matrix types to LLVM array types instead, and while e.g. `[4 x i32]` and `<4 x i32>` *probably* have the same similar memory layout (though I don’t think it’s sound to rely on that either, see #117486), `[4 x i12]` and `<4 x i12>` definitely don’t.
1 parent 72e8b9a commit 6e58e99

File tree

6 files changed

+141
-24
lines changed

6 files changed

+141
-24
lines changed

clang/docs/MatrixTypes.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ program is ill-formed.
3333
Currently, the element type of a matrix is only permitted to be one of the
3434
following types:
3535

36-
* an integer type (as in C23 6.2.5p22), but excluding enumerated types and ``bool``
37-
* the standard floating types ``float`` or ``double``
38-
* a half-precision floating point type, if one is supported on the target
36+
* an integer type (as in C23 6.2.5p22), but excluding enumerated types, ``bool``,
37+
and ``_BitInt`` types whose width is not a power of 2;
38+
* the standard floating types ``float`` or ``double``;
39+
* a half-precision floating point type, if one is supported on the target.
3940

4041
Other types may be supported in the future.
4142

clang/docs/ReleaseNotes.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,9 @@ Non-comprehensive list of changes in this release
416416
``__builtin_elementwise_bitreverse``, ``__builtin_elementwise_add_sat``,
417417
``__builtin_elementwise_sub_sat``.
418418

419+
- Clang now rejects ``_BitInt`` matrix element types if the bit width is less than ``CHAR_WIDTH`` or
420+
not a power of two, matching preexisting behaviour for vector types.
421+
419422
New Compiler Flags
420423
------------------
421424

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3233,8 +3233,8 @@ def err_attribute_too_few_arguments : Error<
32333233
"%0 attribute takes at least %1 argument%s1">;
32343234
def err_attribute_invalid_vector_type : Error<"invalid vector element type %0">;
32353235
def err_attribute_invalid_bitint_vector_type : Error<
3236-
"'_BitInt' vector element width must be %select{a power of 2|"
3237-
"at least as wide as 'CHAR_BIT'}0">;
3236+
"'_BitInt' %select{vector|matrix}0 element width must be %select{a power of 2|"
3237+
"at least as wide as 'CHAR_BIT'}1">;
32383238
def err_attribute_invalid_matrix_type : Error<"invalid matrix element type %0">;
32393239
def err_attribute_bad_neon_vector_size : Error<
32403240
"Neon vector size must be 64 or 128 bits">;

clang/lib/Sema/SemaType.cpp

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2312,6 +2312,17 @@ QualType Sema::BuildArrayType(QualType T, ArraySizeModifier ASM,
23122312
return T;
23132313
}
23142314

2315+
static bool CheckBitIntElementType(Sema &S, SourceLocation AttrLoc,
2316+
const BitIntType *BIT,
2317+
bool ForMatrixType = false) {
2318+
// Only support _BitInt elements with byte-sized power of 2 NumBits.
2319+
unsigned NumBits = BIT->getNumBits();
2320+
if (!llvm::isPowerOf2_32(NumBits) || NumBits < 8)
2321+
return S.Diag(AttrLoc, diag::err_attribute_invalid_bitint_vector_type)
2322+
<< ForMatrixType << (NumBits < 8);
2323+
return false;
2324+
}
2325+
23152326
QualType Sema::BuildVectorType(QualType CurType, Expr *SizeExpr,
23162327
SourceLocation AttrLoc) {
23172328
// The base type must be integer (not Boolean or enumeration) or float, and
@@ -2324,15 +2335,10 @@ QualType Sema::BuildVectorType(QualType CurType, Expr *SizeExpr,
23242335
Diag(AttrLoc, diag::err_attribute_invalid_vector_type) << CurType;
23252336
return QualType();
23262337
}
2327-
// Only support _BitInt elements with byte-sized power of 2 NumBits.
2328-
if (const auto *BIT = CurType->getAs<BitIntType>()) {
2329-
unsigned NumBits = BIT->getNumBits();
2330-
if (!llvm::isPowerOf2_32(NumBits) || NumBits < 8) {
2331-
Diag(AttrLoc, diag::err_attribute_invalid_bitint_vector_type)
2332-
<< (NumBits < 8);
2333-
return QualType();
2334-
}
2335-
}
2338+
2339+
if (const auto *BIT = CurType->getAs<BitIntType>();
2340+
BIT && CheckBitIntElementType(*this, AttrLoc, BIT))
2341+
return QualType();
23362342

23372343
if (SizeExpr->isTypeDependent() || SizeExpr->isValueDependent())
23382344
return Context.getDependentVectorType(CurType, SizeExpr, AttrLoc,
@@ -2402,15 +2408,9 @@ QualType Sema::BuildExtVectorType(QualType T, Expr *ArraySize,
24022408
return QualType();
24032409
}
24042410

2405-
// Only support _BitInt elements with byte-sized power of 2 NumBits.
2406-
if (T->isBitIntType()) {
2407-
unsigned NumBits = T->castAs<BitIntType>()->getNumBits();
2408-
if (!llvm::isPowerOf2_32(NumBits) || NumBits < 8) {
2409-
Diag(AttrLoc, diag::err_attribute_invalid_bitint_vector_type)
2410-
<< (NumBits < 8);
2411-
return QualType();
2412-
}
2413-
}
2411+
if (const auto *BIT = T->getAs<BitIntType>();
2412+
BIT && CheckBitIntElementType(*this, AttrLoc, BIT))
2413+
return QualType();
24142414

24152415
if (!ArraySize->isTypeDependent() && !ArraySize->isValueDependent()) {
24162416
std::optional<llvm::APSInt> vecSize =
@@ -2455,6 +2455,11 @@ QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
24552455
return QualType();
24562456
}
24572457

2458+
if (const auto *BIT = ElementTy->getAs<BitIntType>();
2459+
BIT &&
2460+
CheckBitIntElementType(*this, AttrLoc, BIT, /*ForMatrixType=*/true))
2461+
return QualType();
2462+
24582463
if (NumRows->isTypeDependent() || NumCols->isTypeDependent() ||
24592464
NumRows->isValueDependent() || NumCols->isValueDependent())
24602465
return Context.getDependentSizedMatrixType(ElementTy, NumRows, NumCols,
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// RUN: %clang_cc1 -fenable-matrix %s -emit-llvm -triple x86_64-unknown-linux -disable-llvm-passes -o - -std=c++11 | FileCheck %s
2+
3+
using i8x3 = _BitInt(8) __attribute__((ext_vector_type(3)));
4+
using i8x3x3 = _BitInt(8) __attribute__((matrix_type(3, 3)));
5+
using i32x3 = _BitInt(32) __attribute__((ext_vector_type(3)));
6+
using i32x3x3 = _BitInt(32) __attribute__((matrix_type(3, 3)));
7+
using i512x3 = _BitInt(512) __attribute__((ext_vector_type(3)));
8+
using i512x3x3 = _BitInt(512) __attribute__((matrix_type(3, 3)));
9+
10+
// CHECK-LABEL: define dso_local i32 @_Z2v1Dv3_DB8_(i32 %a.coerce)
11+
i8x3 v1(i8x3 a) {
12+
// CHECK-NEXT: entry:
13+
// CHECK-NEXT: %retval = alloca <3 x i8>, align 4
14+
// CHECK-NEXT: %a = alloca <3 x i8>, align 4
15+
// CHECK-NEXT: %a.addr = alloca <3 x i8>, align 4
16+
// CHECK-NEXT: store i32 %a.coerce, ptr %a, align 4
17+
// CHECK-NEXT: %loadVec4 = load <4 x i8>, ptr %a, align 4
18+
// CHECK-NEXT: %a1 = shufflevector <4 x i8> %loadVec4, <4 x i8> poison, <3 x i32> <i32 0, i32 1, i32 2>
19+
// CHECK-NEXT: %extractVec = shufflevector <3 x i8> %a1, <3 x i8> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 poison>
20+
// CHECK-NEXT: store <4 x i8> %extractVec, ptr %a.addr, align 4
21+
// CHECK-NEXT: %loadVec42 = load <4 x i8>, ptr %a.addr, align 4
22+
// CHECK-NEXT: %extractVec3 = shufflevector <4 x i8> %loadVec42, <4 x i8> poison, <3 x i32> <i32 0, i32 1, i32 2>
23+
// CHECK-NEXT: %loadVec44 = load <4 x i8>, ptr %a.addr, align 4
24+
// CHECK-NEXT: %extractVec5 = shufflevector <4 x i8> %loadVec44, <4 x i8> poison, <3 x i32> <i32 0, i32 1, i32 2>
25+
// CHECK-NEXT: %add = add <3 x i8> %extractVec3, %extractVec5
26+
// CHECK-NEXT: store <3 x i8> %add, ptr %retval, align 4
27+
// CHECK-NEXT: %0 = load i32, ptr %retval, align 4
28+
// CHECK-NEXT: ret i32 %0
29+
return a + a;
30+
}
31+
32+
// CHECK-LABEL: define dso_local noundef <3 x i32> @_Z2v2Dv3_DB32_(<3 x i32> noundef %a)
33+
i32x3 v2(i32x3 a) {
34+
// CHECK-NEXT: entry:
35+
// CHECK-NEXT: %a.addr = alloca <3 x i32>, align 16
36+
// CHECK-NEXT: %extractVec = shufflevector <3 x i32> %a, <3 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 poison>
37+
// CHECK-NEXT: store <4 x i32> %extractVec, ptr %a.addr, align 16
38+
// CHECK-NEXT: %loadVec4 = load <4 x i32>, ptr %a.addr, align 16
39+
// CHECK-NEXT: %extractVec1 = shufflevector <4 x i32> %loadVec4, <4 x i32> poison, <3 x i32> <i32 0, i32 1, i32 2>
40+
// CHECK-NEXT: %loadVec42 = load <4 x i32>, ptr %a.addr, align 16
41+
// CHECK-NEXT: %extractVec3 = shufflevector <4 x i32> %loadVec42, <4 x i32> poison, <3 x i32> <i32 0, i32 1, i32 2>
42+
// CHECK-NEXT: %add = add <3 x i32> %extractVec1, %extractVec3
43+
// CHECK-NEXT: ret <3 x i32> %add
44+
return a + a;
45+
}
46+
47+
// CHECK-LABEL: define dso_local noundef <3 x i512> @_Z2v3Dv3_DB512_(ptr noundef byval(<3 x i512>) align 256 %0)
48+
i512x3 v3(i512x3 a) {
49+
// CHECK-NEXT: entry:
50+
// CHECK-NEXT: %a.addr = alloca <3 x i512>, align 256
51+
// CHECK-NEXT: %loadVec4 = load <4 x i512>, ptr %0, align 256
52+
// CHECK-NEXT: %a = shufflevector <4 x i512> %loadVec4, <4 x i512> poison, <3 x i32> <i32 0, i32 1, i32 2>
53+
// CHECK-NEXT: %extractVec = shufflevector <3 x i512> %a, <3 x i512> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 poison>
54+
// CHECK-NEXT: store <4 x i512> %extractVec, ptr %a.addr, align 256
55+
// CHECK-NEXT: %loadVec41 = load <4 x i512>, ptr %a.addr, align 256
56+
// CHECK-NEXT: %extractVec2 = shufflevector <4 x i512> %loadVec41, <4 x i512> poison, <3 x i32> <i32 0, i32 1, i32 2>
57+
// CHECK-NEXT: %loadVec43 = load <4 x i512>, ptr %a.addr, align 256
58+
// CHECK-NEXT: %extractVec4 = shufflevector <4 x i512> %loadVec43, <4 x i512> poison, <3 x i32> <i32 0, i32 1, i32 2>
59+
// CHECK-NEXT: %add = add <3 x i512> %extractVec2, %extractVec4
60+
// CHECK-NEXT: ret <3 x i512> %add
61+
return a + a;
62+
}
63+
64+
// CHECK-LABEL: define dso_local noundef <9 x i8> @_Z2m1u11matrix_typeILm3ELm3EDB8_E(<9 x i8> noundef %a)
65+
i8x3x3 m1(i8x3x3 a) {
66+
// CHECK-NEXT: entry:
67+
// CHECK-NEXT: %a.addr = alloca [9 x i8], align 1
68+
// CHECK-NEXT: store <9 x i8> %a, ptr %a.addr, align 1
69+
// CHECK-NEXT: %0 = load <9 x i8>, ptr %a.addr, align 1
70+
// CHECK-NEXT: %1 = load <9 x i8>, ptr %a.addr, align 1
71+
// CHECK-NEXT: %2 = add <9 x i8> %0, %1
72+
// CHECK-NEXT: ret <9 x i8> %2
73+
return a + a;
74+
}
75+
76+
// CHECK-LABEL: define dso_local noundef <9 x i32> @_Z2m2u11matrix_typeILm3ELm3EDB32_E(<9 x i32> noundef %a)
77+
i32x3x3 m2(i32x3x3 a) {
78+
// CHECK-NEXT: entry:
79+
// CHECK-NEXT: %a.addr = alloca [9 x i32], align 4
80+
// CHECK-NEXT: store <9 x i32> %a, ptr %a.addr, align 4
81+
// CHECK-NEXT: %0 = load <9 x i32>, ptr %a.addr, align 4
82+
// CHECK-NEXT: %1 = load <9 x i32>, ptr %a.addr, align 4
83+
// CHECK-NEXT: %2 = add <9 x i32> %0, %1
84+
// CHECK-NEXT: ret <9 x i32> %2
85+
return a + a;
86+
}
87+
88+
// CHECK-LABEL: define dso_local noundef <9 x i512> @_Z2m3u11matrix_typeILm3ELm3EDB512_E(<9 x i512> noundef %a)
89+
i512x3x3 m3(i512x3x3 a) {
90+
// CHECK-NEXT: entry:
91+
// CHECK-NEXT: %a.addr = alloca [9 x i512], align 8
92+
// CHECK-NEXT: store <9 x i512> %a, ptr %a.addr, align 8
93+
// CHECK-NEXT: %0 = load <9 x i512>, ptr %a.addr, align 8
94+
// CHECK-NEXT: %1 = load <9 x i512>, ptr %a.addr, align 8
95+
// CHECK-NEXT: %2 = add <9 x i512> %0, %1
96+
// CHECK-NEXT: ret <9 x i512> %2
97+
return a + a;
98+
}

clang/test/SemaCXX/matrix-type.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %clang_cc1 -fsyntax-only -pedantic -fenable-matrix -std=c++11 -verify -triple x86_64-apple-darwin %s
1+
// RUN: %clang_cc1 -fsyntax-only -fenable-matrix -std=c++11 -verify -triple x86_64-apple-darwin %s
22

33
using matrix_double_t = double __attribute__((matrix_type(6, 6)));
44
using matrix_float_t = float __attribute__((matrix_type(6, 6)));
@@ -29,3 +29,13 @@ void matrix_unsupported_element_type() {
2929
using matrix3_t = bool __attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'bool'}}
3030
using matrix4_t = TestEnum __attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'TestEnum'}}
3131
}
32+
33+
void matrix_unsupported_bit_int() {
34+
using m1 = _BitInt(2) __attribute__((matrix_type(4, 4))); // expected-error{{'_BitInt' matrix element width must be at least as wide as 'CHAR_BIT'}}
35+
using m2 = _BitInt(7) __attribute__((matrix_type(4, 4))); // expected-error{{'_BitInt' matrix element width must be at least as wide as 'CHAR_BIT'}}
36+
using m3 = _BitInt(9) __attribute__((matrix_type(4, 4))); // expected-error{{'_BitInt' matrix element width must be a power of 2}}
37+
using m4 = _BitInt(12) __attribute__((matrix_type(4, 4))); // expected-error{{'_BitInt' matrix element width must be a power of 2}}
38+
using m5 = _BitInt(8) __attribute__((matrix_type(4, 4)));
39+
using m6 = _BitInt(64) __attribute__((matrix_type(4, 4)));
40+
using m7 = _BitInt(256) __attribute__((matrix_type(4, 4)));
41+
}

0 commit comments

Comments
 (0)