Skip to content

Commit b110b7d

Browse files
authored
[X86][Clang] VectorExprEvaluator::VisitCallExpr / InterpretBuiltin - Allow AVX/AVX512 IFMA madd52 intrinsics to be used in constexpr (#161056)
Resolves #160498
1 parent 7e59abd commit b110b7d

File tree

9 files changed

+687
-117
lines changed

9 files changed

+687
-117
lines changed

clang/include/clang/Basic/BuiltinsX86.td

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2137,24 +2137,18 @@ let Features = "avx512vl", Attributes = [NoThrow, RequiredVectorWidth<256>] in {
21372137
def movdqa64store256_mask : X86Builtin<"void(_Vector<4, long long int *>, _Vector<4, long long int>, unsigned char)">;
21382138
}
21392139

2140-
let Features = "avx512ifma", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
2140+
let Features = "avx512ifma", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
21412141
def vpmadd52huq512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, _Vector<8, long long int>, _Vector<8, long long int>)">;
21422142
def vpmadd52luq512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, _Vector<8, long long int>, _Vector<8, long long int>)">;
21432143
}
21442144

2145-
let Features = "avx512ifma,avx512vl|avxifma", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
2145+
let Features = "avx512ifma,avx512vl|avxifma", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
21462146
def vpmadd52huq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>, _Vector<2, long long int>)">;
2147-
}
2148-
2149-
let Features = "avx512ifma,avx512vl|avxifma", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
2150-
def vpmadd52huq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<4, long long int>, _Vector<4, long long int>)">;
2151-
}
2152-
2153-
let Features = "avx512ifma,avx512vl|avxifma", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
21542147
def vpmadd52luq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>, _Vector<2, long long int>)">;
21552148
}
21562149

2157-
let Features = "avx512ifma,avx512vl|avxifma", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
2150+
let Features = "avx512ifma,avx512vl|avxifma", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
2151+
def vpmadd52huq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<4, long long int>, _Vector<4, long long int>)">;
21582152
def vpmadd52luq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<4, long long int>, _Vector<4, long long int>)">;
21592153
}
21602154

clang/lib/AST/ByteCode/InterpBuiltin.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3818,6 +3818,21 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
38183818
return F;
38193819
});
38203820

3821+
case X86::BI__builtin_ia32_vpmadd52luq128:
3822+
case X86::BI__builtin_ia32_vpmadd52luq256:
3823+
case X86::BI__builtin_ia32_vpmadd52luq512:
3824+
return interp__builtin_elementwise_triop(
3825+
S, OpPC, Call, [](const APSInt &A, const APSInt &B, const APSInt &C) {
3826+
return A + (B.trunc(52) * C.trunc(52)).zext(64);
3827+
});
3828+
case X86::BI__builtin_ia32_vpmadd52huq128:
3829+
case X86::BI__builtin_ia32_vpmadd52huq256:
3830+
case X86::BI__builtin_ia32_vpmadd52huq512:
3831+
return interp__builtin_elementwise_triop(
3832+
S, OpPC, Call, [](const APSInt &A, const APSInt &B, const APSInt &C) {
3833+
return A + llvm::APIntOps::mulhu(B.trunc(52), C.trunc(52)).zext(64);
3834+
});
3835+
38213836
case X86::BI__builtin_ia32_vpshldd128:
38223837
case X86::BI__builtin_ia32_vpshldd256:
38233838
case X86::BI__builtin_ia32_vpshldd512:

clang/lib/AST/ExprConstant.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11974,6 +11974,54 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1197411974

1197511975
return Success(APValue(ResultElements.data(), ResultElements.size()), E);
1197611976
}
11977+
11978+
case X86::BI__builtin_ia32_vpmadd52luq128:
11979+
case X86::BI__builtin_ia32_vpmadd52luq256:
11980+
case X86::BI__builtin_ia32_vpmadd52luq512: {
11981+
APValue A, B, C;
11982+
if (!EvaluateAsRValue(Info, E->getArg(0), A) ||
11983+
!EvaluateAsRValue(Info, E->getArg(1), B) ||
11984+
!EvaluateAsRValue(Info, E->getArg(2), C))
11985+
return false;
11986+
11987+
unsigned ALen = A.getVectorLength();
11988+
SmallVector<APValue, 4> ResultElements;
11989+
ResultElements.reserve(ALen);
11990+
11991+
for (unsigned EltNum = 0; EltNum < ALen; EltNum += 1) {
11992+
APInt AElt = A.getVectorElt(EltNum).getInt();
11993+
APInt BElt = B.getVectorElt(EltNum).getInt().trunc(52);
11994+
APInt CElt = C.getVectorElt(EltNum).getInt().trunc(52);
11995+
APSInt ResElt(AElt + (BElt * CElt).zext(64), false);
11996+
ResultElements.push_back(APValue(ResElt));
11997+
}
11998+
11999+
return Success(APValue(ResultElements.data(), ResultElements.size()), E);
12000+
}
12001+
case X86::BI__builtin_ia32_vpmadd52huq128:
12002+
case X86::BI__builtin_ia32_vpmadd52huq256:
12003+
case X86::BI__builtin_ia32_vpmadd52huq512: {
12004+
APValue A, B, C;
12005+
if (!EvaluateAsRValue(Info, E->getArg(0), A) ||
12006+
!EvaluateAsRValue(Info, E->getArg(1), B) ||
12007+
!EvaluateAsRValue(Info, E->getArg(2), C))
12008+
return false;
12009+
12010+
unsigned ALen = A.getVectorLength();
12011+
SmallVector<APValue, 4> ResultElements;
12012+
ResultElements.reserve(ALen);
12013+
12014+
for (unsigned EltNum = 0; EltNum < ALen; EltNum += 1) {
12015+
APInt AElt = A.getVectorElt(EltNum).getInt();
12016+
APInt BElt = B.getVectorElt(EltNum).getInt().trunc(52);
12017+
APInt CElt = C.getVectorElt(EltNum).getInt().trunc(52);
12018+
APSInt ResElt(AElt + llvm::APIntOps::mulhu(BElt, CElt).zext(64), false);
12019+
ResultElements.push_back(APValue(ResElt));
12020+
}
12021+
12022+
return Success(APValue(ResultElements.data(), ResultElements.size()), E);
12023+
}
12024+
1197712025
case clang::X86::BI__builtin_ia32_vprotbi:
1197812026
case clang::X86::BI__builtin_ia32_vprotdi:
1197912027
case clang::X86::BI__builtin_ia32_vprotqi:

clang/lib/Headers/avx512ifmaintrin.h

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,54 +15,53 @@
1515
#define __IFMAINTRIN_H
1616

1717
/* Define the default attributes for the functions in this file. */
18+
#if defined(__cplusplus) && (__cplusplus >= 201103L)
19+
#define __DEFAULT_FN_ATTRS \
20+
constexpr \
21+
__attribute__((__always_inline__, __nodebug__, __target__("avx512ifma"), \
22+
__min_vector_width__(512)))
23+
#else
1824
#define __DEFAULT_FN_ATTRS \
1925
__attribute__((__always_inline__, __nodebug__, __target__("avx512ifma"), \
2026
__min_vector_width__(512)))
27+
#endif
2128

2229
static __inline__ __m512i __DEFAULT_FN_ATTRS
23-
_mm512_madd52hi_epu64 (__m512i __X, __m512i __Y, __m512i __Z)
24-
{
25-
return (__m512i)__builtin_ia32_vpmadd52huq512((__v8di) __X, (__v8di) __Y,
26-
(__v8di) __Z);
30+
_mm512_madd52hi_epu64(__m512i __X, __m512i __Y, __m512i __Z) {
31+
return (__m512i)__builtin_ia32_vpmadd52huq512((__v8di)__X, (__v8di)__Y,
32+
(__v8di)__Z);
2733
}
2834

29-
static __inline__ __m512i __DEFAULT_FN_ATTRS
30-
_mm512_mask_madd52hi_epu64 (__m512i __W, __mmask8 __M, __m512i __X, __m512i __Y)
31-
{
32-
return (__m512i)__builtin_ia32_selectq_512(__M,
33-
(__v8di)_mm512_madd52hi_epu64(__W, __X, __Y),
34-
(__v8di)__W);
35+
static __inline__ __m512i __DEFAULT_FN_ATTRS _mm512_mask_madd52hi_epu64(
36+
__m512i __W, __mmask8 __M, __m512i __X, __m512i __Y) {
37+
return (__m512i)__builtin_ia32_selectq_512(
38+
__M, (__v8di)_mm512_madd52hi_epu64(__W, __X, __Y), (__v8di)__W);
3539
}
3640

37-
static __inline__ __m512i __DEFAULT_FN_ATTRS
38-
_mm512_maskz_madd52hi_epu64 (__mmask8 __M, __m512i __X, __m512i __Y, __m512i __Z)
39-
{
40-
return (__m512i)__builtin_ia32_selectq_512(__M,
41-
(__v8di)_mm512_madd52hi_epu64(__X, __Y, __Z),
42-
(__v8di)_mm512_setzero_si512());
41+
static __inline__ __m512i __DEFAULT_FN_ATTRS _mm512_maskz_madd52hi_epu64(
42+
__mmask8 __M, __m512i __X, __m512i __Y, __m512i __Z) {
43+
return (__m512i)__builtin_ia32_selectq_512(
44+
__M, (__v8di)_mm512_madd52hi_epu64(__X, __Y, __Z),
45+
(__v8di)_mm512_setzero_si512());
4346
}
4447

4548
static __inline__ __m512i __DEFAULT_FN_ATTRS
46-
_mm512_madd52lo_epu64 (__m512i __X, __m512i __Y, __m512i __Z)
47-
{
48-
return (__m512i)__builtin_ia32_vpmadd52luq512((__v8di) __X, (__v8di) __Y,
49-
(__v8di) __Z);
49+
_mm512_madd52lo_epu64(__m512i __X, __m512i __Y, __m512i __Z) {
50+
return (__m512i)__builtin_ia32_vpmadd52luq512((__v8di)__X, (__v8di)__Y,
51+
(__v8di)__Z);
5052
}
5153

52-
static __inline__ __m512i __DEFAULT_FN_ATTRS
53-
_mm512_mask_madd52lo_epu64 (__m512i __W, __mmask8 __M, __m512i __X, __m512i __Y)
54-
{
55-
return (__m512i)__builtin_ia32_selectq_512(__M,
56-
(__v8di)_mm512_madd52lo_epu64(__W, __X, __Y),
57-
(__v8di)__W);
54+
static __inline__ __m512i __DEFAULT_FN_ATTRS _mm512_mask_madd52lo_epu64(
55+
__m512i __W, __mmask8 __M, __m512i __X, __m512i __Y) {
56+
return (__m512i)__builtin_ia32_selectq_512(
57+
__M, (__v8di)_mm512_madd52lo_epu64(__W, __X, __Y), (__v8di)__W);
5858
}
5959

60-
static __inline__ __m512i __DEFAULT_FN_ATTRS
61-
_mm512_maskz_madd52lo_epu64 (__mmask8 __M, __m512i __X, __m512i __Y, __m512i __Z)
62-
{
63-
return (__m512i)__builtin_ia32_selectq_512(__M,
64-
(__v8di)_mm512_madd52lo_epu64(__X, __Y, __Z),
65-
(__v8di)_mm512_setzero_si512());
60+
static __inline__ __m512i __DEFAULT_FN_ATTRS _mm512_maskz_madd52lo_epu64(
61+
__mmask8 __M, __m512i __X, __m512i __Y, __m512i __Z) {
62+
return (__m512i)__builtin_ia32_selectq_512(
63+
__M, (__v8di)_mm512_madd52lo_epu64(__X, __Y, __Z),
64+
(__v8di)_mm512_setzero_si512());
6665
}
6766

6867
#undef __DEFAULT_FN_ATTRS

clang/lib/Headers/avx512ifmavlintrin.h

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,24 @@
88
*===-----------------------------------------------------------------------===
99
*/
1010
#ifndef __IMMINTRIN_H
11-
#error "Never use <avx512ifmavlintrin.h> directly; include <immintrin.h> instead."
11+
#error \
12+
"Never use <avx512ifmavlintrin.h> directly; include <immintrin.h> instead."
1213
#endif
1314

1415
#ifndef __IFMAVLINTRIN_H
1516
#define __IFMAVLINTRIN_H
1617

1718
/* Define the default attributes for the functions in this file. */
19+
#if defined(__cplusplus) && (__cplusplus >= 201103L)
20+
#define __DEFAULT_FN_ATTRS128 \
21+
constexpr __attribute__((__always_inline__, __nodebug__, \
22+
__target__("avx512ifma,avx512vl"), \
23+
__min_vector_width__(128)))
24+
#define __DEFAULT_FN_ATTRS256 \
25+
constexpr __attribute__((__always_inline__, __nodebug__, \
26+
__target__("avx512ifma,avx512vl"), \
27+
__min_vector_width__(256)))
28+
#else
1829
#define __DEFAULT_FN_ATTRS128 \
1930
__attribute__((__always_inline__, __nodebug__, \
2031
__target__("avx512ifma,avx512vl"), \
@@ -24,6 +35,8 @@
2435
__target__("avx512ifma,avx512vl"), \
2536
__min_vector_width__(256)))
2637

38+
#endif
39+
2740
#define _mm_madd52hi_epu64(X, Y, Z) \
2841
((__m128i)__builtin_ia32_vpmadd52huq128((__v2di)(X), (__v2di)(Y), \
2942
(__v2di)(Z)))
@@ -41,70 +54,57 @@
4154
(__v4di)(Z)))
4255

4356
static __inline__ __m128i __DEFAULT_FN_ATTRS128
44-
_mm_mask_madd52hi_epu64 (__m128i __W, __mmask8 __M, __m128i __X, __m128i __Y)
45-
{
46-
return (__m128i)__builtin_ia32_selectq_128(__M,
47-
(__v2di)_mm_madd52hi_epu64(__W, __X, __Y),
48-
(__v2di)__W);
57+
_mm_mask_madd52hi_epu64(__m128i __W, __mmask8 __M, __m128i __X, __m128i __Y) {
58+
return (__m128i)__builtin_ia32_selectq_128(
59+
__M, (__v2di)_mm_madd52hi_epu64(__W, __X, __Y), (__v2di)__W);
4960
}
5061

5162
static __inline__ __m128i __DEFAULT_FN_ATTRS128
52-
_mm_maskz_madd52hi_epu64 (__mmask8 __M, __m128i __X, __m128i __Y, __m128i __Z)
53-
{
54-
return (__m128i)__builtin_ia32_selectq_128(__M,
55-
(__v2di)_mm_madd52hi_epu64(__X, __Y, __Z),
56-
(__v2di)_mm_setzero_si128());
63+
_mm_maskz_madd52hi_epu64(__mmask8 __M, __m128i __X, __m128i __Y, __m128i __Z) {
64+
return (__m128i)__builtin_ia32_selectq_128(
65+
__M, (__v2di)_mm_madd52hi_epu64(__X, __Y, __Z),
66+
(__v2di)_mm_setzero_si128());
5767
}
5868

59-
static __inline__ __m256i __DEFAULT_FN_ATTRS256
60-
_mm256_mask_madd52hi_epu64 (__m256i __W, __mmask8 __M, __m256i __X, __m256i __Y)
61-
{
62-
return (__m256i)__builtin_ia32_selectq_256(__M,
63-
(__v4di)_mm256_madd52hi_epu64(__W, __X, __Y),
64-
(__v4di)__W);
69+
static __inline__ __m256i __DEFAULT_FN_ATTRS256 _mm256_mask_madd52hi_epu64(
70+
__m256i __W, __mmask8 __M, __m256i __X, __m256i __Y) {
71+
return (__m256i)__builtin_ia32_selectq_256(
72+
__M, (__v4di)_mm256_madd52hi_epu64(__W, __X, __Y), (__v4di)__W);
6573
}
6674

67-
static __inline__ __m256i __DEFAULT_FN_ATTRS256
68-
_mm256_maskz_madd52hi_epu64 (__mmask8 __M, __m256i __X, __m256i __Y, __m256i __Z)
69-
{
70-
return (__m256i)__builtin_ia32_selectq_256(__M,
71-
(__v4di)_mm256_madd52hi_epu64(__X, __Y, __Z),
72-
(__v4di)_mm256_setzero_si256());
75+
static __inline__ __m256i __DEFAULT_FN_ATTRS256 _mm256_maskz_madd52hi_epu64(
76+
__mmask8 __M, __m256i __X, __m256i __Y, __m256i __Z) {
77+
return (__m256i)__builtin_ia32_selectq_256(
78+
__M, (__v4di)_mm256_madd52hi_epu64(__X, __Y, __Z),
79+
(__v4di)_mm256_setzero_si256());
7380
}
7481

7582
static __inline__ __m128i __DEFAULT_FN_ATTRS128
76-
_mm_mask_madd52lo_epu64 (__m128i __W, __mmask8 __M, __m128i __X, __m128i __Y)
77-
{
78-
return (__m128i)__builtin_ia32_selectq_128(__M,
79-
(__v2di)_mm_madd52lo_epu64(__W, __X, __Y),
80-
(__v2di)__W);
83+
_mm_mask_madd52lo_epu64(__m128i __W, __mmask8 __M, __m128i __X, __m128i __Y) {
84+
return (__m128i)__builtin_ia32_selectq_128(
85+
__M, (__v2di)_mm_madd52lo_epu64(__W, __X, __Y), (__v2di)__W);
8186
}
8287

8388
static __inline__ __m128i __DEFAULT_FN_ATTRS128
84-
_mm_maskz_madd52lo_epu64 (__mmask8 __M, __m128i __X, __m128i __Y, __m128i __Z)
85-
{
86-
return (__m128i)__builtin_ia32_selectq_128(__M,
87-
(__v2di)_mm_madd52lo_epu64(__X, __Y, __Z),
88-
(__v2di)_mm_setzero_si128());
89+
_mm_maskz_madd52lo_epu64(__mmask8 __M, __m128i __X, __m128i __Y, __m128i __Z) {
90+
return (__m128i)__builtin_ia32_selectq_128(
91+
__M, (__v2di)_mm_madd52lo_epu64(__X, __Y, __Z),
92+
(__v2di)_mm_setzero_si128());
8993
}
9094

91-
static __inline__ __m256i __DEFAULT_FN_ATTRS256
92-
_mm256_mask_madd52lo_epu64 (__m256i __W, __mmask8 __M, __m256i __X, __m256i __Y)
93-
{
94-
return (__m256i)__builtin_ia32_selectq_256(__M,
95-
(__v4di)_mm256_madd52lo_epu64(__W, __X, __Y),
96-
(__v4di)__W);
95+
static __inline__ __m256i __DEFAULT_FN_ATTRS256 _mm256_mask_madd52lo_epu64(
96+
__m256i __W, __mmask8 __M, __m256i __X, __m256i __Y) {
97+
return (__m256i)__builtin_ia32_selectq_256(
98+
__M, (__v4di)_mm256_madd52lo_epu64(__W, __X, __Y), (__v4di)__W);
9799
}
98100

99-
static __inline__ __m256i __DEFAULT_FN_ATTRS256
100-
_mm256_maskz_madd52lo_epu64 (__mmask8 __M, __m256i __X, __m256i __Y, __m256i __Z)
101-
{
102-
return (__m256i)__builtin_ia32_selectq_256(__M,
103-
(__v4di)_mm256_madd52lo_epu64(__X, __Y, __Z),
104-
(__v4di)_mm256_setzero_si256());
101+
static __inline__ __m256i __DEFAULT_FN_ATTRS256 _mm256_maskz_madd52lo_epu64(
102+
__mmask8 __M, __m256i __X, __m256i __Y, __m256i __Z) {
103+
return (__m256i)__builtin_ia32_selectq_256(
104+
__M, (__v4di)_mm256_madd52lo_epu64(__X, __Y, __Z),
105+
(__v4di)_mm256_setzero_si256());
105106
}
106107

107-
108108
#undef __DEFAULT_FN_ATTRS128
109109
#undef __DEFAULT_FN_ATTRS256
110110

clang/lib/Headers/avxifmaintrin.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,21 @@
1515
#define __AVXIFMAINTRIN_H
1616

1717
/* Define the default attributes for the functions in this file. */
18+
#if defined(__cplusplus) && (__cplusplus >= 201103L)
19+
#define __DEFAULT_FN_ATTRS128 \
20+
constexpr __attribute__((__always_inline__, __nodebug__, \
21+
__target__("avxifma"), __min_vector_width__(128)))
22+
#define __DEFAULT_FN_ATTRS256 \
23+
constexpr __attribute__((__always_inline__, __nodebug__, \
24+
__target__("avxifma"), __min_vector_width__(256)))
25+
#else
1826
#define __DEFAULT_FN_ATTRS128 \
1927
__attribute__((__always_inline__, __nodebug__, __target__("avxifma"), \
2028
__min_vector_width__(128)))
2129
#define __DEFAULT_FN_ATTRS256 \
2230
__attribute__((__always_inline__, __nodebug__, __target__("avxifma"), \
2331
__min_vector_width__(256)))
32+
#endif
2433

2534
// must vex-encoding
2635

0 commit comments

Comments
 (0)