Skip to content

Commit d56d537

Browse files
authored
[Clang] Allow vector and matrix type attributes for sub-byte _BitInt (#140253)
It is useful for several cases, particularly for 4-bit integers.
1 parent 46ab684 commit d56d537

File tree

8 files changed

+122
-23
lines changed

8 files changed

+122
-23
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3257,8 +3257,7 @@ def err_attribute_too_few_arguments : Error<
32573257
"%0 attribute takes at least %1 argument%s1">;
32583258
def err_attribute_invalid_vector_type : Error<"invalid vector element type %0">;
32593259
def err_attribute_invalid_bitint_vector_type : Error<
3260-
"'_BitInt' %select{vector|matrix}0 element width must be %select{a power of 2|"
3261-
"at least as wide as 'CHAR_BIT'}1">;
3260+
"'_BitInt' %select{vector|matrix}0 element width must be a power of 2">;
32623261
def err_attribute_invalid_matrix_type : Error<"invalid matrix element type %0">;
32633262
def err_attribute_bad_neon_vector_size : Error<
32643263
"Neon vector size must be 64 or 128 bits">;

clang/lib/AST/ASTContext.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4577,8 +4577,7 @@ QualType ASTContext::getVectorType(QualType vecType, unsigned NumElts,
45774577
assert(vecType->isBuiltinType() ||
45784578
(vecType->isBitIntType() &&
45794579
// Only support _BitInt elements with byte-sized power of 2 NumBits.
4580-
llvm::isPowerOf2_32(vecType->castAs<BitIntType>()->getNumBits()) &&
4581-
vecType->castAs<BitIntType>()->getNumBits() >= 8));
4580+
llvm::isPowerOf2_32(vecType->castAs<BitIntType>()->getNumBits())));
45824581

45834582
// Check if we've already instantiated a vector of this type.
45844583
llvm::FoldingSetNodeID ID;
@@ -4650,8 +4649,7 @@ QualType ASTContext::getExtVectorType(QualType vecType,
46504649
assert(vecType->isBuiltinType() || vecType->isDependentType() ||
46514650
(vecType->isBitIntType() &&
46524651
// Only support _BitInt elements with byte-sized power of 2 NumBits.
4653-
llvm::isPowerOf2_32(vecType->castAs<BitIntType>()->getNumBits()) &&
4654-
vecType->castAs<BitIntType>()->getNumBits() >= 8));
4652+
llvm::isPowerOf2_32(vecType->castAs<BitIntType>()->getNumBits())));
46554653

46564654
// Check if we've already instantiated a vector of this type.
46574655
llvm::FoldingSetNodeID ID;

clang/lib/Sema/SemaType.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2321,9 +2321,9 @@ static bool CheckBitIntElementType(Sema &S, SourceLocation AttrLoc,
23212321
bool ForMatrixType = false) {
23222322
// Only support _BitInt elements with byte-sized power of 2 NumBits.
23232323
unsigned NumBits = BIT->getNumBits();
2324-
if (!llvm::isPowerOf2_32(NumBits) || NumBits < 8)
2324+
if (!llvm::isPowerOf2_32(NumBits))
23252325
return S.Diag(AttrLoc, diag::err_attribute_invalid_bitint_vector_type)
2326-
<< ForMatrixType << (NumBits < 8);
2326+
<< ForMatrixType;
23272327
return false;
23282328
}
23292329

clang/test/CodeGenCXX/ext-int.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,41 @@ void VectorTest(uint16_t4 first, uint16_t4 second) {
571571
// CHECK: %[[ADD:.+]] = add <3 x i16> %[[Shuffle]], %[[Shuffle1]]
572572
}
573573

574+
typedef unsigned _BitInt(4) uint4_t4 __attribute__((ext_vector_type(4)));
575+
void VectorTest(uint4_t4 first, uint4_t4 second) {
576+
// LIN64: define{{.*}} void @_Z10VectorTestDv4_DU4_S0_(i32 %{{.+}}, i32 %{{.+}})
577+
// LIN32: define{{.*}} void @_Z10VectorTestDv4_DU4_S0_(<4 x i4> %{{.+}}, <4 x i4> %{{.+}})
578+
// WIN64: define dso_local void @"?VectorTest@@YAXT?$__vector@U?$_UBitInt@$03@__clang@@$03@__clang@@0@Z"(<4 x i4> %{{.+}}, <4 x i4> %{{.+}})
579+
// WIN32: define dso_local void @"?VectorTest@@YAXT?$__vector@U?$_UBitInt@$03@__clang@@$03@__clang@@0@Z"(<4 x i4> inreg %{{.+}}, <4 x i4> inreg %{{.+}})
580+
first.xzw + second.zwx;
581+
// CHECK: %[[Shuffle:.+]] = shufflevector <4 x i4> %{{.+}}, <4 x i4> poison, <3 x i32> <i32 0, i32 2, i32 3>
582+
// CHECK: %[[Shuffle1:.+]] = shufflevector <4 x i4> %{{.+}}, <4 x i4> poison, <3 x i32> <i32 2, i32 3, i32 0>
583+
// CHECK: %[[ADD:.+]] = add <3 x i4> %[[Shuffle]], %[[Shuffle1]]
584+
}
585+
586+
typedef unsigned _BitInt(2) uint2_t2 __attribute__((ext_vector_type(2)));
587+
uint2_t2 TestBitIntVector2x2Alloca(uint2_t2 v1, uint2_t2 v2) {
588+
// LIN64: define dso_local i16 @_Z25TestBitIntVector2x2AllocaDv2_DU2_S0_(i16 %[[V1Coerce:.+]], i16 %[[V2Coerce:.+]])
589+
// LIN64: %[[RetVal:.+]] = alloca <2 x i2>, align 2
590+
// LIN64: %[[V1Addr:.+]] = alloca <2 x i2>, align 2
591+
// LIN64: %[[V2Addr:.+]] = alloca <2 x i2>, align 2
592+
// LIN64: %[[RetValCoerce:.+]] = alloca i16, align 2
593+
// LIN64: call void @llvm.memcpy.p0.p0.i64(ptr align 2 %[[RetValCoerce]], ptr align 2 %[[RetVal]], i64 1, i1 false)
594+
// LIN64: %[[Ret:.+]] = load i16, ptr %[[RetValCoerce]], align 2
595+
// LIN64: ret i16 %[[Ret]]
596+
597+
// LIN32: define dso_local <2 x i2> @_Z25TestBitIntVector2x2AllocaDv2_DU2_S0_(<2 x i2> %{{.+}}, <2 x i2> %{{.+}})
598+
// LIN32: %[[V1Addr:.+]] = alloca <2 x i2>, align 2
599+
// LIN32: %[[V2Addr:.+]] = alloca <2 x i2>, align 2
600+
// LIN32: ret <2 x i2> %[[Ret:.+]]
601+
602+
// WIN: define dso_local <2 x i2> @"?TestBitIntVector2x2Alloca@@YAT?$__vector@U?$_UBitInt@$01@__clang@@$01@__clang@@T12@0@Z"(<2 x i2>{{.*}}, <2 x i2>{{.*}})
603+
// WIN: %[[V1:.+]] = alloca <2 x i2>, align 2
604+
// WIN: %[[V2:.+]] = alloca <2 x i2>, align 2
605+
// WIN: ret <2 x i2> %[[Ret:.+]]
606+
return v1 + v2;
607+
}
608+
574609
// Ensure that these types don't alias the normal int types.
575610
void TBAATest(_BitInt(sizeof(int) * 8) ExtInt,
576611
unsigned _BitInt(sizeof(int) * 8) ExtUInt,

clang/test/CodeGenCXX/matrix-vector-bit-int.cpp

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ using i32x3 = _BitInt(32) __attribute__((ext_vector_type(3)));
77
using i32x3x3 = _BitInt(32) __attribute__((matrix_type(3, 3)));
88
using i512x3 = _BitInt(512) __attribute__((ext_vector_type(3)));
99
using i512x3x3 = _BitInt(512) __attribute__((matrix_type(3, 3)));
10+
using i4x3 = _BitInt(4) __attribute__((ext_vector_type(3)));
11+
using i4x3x3 = _BitInt(4) __attribute__((matrix_type(3, 3)));
1012

1113
// CHECK-LABEL: define dso_local i32 @_Z2v1Dv3_DB8_(
1214
// CHECK-SAME: i32 [[A_COERCE:%.*]]) #[[ATTR0:[0-9]+]] {
@@ -68,8 +70,34 @@ i512x3 v3(i512x3 a) {
6870
return a + a;
6971
}
7072

73+
// CHECK-LABEL: define dso_local i32 @_Z2v4Dv3_DB4_(
74+
// CHECK-SAME: i32 [[A_COERCE:%.*]]) #[[ATTR0]] {
75+
// CHECK-NEXT: [[ENTRY:.*:]]
76+
// CHECK-NEXT: [[RETVAL:%.*]] = alloca <3 x i4>, align 4
77+
// CHECK-NEXT: [[A:%.*]] = alloca <3 x i4>, align 4
78+
// CHECK-NEXT: [[A_ADDR:%.*]] = alloca <3 x i4>, align 4
79+
// CHECK-NEXT: [[RETVAL_COERCE:%.*]] = alloca i32, align 4
80+
// CHECK-NEXT: store i32 [[A_COERCE]], ptr [[A]], align 4
81+
// CHECK-NEXT: [[LOADVECN:%.*]] = load <4 x i4>, ptr [[A]], align 4
82+
// CHECK-NEXT: [[A1:%.*]] = shufflevector <4 x i4> [[LOADVECN]], <4 x i4> poison, <3 x i32> <i32 0, i32 1, i32 2>
83+
// CHECK-NEXT: [[EXTRACTVEC:%.*]] = shufflevector <3 x i4> [[A1]], <3 x i4> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 poison>
84+
// CHECK-NEXT: store <4 x i4> [[EXTRACTVEC]], ptr [[A_ADDR]], align 4
85+
// CHECK-NEXT: [[LOADVECN2:%.*]] = load <4 x i4>, ptr [[A_ADDR]], align 4
86+
// CHECK-NEXT: [[EXTRACTVEC3:%.*]] = shufflevector <4 x i4> [[LOADVECN2]], <4 x i4> poison, <3 x i32> <i32 0, i32 1, i32 2>
87+
// CHECK-NEXT: [[LOADVECN4:%.*]] = load <4 x i4>, ptr [[A_ADDR]], align 4
88+
// CHECK-NEXT: [[EXTRACTVEC5:%.*]] = shufflevector <4 x i4> [[LOADVECN4]], <4 x i4> poison, <3 x i32> <i32 0, i32 1, i32 2>
89+
// CHECK-NEXT: [[ADD:%.*]] = add <3 x i4> [[EXTRACTVEC3]], [[EXTRACTVEC5]]
90+
// CHECK-NEXT: store <3 x i4> [[ADD]], ptr [[RETVAL]], align 4
91+
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr align 4 [[RETVAL_COERCE]], ptr align 4 [[RETVAL]], i64 2, i1 false)
92+
// CHECK-NEXT: [[TMP0:%.*]] = load i32, ptr [[RETVAL_COERCE]], align 4
93+
// CHECK-NEXT: ret i32 [[TMP0]]
94+
//
95+
i4x3 v4(i4x3 a) {
96+
return a + a;
97+
}
98+
7199
// CHECK-LABEL: define dso_local noundef <9 x i8> @_Z2m1u11matrix_typeILm3ELm3EDB8_E(
72-
// CHECK-SAME: <9 x i8> noundef [[A:%.*]]) #[[ATTR3:[0-9]+]] {
100+
// CHECK-SAME: <9 x i8> noundef [[A:%.*]]) #[[ATTR4:[0-9]+]] {
73101
// CHECK-NEXT: [[ENTRY:.*:]]
74102
// CHECK-NEXT: [[A_ADDR:%.*]] = alloca [9 x i8], align 1
75103
// CHECK-NEXT: store <9 x i8> [[A]], ptr [[A_ADDR]], align 1
@@ -83,7 +111,7 @@ i8x3x3 m1(i8x3x3 a) {
83111
}
84112

85113
// CHECK-LABEL: define dso_local noundef <9 x i32> @_Z2m2u11matrix_typeILm3ELm3EDB32_E(
86-
// CHECK-SAME: <9 x i32> noundef [[A:%.*]]) #[[ATTR4:[0-9]+]] {
114+
// CHECK-SAME: <9 x i32> noundef [[A:%.*]]) #[[ATTR5:[0-9]+]] {
87115
// CHECK-NEXT: [[ENTRY:.*:]]
88116
// CHECK-NEXT: [[A_ADDR:%.*]] = alloca [9 x i32], align 4
89117
// CHECK-NEXT: store <9 x i32> [[A]], ptr [[A_ADDR]], align 4
@@ -97,7 +125,7 @@ i32x3x3 m2(i32x3x3 a) {
97125
}
98126

99127
// CHECK-LABEL: define dso_local noundef <9 x i512> @_Z2m3u11matrix_typeILm3ELm3EDB512_E(
100-
// CHECK-SAME: <9 x i512> noundef [[A:%.*]]) #[[ATTR5:[0-9]+]] {
128+
// CHECK-SAME: <9 x i512> noundef [[A:%.*]]) #[[ATTR6:[0-9]+]] {
101129
// CHECK-NEXT: [[ENTRY:.*:]]
102130
// CHECK-NEXT: [[A_ADDR:%.*]] = alloca [9 x i512], align 8
103131
// CHECK-NEXT: store <9 x i512> [[A]], ptr [[A_ADDR]], align 8
@@ -109,3 +137,17 @@ i32x3x3 m2(i32x3x3 a) {
109137
i512x3x3 m3(i512x3x3 a) {
110138
return a + a;
111139
}
140+
141+
// CHECK-LABEL: define dso_local noundef <9 x i4> @_Z2m4u11matrix_typeILm3ELm3EDB4_E(
142+
// CHECK-SAME: <9 x i4> noundef [[A:%.*]]) #[[ATTR7:[0-9]+]] {
143+
// CHECK-NEXT: [[ENTRY:.*:]]
144+
// CHECK-NEXT: [[A_ADDR:%.*]] = alloca [9 x i4], align 1
145+
// CHECK-NEXT: store <9 x i4> [[A]], ptr [[A_ADDR]], align 1
146+
// CHECK-NEXT: [[TMP0:%.*]] = load <9 x i4>, ptr [[A_ADDR]], align 1
147+
// CHECK-NEXT: [[TMP1:%.*]] = load <9 x i4>, ptr [[A_ADDR]], align 1
148+
// CHECK-NEXT: [[TMP2:%.*]] = add <9 x i4> [[TMP0]], [[TMP1]]
149+
// CHECK-NEXT: ret <9 x i4> [[TMP2]]
150+
//
151+
i4x3x3 m4(i4x3x3 a) {
152+
return a + a;
153+
}

clang/test/SemaCXX/ext-int.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,17 +84,9 @@ struct is_same<T,T> {
8484
};
8585

8686
// Reject vector types:
87-
// expected-error@+1{{'_BitInt' vector element width must be at least as wide as 'CHAR_BIT'}}
88-
typedef _BitInt(2) __attribute__((vector_size(16))) VecTy;
89-
// expected-error@+1{{'_BitInt' vector element width must be at least as wide as 'CHAR_BIT'}}
90-
typedef _BitInt(2) __attribute__((ext_vector_type(32))) OtherVecTy;
91-
// expected-error@+1{{'_BitInt' vector element width must be at least as wide as 'CHAR_BIT'}}
92-
typedef _BitInt(4) __attribute__((vector_size(16))) VecTy2;
93-
// expected-error@+1{{'_BitInt' vector element width must be at least as wide as 'CHAR_BIT'}}
94-
typedef _BitInt(4) __attribute__((ext_vector_type(32))) OtherVecTy2;
95-
// expected-error@+1{{'_BitInt' vector element width must be at least as wide as 'CHAR_BIT'}}
87+
// expected-error@+1{{'_BitInt' vector element width must be a power of 2}}
9688
typedef _BitInt(5) __attribute__((vector_size(16))) VecTy3;
97-
// expected-error@+1{{'_BitInt' vector element width must be at least as wide as 'CHAR_BIT'}}
89+
// expected-error@+1{{'_BitInt' vector element width must be a power of 2}}
9890
typedef _BitInt(5) __attribute__((ext_vector_type(32))) OtherVecTy3;
9991
// expected-error@+1{{'_BitInt' vector element width must be a power of 2}}
10092
typedef _BitInt(37) __attribute__((vector_size(16))) VecTy4;

clang/test/SemaCXX/matrix-type.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ void matrix_unsupported_element_type() {
3131
}
3232

3333
void matrix_unsupported_bit_int() {
34-
using m1 = _BitInt(2) __attribute__((matrix_type(4, 4))); // expected-error{{'_BitInt' matrix element width must be at least as wide as 'CHAR_BIT'}}
35-
using m2 = _BitInt(7) __attribute__((matrix_type(4, 4))); // expected-error{{'_BitInt' matrix element width must be at least as wide as 'CHAR_BIT'}}
34+
using m2 = _BitInt(7) __attribute__((matrix_type(4, 4))); // expected-error{{'_BitInt' matrix element width must be a power of 2}}
3635
using m3 = _BitInt(9) __attribute__((matrix_type(4, 4))); // expected-error{{'_BitInt' matrix element width must be a power of 2}}
3736
using m4 = _BitInt(12) __attribute__((matrix_type(4, 4))); // expected-error{{'_BitInt' matrix element width must be a power of 2}}
3837
using m5 = _BitInt(8) __attribute__((matrix_type(4, 4)));

llvm/test/CodeGen/X86/vec-2bit-int.ll

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -mtriple=x86_64-- | FileCheck %s
3+
4+
define dso_local <2 x i2> @foo(<2 x i2> %v1, <2 x i2> %v2) {
5+
; CHECK-LABEL: foo:
6+
; CHECK: # %bb.0: # %entry
7+
; CHECK-NEXT: movq %xmm1, %rax
8+
; CHECK-NEXT: andb $3, %al
9+
; CHECK-NEXT: pshufd {{.*#+}} xmm2 = xmm1[2,3,2,3]
10+
; CHECK-NEXT: movq %xmm2, %rcx
11+
; CHECK-NEXT: shlb $2, %cl
12+
; CHECK-NEXT: orb %al, %cl
13+
; CHECK-NEXT: andb $15, %cl
14+
; CHECK-NEXT: movb %cl, -{{[0-9]+}}(%rsp)
15+
; CHECK-NEXT: movq %xmm0, %rax
16+
; CHECK-NEXT: andb $3, %al
17+
; CHECK-NEXT: pshufd {{.*#+}} xmm2 = xmm0[2,3,2,3]
18+
; CHECK-NEXT: movq %xmm2, %rcx
19+
; CHECK-NEXT: shlb $2, %cl
20+
; CHECK-NEXT: orb %al, %cl
21+
; CHECK-NEXT: andb $15, %cl
22+
; CHECK-NEXT: movb %cl, -{{[0-9]+}}(%rsp)
23+
; CHECK-NEXT: paddq %xmm1, %xmm0
24+
; CHECK-NEXT: retq
25+
entry:
26+
%v2.addr = alloca <2 x i2>, align 2
27+
%v1.addr = alloca <2 x i2>, align 2
28+
store <2 x i2> %v2, ptr %v2.addr, align 2
29+
store <2 x i2> %v1, ptr %v1.addr, align 2
30+
%0 = load <2 x i2>, ptr %v1.addr, align 2
31+
%1 = load <2 x i2>, ptr %v2.addr, align 2
32+
%add = add <2 x i2> %0, %1
33+
ret <2 x i2> %add
34+
}

0 commit comments

Comments
 (0)