Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions clang/docs/MatrixTypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ program is ill-formed.
Currently, the element type of a matrix is only permitted to be one of the
following types:

* an integer type (as in C23 6.2.5p22), but excluding enumerated types and ``bool``
* the standard floating types ``float`` or ``double``
* a half-precision floating point type, if one is supported on the target
* an integer type (as in C23 6.2.5p22), but excluding enumerated types, ``bool``,
and ``_BitInt`` types whose width is not a power of 2;
* the standard floating types ``float`` or ``double``;
* a half-precision floating point type, if one is supported on the target.

Other types may be supported in the future.

Expand Down
3 changes: 3 additions & 0 deletions clang/docs/ReleaseNotes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,9 @@ Non-comprehensive list of changes in this release
- ``__builtin_reduce_and`` function can now be used in constant expressions.
- ``__builtin_reduce_or`` and ``__builtin_reduce_xor`` functions can now be used in constant expressions.

- Clang now rejects ``_BitInt`` matrix element types if the bit width is less than ``CHAR_WIDTH`` or
not a power of two, matching preexisting behaviour for vector types.

New Compiler Flags
------------------

Expand Down
4 changes: 2 additions & 2 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -3233,8 +3233,8 @@ def err_attribute_too_few_arguments : Error<
"%0 attribute takes at least %1 argument%s1">;
def err_attribute_invalid_vector_type : Error<"invalid vector element type %0">;
def err_attribute_invalid_bitint_vector_type : Error<
"'_BitInt' vector element width must be %select{a power of 2|"
"at least as wide as 'CHAR_BIT'}0">;
"'_BitInt' %select{vector|matrix}0 element width must be %select{a power of 2|"
"at least as wide as 'CHAR_BIT'}1">;
def err_attribute_invalid_matrix_type : Error<"invalid matrix element type %0">;
def err_attribute_bad_neon_vector_size : Error<
"Neon vector size must be 64 or 128 bits">;
Expand Down
41 changes: 23 additions & 18 deletions clang/lib/Sema/SemaType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2312,6 +2312,17 @@ QualType Sema::BuildArrayType(QualType T, ArraySizeModifier ASM,
return T;
}

static bool CheckBitIntElementType(Sema &S, SourceLocation AttrLoc,
const BitIntType *BIT,
bool ForMatrixType = false) {
// Only support _BitInt elements with byte-sized power of 2 NumBits.
unsigned NumBits = BIT->getNumBits();
if (!llvm::isPowerOf2_32(NumBits) || NumBits < 8)
return S.Diag(AttrLoc, diag::err_attribute_invalid_bitint_vector_type)
<< ForMatrixType << (NumBits < 8);
return false;
}

QualType Sema::BuildVectorType(QualType CurType, Expr *SizeExpr,
SourceLocation AttrLoc) {
// The base type must be integer (not Boolean or enumeration) or float, and
Expand All @@ -2324,15 +2335,10 @@ QualType Sema::BuildVectorType(QualType CurType, Expr *SizeExpr,
Diag(AttrLoc, diag::err_attribute_invalid_vector_type) << CurType;
return QualType();
}
// Only support _BitInt elements with byte-sized power of 2 NumBits.
if (const auto *BIT = CurType->getAs<BitIntType>()) {
unsigned NumBits = BIT->getNumBits();
if (!llvm::isPowerOf2_32(NumBits) || NumBits < 8) {
Diag(AttrLoc, diag::err_attribute_invalid_bitint_vector_type)
<< (NumBits < 8);
return QualType();
}
}

if (const auto *BIT = CurType->getAs<BitIntType>();
BIT && CheckBitIntElementType(*this, AttrLoc, BIT))
return QualType();

if (SizeExpr->isTypeDependent() || SizeExpr->isValueDependent())
return Context.getDependentVectorType(CurType, SizeExpr, AttrLoc,
Expand Down Expand Up @@ -2402,15 +2408,9 @@ QualType Sema::BuildExtVectorType(QualType T, Expr *ArraySize,
return QualType();
}

// Only support _BitInt elements with byte-sized power of 2 NumBits.
if (T->isBitIntType()) {
unsigned NumBits = T->castAs<BitIntType>()->getNumBits();
if (!llvm::isPowerOf2_32(NumBits) || NumBits < 8) {
Diag(AttrLoc, diag::err_attribute_invalid_bitint_vector_type)
<< (NumBits < 8);
return QualType();
}
}
if (const auto *BIT = T->getAs<BitIntType>();
BIT && CheckBitIntElementType(*this, AttrLoc, BIT))
return QualType();

if (!ArraySize->isTypeDependent() && !ArraySize->isValueDependent()) {
std::optional<llvm::APSInt> vecSize =
Expand Down Expand Up @@ -2455,6 +2455,11 @@ QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
return QualType();
}

if (const auto *BIT = ElementTy->getAs<BitIntType>();
BIT &&
CheckBitIntElementType(*this, AttrLoc, BIT, /*ForMatrixType=*/true))
return QualType();

if (NumRows->isTypeDependent() || NumCols->isTypeDependent() ||
NumRows->isValueDependent() || NumCols->isValueDependent())
return Context.getDependentSizedMatrixType(ElementTy, NumRows, NumCols,
Expand Down
98 changes: 98 additions & 0 deletions clang/test/CodeGenCXX/matrix-vector-bit-int.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// RUN: %clang_cc1 -fenable-matrix %s -emit-llvm -triple x86_64-unknown-linux -disable-llvm-passes -o - -std=c++11 | FileCheck %s

using i8x3 = _BitInt(8) __attribute__((ext_vector_type(3)));
using i8x3x3 = _BitInt(8) __attribute__((matrix_type(3, 3)));
using i32x3 = _BitInt(32) __attribute__((ext_vector_type(3)));
using i32x3x3 = _BitInt(32) __attribute__((matrix_type(3, 3)));
using i512x3 = _BitInt(512) __attribute__((ext_vector_type(3)));
using i512x3x3 = _BitInt(512) __attribute__((matrix_type(3, 3)));

// CHECK-LABEL: define dso_local i32 @_Z2v1Dv3_DB8_(i32 %a.coerce)
i8x3 v1(i8x3 a) {
// CHECK-NEXT: entry:
// CHECK-NEXT: %retval = alloca <3 x i8>, align 4
// CHECK-NEXT: %a = alloca <3 x i8>, align 4
// CHECK-NEXT: %a.addr = alloca <3 x i8>, align 4
// CHECK-NEXT: store i32 %a.coerce, ptr %a, align 4
// CHECK-NEXT: %loadVec4 = load <4 x i8>, ptr %a, align 4
// CHECK-NEXT: %a1 = shufflevector <4 x i8> %loadVec4, <4 x i8> poison, <3 x i32> <i32 0, i32 1, i32 2>
// CHECK-NEXT: %extractVec = shufflevector <3 x i8> %a1, <3 x i8> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 poison>
// CHECK-NEXT: store <4 x i8> %extractVec, ptr %a.addr, align 4
// CHECK-NEXT: %loadVec42 = load <4 x i8>, ptr %a.addr, align 4
// CHECK-NEXT: %extractVec3 = shufflevector <4 x i8> %loadVec42, <4 x i8> poison, <3 x i32> <i32 0, i32 1, i32 2>
// CHECK-NEXT: %loadVec44 = load <4 x i8>, ptr %a.addr, align 4
// CHECK-NEXT: %extractVec5 = shufflevector <4 x i8> %loadVec44, <4 x i8> poison, <3 x i32> <i32 0, i32 1, i32 2>
// CHECK-NEXT: %add = add <3 x i8> %extractVec3, %extractVec5
// CHECK-NEXT: store <3 x i8> %add, ptr %retval, align 4
// CHECK-NEXT: %0 = load i32, ptr %retval, align 4
// CHECK-NEXT: ret i32 %0
return a + a;
}

// CHECK-LABEL: define dso_local noundef <3 x i32> @_Z2v2Dv3_DB32_(<3 x i32> noundef %a)
i32x3 v2(i32x3 a) {
// CHECK-NEXT: entry:
// CHECK-NEXT: %a.addr = alloca <3 x i32>, align 16
// CHECK-NEXT: %extractVec = shufflevector <3 x i32> %a, <3 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 poison>
// CHECK-NEXT: store <4 x i32> %extractVec, ptr %a.addr, align 16
// CHECK-NEXT: %loadVec4 = load <4 x i32>, ptr %a.addr, align 16
// CHECK-NEXT: %extractVec1 = shufflevector <4 x i32> %loadVec4, <4 x i32> poison, <3 x i32> <i32 0, i32 1, i32 2>
// CHECK-NEXT: %loadVec42 = load <4 x i32>, ptr %a.addr, align 16
// CHECK-NEXT: %extractVec3 = shufflevector <4 x i32> %loadVec42, <4 x i32> poison, <3 x i32> <i32 0, i32 1, i32 2>
// CHECK-NEXT: %add = add <3 x i32> %extractVec1, %extractVec3
// CHECK-NEXT: ret <3 x i32> %add
return a + a;
}

// CHECK-LABEL: define dso_local noundef <3 x i512> @_Z2v3Dv3_DB512_(ptr noundef byval(<3 x i512>) align 256 %0)
i512x3 v3(i512x3 a) {
// CHECK-NEXT: entry:
// CHECK-NEXT: %a.addr = alloca <3 x i512>, align 256
// CHECK-NEXT: %loadVec4 = load <4 x i512>, ptr %0, align 256
// CHECK-NEXT: %a = shufflevector <4 x i512> %loadVec4, <4 x i512> poison, <3 x i32> <i32 0, i32 1, i32 2>
// CHECK-NEXT: %extractVec = shufflevector <3 x i512> %a, <3 x i512> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 poison>
// CHECK-NEXT: store <4 x i512> %extractVec, ptr %a.addr, align 256
// CHECK-NEXT: %loadVec41 = load <4 x i512>, ptr %a.addr, align 256
// CHECK-NEXT: %extractVec2 = shufflevector <4 x i512> %loadVec41, <4 x i512> poison, <3 x i32> <i32 0, i32 1, i32 2>
// CHECK-NEXT: %loadVec43 = load <4 x i512>, ptr %a.addr, align 256
// CHECK-NEXT: %extractVec4 = shufflevector <4 x i512> %loadVec43, <4 x i512> poison, <3 x i32> <i32 0, i32 1, i32 2>
// CHECK-NEXT: %add = add <3 x i512> %extractVec2, %extractVec4
// CHECK-NEXT: ret <3 x i512> %add
return a + a;
}

// CHECK-LABEL: define dso_local noundef <9 x i8> @_Z2m1u11matrix_typeILm3ELm3EDB8_E(<9 x i8> noundef %a)
i8x3x3 m1(i8x3x3 a) {
// CHECK-NEXT: entry:
// CHECK-NEXT: %a.addr = alloca [9 x i8], align 1
// CHECK-NEXT: store <9 x i8> %a, ptr %a.addr, align 1
// CHECK-NEXT: %0 = load <9 x i8>, ptr %a.addr, align 1
// CHECK-NEXT: %1 = load <9 x i8>, ptr %a.addr, align 1
// CHECK-NEXT: %2 = add <9 x i8> %0, %1
// CHECK-NEXT: ret <9 x i8> %2
return a + a;
}

// CHECK-LABEL: define dso_local noundef <9 x i32> @_Z2m2u11matrix_typeILm3ELm3EDB32_E(<9 x i32> noundef %a)
i32x3x3 m2(i32x3x3 a) {
// CHECK-NEXT: entry:
// CHECK-NEXT: %a.addr = alloca [9 x i32], align 4
// CHECK-NEXT: store <9 x i32> %a, ptr %a.addr, align 4
// CHECK-NEXT: %0 = load <9 x i32>, ptr %a.addr, align 4
// CHECK-NEXT: %1 = load <9 x i32>, ptr %a.addr, align 4
// CHECK-NEXT: %2 = add <9 x i32> %0, %1
// CHECK-NEXT: ret <9 x i32> %2
return a + a;
}

// CHECK-LABEL: define dso_local noundef <9 x i512> @_Z2m3u11matrix_typeILm3ELm3EDB512_E(<9 x i512> noundef %a)
i512x3x3 m3(i512x3x3 a) {
// CHECK-NEXT: entry:
// CHECK-NEXT: %a.addr = alloca [9 x i512], align 8
// CHECK-NEXT: store <9 x i512> %a, ptr %a.addr, align 8
// CHECK-NEXT: %0 = load <9 x i512>, ptr %a.addr, align 8
// CHECK-NEXT: %1 = load <9 x i512>, ptr %a.addr, align 8
// CHECK-NEXT: %2 = add <9 x i512> %0, %1
// CHECK-NEXT: ret <9 x i512> %2
return a + a;
}
12 changes: 11 additions & 1 deletion clang/test/SemaCXX/matrix-type.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %clang_cc1 -fsyntax-only -pedantic -fenable-matrix -std=c++11 -verify -triple x86_64-apple-darwin %s
// RUN: %clang_cc1 -fsyntax-only -fenable-matrix -std=c++11 -verify -triple x86_64-apple-darwin %s

using matrix_double_t = double __attribute__((matrix_type(6, 6)));
using matrix_float_t = float __attribute__((matrix_type(6, 6)));
Expand Down Expand Up @@ -29,3 +29,13 @@ void matrix_unsupported_element_type() {
using matrix3_t = bool __attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'bool'}}
using matrix4_t = TestEnum __attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'TestEnum'}}
}

void matrix_unsupported_bit_int() {
using m1 = _BitInt(2) __attribute__((matrix_type(4, 4))); // expected-error{{'_BitInt' matrix element width must be at least as wide as 'CHAR_BIT'}}
using m2 = _BitInt(7) __attribute__((matrix_type(4, 4))); // expected-error{{'_BitInt' matrix element width must be at least as wide as 'CHAR_BIT'}}
using m3 = _BitInt(9) __attribute__((matrix_type(4, 4))); // expected-error{{'_BitInt' matrix element width must be a power of 2}}
using m4 = _BitInt(12) __attribute__((matrix_type(4, 4))); // expected-error{{'_BitInt' matrix element width must be a power of 2}}
using m5 = _BitInt(8) __attribute__((matrix_type(4, 4)));
using m6 = _BitInt(64) __attribute__((matrix_type(4, 4)));
using m7 = _BitInt(256) __attribute__((matrix_type(4, 4)));
}
Loading