Skip to content

Commit cc9ad9a

Browse files
authored
[Clang] Add constexpr support for AVX512 permutex2 intrinsics (#165085)
This patch enables compile-time evaluation of AVX512 permutex2var intrinsics in constexpr contexts. Extend shuffle generic to handle both integer immediate and vector mask operands. Resolves #161335
1 parent 0314b93 commit cc9ad9a

19 files changed

+1096
-265
lines changed

clang/include/clang/Basic/BuiltinsX86.td

Lines changed: 17 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1765,75 +1765,48 @@ let Features = "avx512vl", Attributes = [NoThrow, RequiredVectorWidth<256>] in {
17651765
def scattersiv8si : X86Builtin<"void(void *, unsigned char, _Vector<8, int>, _Vector<8, int>, _Constant int)">;
17661766
}
17671767

1768-
let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
1768+
let Features = "avx512vl", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
17691769
def vpermi2vard128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>, _Vector<4, int>)">;
1770-
}
1771-
1772-
let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
1773-
def vpermi2vard256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<8, int>, _Vector<8, int>)">;
1774-
}
1775-
1776-
let Features = "avx512f", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
1777-
def vpermi2vard512 : X86Builtin<"_Vector<16, int>(_Vector<16, int>, _Vector<16, int>, _Vector<16, int>)">;
1778-
}
1779-
1780-
let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
1781-
def vpermi2varpd128 : X86Builtin<"_Vector<2, double>(_Vector<2, double>, _Vector<2, long long int>, _Vector<2, double>)">;
1782-
}
1783-
1784-
let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
1785-
def vpermi2varpd256 : X86Builtin<"_Vector<4, double>(_Vector<4, double>, _Vector<4, long long int>, _Vector<4, double>)">;
1786-
}
1787-
1788-
let Features = "avx512f", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
1789-
def vpermi2varpd512 : X86Builtin<"_Vector<8, double>(_Vector<8, double>, _Vector<8, long long int>, _Vector<8, double>)">;
1790-
}
1791-
1792-
let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
1793-
def vpermi2varps128 : X86Builtin<"_Vector<4, float>(_Vector<4, float>, _Vector<4, int>, _Vector<4, float>)">;
1794-
}
1795-
1796-
let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
1797-
def vpermi2varps256 : X86Builtin<"_Vector<8, float>(_Vector<8, float>, _Vector<8, int>, _Vector<8, float>)">;
1798-
}
1799-
1800-
let Features = "avx512f", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
1801-
def vpermi2varps512 : X86Builtin<"_Vector<16, float>(_Vector<16, float>, _Vector<16, int>, _Vector<16, float>)">;
1802-
}
1803-
1804-
let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
18051770
def vpermi2varq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>, _Vector<2, long long int>)">;
1771+
def vpermi2varps128 : X86Builtin<"_Vector<4, float>(_Vector<4, float>, _Vector<4, int>, _Vector<4, float>)">;
1772+
def vpermi2varpd128 : X86Builtin<"_Vector<2, double>(_Vector<2, double>, _Vector<2, long long int>, _Vector<2, double>)">;
18061773
}
18071774

1808-
let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
1775+
let Features = "avx512vl", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
1776+
def vpermi2vard256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<8, int>, _Vector<8, int>)">;
18091777
def vpermi2varq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<4, long long int>, _Vector<4, long long int>)">;
1778+
def vpermi2varps256 : X86Builtin<"_Vector<8, float>(_Vector<8, float>, _Vector<8, int>, _Vector<8, float>)">;
1779+
def vpermi2varpd256 : X86Builtin<"_Vector<4, double>(_Vector<4, double>, _Vector<4, long long int>, _Vector<4, double>)">;
18101780
}
18111781

1812-
let Features = "avx512f", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
1782+
let Features = "avx512f", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
1783+
def vpermi2vard512 : X86Builtin<"_Vector<16, int>(_Vector<16, int>, _Vector<16, int>, _Vector<16, int>)">;
18131784
def vpermi2varq512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, _Vector<8, long long int>, _Vector<8, long long int>)">;
1785+
def vpermi2varps512 : X86Builtin<"_Vector<16, float>(_Vector<16, float>, _Vector<16, int>, _Vector<16, float>)">;
1786+
def vpermi2varpd512 : X86Builtin<"_Vector<8, double>(_Vector<8, double>, _Vector<8, long long int>, _Vector<8, double>)">;
18141787
}
18151788

1816-
let Features = "avx512vbmi,avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
1789+
let Features = "avx512vbmi,avx512vl", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
18171790
def vpermi2varqi128 : X86Builtin<"_Vector<16, char>(_Vector<16, char>, _Vector<16, char>, _Vector<16, char>)">;
18181791
}
18191792

1820-
let Features = "avx512vbmi,avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
1793+
let Features = "avx512vbmi,avx512vl", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
18211794
def vpermi2varqi256 : X86Builtin<"_Vector<32, char>(_Vector<32, char>, _Vector<32, char>, _Vector<32, char>)">;
18221795
}
18231796

1824-
let Features = "avx512vbmi", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
1797+
let Features = "avx512vbmi", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
18251798
def vpermi2varqi512 : X86Builtin<"_Vector<64, char>(_Vector<64, char>, _Vector<64, char>, _Vector<64, char>)">;
18261799
}
18271800

1828-
let Features = "avx512vl,avx512bw", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
1801+
let Features = "avx512vl,avx512bw", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
18291802
def vpermi2varhi128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Vector<8, short>, _Vector<8, short>)">;
18301803
}
18311804

1832-
let Features = "avx512vl,avx512bw", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
1805+
let Features = "avx512vl,avx512bw", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
18331806
def vpermi2varhi256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<16, short>, _Vector<16, short>)">;
18341807
}
18351808

1836-
let Features = "avx512bw", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
1809+
let Features = "avx512bw", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
18371810
def vpermi2varhi512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Vector<32, short>, _Vector<32, short>)">;
18381811
}
18391812

clang/lib/AST/ByteCode/InterpBuiltin.cpp

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3415,18 +3415,46 @@ static bool interp__builtin_ia32_shuffle_generic(
34153415
GetSourceIndex) {
34163416

34173417
assert(Call->getNumArgs() == 3);
3418-
unsigned ShuffleMask = popToAPSInt(S, Call->getArg(2)).getZExtValue();
3418+
3419+
unsigned ShuffleMask = 0;
3420+
Pointer A, MaskVector, B;
3421+
3422+
QualType Arg2Type = Call->getArg(2)->getType();
3423+
bool IsVectorMask = false;
3424+
if (Arg2Type->isVectorType()) {
3425+
IsVectorMask = true;
3426+
B = S.Stk.pop<Pointer>();
3427+
MaskVector = S.Stk.pop<Pointer>();
3428+
A = S.Stk.pop<Pointer>();
3429+
} else if (Arg2Type->isIntegerType()) {
3430+
ShuffleMask = popToAPSInt(S, Call->getArg(2)).getZExtValue();
3431+
B = S.Stk.pop<Pointer>();
3432+
A = S.Stk.pop<Pointer>();
3433+
} else {
3434+
return false;
3435+
}
34193436

34203437
QualType Arg0Type = Call->getArg(0)->getType();
34213438
const auto *VecT = Arg0Type->castAs<VectorType>();
34223439
PrimType ElemT = *S.getContext().classify(VecT->getElementType());
34233440
unsigned NumElems = VecT->getNumElements();
34243441

3425-
const Pointer &B = S.Stk.pop<Pointer>();
3426-
const Pointer &A = S.Stk.pop<Pointer>();
34273442
const Pointer &Dst = S.Stk.peek<Pointer>();
34283443

3444+
PrimType MaskElemT = PT_Uint32;
3445+
if (IsVectorMask) {
3446+
QualType Arg1Type = Call->getArg(1)->getType();
3447+
const auto *MaskVecT = Arg1Type->castAs<VectorType>();
3448+
QualType MaskElemType = MaskVecT->getElementType();
3449+
MaskElemT = *S.getContext().classify(MaskElemType);
3450+
}
3451+
34293452
for (unsigned DstIdx = 0; DstIdx != NumElems; ++DstIdx) {
3453+
if (IsVectorMask) {
3454+
INT_TYPE_SWITCH(MaskElemT, {
3455+
ShuffleMask = static_cast<unsigned>(MaskVector.elem<T>(DstIdx));
3456+
});
3457+
}
34303458
auto [SrcVecIdx, SrcIdx] = GetSourceIndex(DstIdx, ShuffleMask);
34313459

34323460
if (SrcIdx < 0) {
@@ -4434,6 +4462,60 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
44344462
return std::pair<unsigned, int>{0, static_cast<int>(DstIdx)};
44354463
}
44364464
});
4465+
case X86::BI__builtin_ia32_vpermi2varq128:
4466+
case X86::BI__builtin_ia32_vpermi2varpd128:
4467+
return interp__builtin_ia32_shuffle_generic(
4468+
S, OpPC, Call, [](unsigned DstIdx, unsigned ShuffleMask) {
4469+
int Offset = ShuffleMask & 0x1;
4470+
unsigned SrcIdx = (ShuffleMask >> 1) & 0x1;
4471+
return std::pair<unsigned, int>{SrcIdx, Offset};
4472+
});
4473+
case X86::BI__builtin_ia32_vpermi2vard128:
4474+
case X86::BI__builtin_ia32_vpermi2varps128:
4475+
case X86::BI__builtin_ia32_vpermi2varq256:
4476+
case X86::BI__builtin_ia32_vpermi2varpd256:
4477+
return interp__builtin_ia32_shuffle_generic(
4478+
S, OpPC, Call, [](unsigned DstIdx, unsigned ShuffleMask) {
4479+
int Offset = ShuffleMask & 0x3;
4480+
unsigned SrcIdx = (ShuffleMask >> 2) & 0x1;
4481+
return std::pair<unsigned, int>{SrcIdx, Offset};
4482+
});
4483+
case X86::BI__builtin_ia32_vpermi2varhi128:
4484+
case X86::BI__builtin_ia32_vpermi2vard256:
4485+
case X86::BI__builtin_ia32_vpermi2varps256:
4486+
case X86::BI__builtin_ia32_vpermi2varq512:
4487+
case X86::BI__builtin_ia32_vpermi2varpd512:
4488+
return interp__builtin_ia32_shuffle_generic(
4489+
S, OpPC, Call, [](unsigned DstIdx, unsigned ShuffleMask) {
4490+
int Offset = ShuffleMask & 0x7;
4491+
unsigned SrcIdx = (ShuffleMask >> 3) & 0x1;
4492+
return std::pair<unsigned, int>{SrcIdx, Offset};
4493+
});
4494+
case X86::BI__builtin_ia32_vpermi2varqi128:
4495+
case X86::BI__builtin_ia32_vpermi2varhi256:
4496+
case X86::BI__builtin_ia32_vpermi2vard512:
4497+
case X86::BI__builtin_ia32_vpermi2varps512:
4498+
return interp__builtin_ia32_shuffle_generic(
4499+
S, OpPC, Call, [](unsigned DstIdx, unsigned ShuffleMask) {
4500+
int Offset = ShuffleMask & 0xF;
4501+
unsigned SrcIdx = (ShuffleMask >> 4) & 0x1;
4502+
return std::pair<unsigned, int>{SrcIdx, Offset};
4503+
});
4504+
case X86::BI__builtin_ia32_vpermi2varqi256:
4505+
case X86::BI__builtin_ia32_vpermi2varhi512:
4506+
return interp__builtin_ia32_shuffle_generic(
4507+
S, OpPC, Call, [](unsigned DstIdx, unsigned ShuffleMask) {
4508+
int Offset = ShuffleMask & 0x1F;
4509+
unsigned SrcIdx = (ShuffleMask >> 5) & 0x1;
4510+
return std::pair<unsigned, int>{SrcIdx, Offset};
4511+
});
4512+
case X86::BI__builtin_ia32_vpermi2varqi512:
4513+
return interp__builtin_ia32_shuffle_generic(
4514+
S, OpPC, Call, [](unsigned DstIdx, unsigned ShuffleMask) {
4515+
int Offset = ShuffleMask & 0x3F;
4516+
unsigned SrcIdx = (ShuffleMask >> 6) & 0x1;
4517+
return std::pair<unsigned, int>{SrcIdx, Offset};
4518+
});
44374519
case X86::BI__builtin_ia32_pshufb128:
44384520
case X86::BI__builtin_ia32_pshufb256:
44394521
case X86::BI__builtin_ia32_pshufb512:

clang/lib/AST/ExprConstant.cpp

Lines changed: 103 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11628,21 +11628,38 @@ static bool evalShuffleGeneric(
1162811628
if (!VT)
1162911629
return false;
1163011630

11631-
APSInt MaskImm;
11632-
if (!EvaluateInteger(Call->getArg(2), MaskImm, Info))
11633-
return false;
11634-
unsigned ShuffleMask = static_cast<unsigned>(MaskImm.getZExtValue());
11631+
unsigned ShuffleMask = 0;
11632+
APValue A, MaskVector, B;
11633+
bool IsVectorMask = false;
1163511634

11636-
APValue A, B;
11637-
if (!EvaluateAsRValue(Info, Call->getArg(0), A) ||
11638-
!EvaluateAsRValue(Info, Call->getArg(1), B))
11635+
QualType Arg2Type = Call->getArg(2)->getType();
11636+
if (Arg2Type->isVectorType()) {
11637+
IsVectorMask = true;
11638+
if (!EvaluateAsRValue(Info, Call->getArg(0), A) ||
11639+
!EvaluateAsRValue(Info, Call->getArg(1), MaskVector) ||
11640+
!EvaluateAsRValue(Info, Call->getArg(2), B))
11641+
return false;
11642+
} else if (Arg2Type->isIntegerType()) {
11643+
APSInt MaskImm;
11644+
if (!EvaluateInteger(Call->getArg(2), MaskImm, Info))
11645+
return false;
11646+
ShuffleMask = static_cast<unsigned>(MaskImm.getZExtValue());
11647+
if (!EvaluateAsRValue(Info, Call->getArg(0), A) ||
11648+
!EvaluateAsRValue(Info, Call->getArg(1), B))
11649+
return false;
11650+
} else {
1163911651
return false;
11652+
}
1164011653

1164111654
unsigned NumElts = VT->getNumElements();
11642-
SmallVector<APValue, 16> ResultElements;
11655+
SmallVector<APValue, 64> ResultElements;
1164311656
ResultElements.reserve(NumElts);
1164411657

1164511658
for (unsigned DstIdx = 0; DstIdx != NumElts; ++DstIdx) {
11659+
if (IsVectorMask) {
11660+
ShuffleMask = static_cast<unsigned>(
11661+
MaskVector.getVectorElt(DstIdx).getInt().getZExtValue());
11662+
}
1164611663
auto [SrcVecIdx, SrcIdx] = GetSourceIndex(DstIdx, ShuffleMask);
1164711664

1164811665
if (SrcIdx < 0) {
@@ -13080,6 +13097,84 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1308013097

1308113098
return Success(APValue(ResultElements.data(), ResultElements.size()), E);
1308213099
}
13100+
case X86::BI__builtin_ia32_vpermi2varq128:
13101+
case X86::BI__builtin_ia32_vpermi2varpd128: {
13102+
APValue R;
13103+
if (!evalShuffleGeneric(Info, E, R,
13104+
[](unsigned DstIdx, unsigned ShuffleMask) {
13105+
int Offset = ShuffleMask & 0x1;
13106+
unsigned SrcIdx = (ShuffleMask >> 1) & 0x1;
13107+
return std::pair<unsigned, int>{SrcIdx, Offset};
13108+
}))
13109+
return false;
13110+
return Success(R, E);
13111+
}
13112+
case X86::BI__builtin_ia32_vpermi2vard128:
13113+
case X86::BI__builtin_ia32_vpermi2varps128:
13114+
case X86::BI__builtin_ia32_vpermi2varq256:
13115+
case X86::BI__builtin_ia32_vpermi2varpd256: {
13116+
APValue R;
13117+
if (!evalShuffleGeneric(Info, E, R,
13118+
[](unsigned DstIdx, unsigned ShuffleMask) {
13119+
int Offset = ShuffleMask & 0x3;
13120+
unsigned SrcIdx = (ShuffleMask >> 2) & 0x1;
13121+
return std::pair<unsigned, int>{SrcIdx, Offset};
13122+
}))
13123+
return false;
13124+
return Success(R, E);
13125+
}
13126+
case X86::BI__builtin_ia32_vpermi2varhi128:
13127+
case X86::BI__builtin_ia32_vpermi2vard256:
13128+
case X86::BI__builtin_ia32_vpermi2varps256:
13129+
case X86::BI__builtin_ia32_vpermi2varq512:
13130+
case X86::BI__builtin_ia32_vpermi2varpd512: {
13131+
APValue R;
13132+
if (!evalShuffleGeneric(Info, E, R,
13133+
[](unsigned DstIdx, unsigned ShuffleMask) {
13134+
int Offset = ShuffleMask & 0x7;
13135+
unsigned SrcIdx = (ShuffleMask >> 3) & 0x1;
13136+
return std::pair<unsigned, int>{SrcIdx, Offset};
13137+
}))
13138+
return false;
13139+
return Success(R, E);
13140+
}
13141+
case X86::BI__builtin_ia32_vpermi2varqi128:
13142+
case X86::BI__builtin_ia32_vpermi2varhi256:
13143+
case X86::BI__builtin_ia32_vpermi2vard512:
13144+
case X86::BI__builtin_ia32_vpermi2varps512: {
13145+
APValue R;
13146+
if (!evalShuffleGeneric(Info, E, R,
13147+
[](unsigned DstIdx, unsigned ShuffleMask) {
13148+
int Offset = ShuffleMask & 0xF;
13149+
unsigned SrcIdx = (ShuffleMask >> 4) & 0x1;
13150+
return std::pair<unsigned, int>{SrcIdx, Offset};
13151+
}))
13152+
return false;
13153+
return Success(R, E);
13154+
}
13155+
case X86::BI__builtin_ia32_vpermi2varqi256:
13156+
case X86::BI__builtin_ia32_vpermi2varhi512: {
13157+
APValue R;
13158+
if (!evalShuffleGeneric(Info, E, R,
13159+
[](unsigned DstIdx, unsigned ShuffleMask) {
13160+
int Offset = ShuffleMask & 0x1F;
13161+
unsigned SrcIdx = (ShuffleMask >> 5) & 0x1;
13162+
return std::pair<unsigned, int>{SrcIdx, Offset};
13163+
}))
13164+
return false;
13165+
return Success(R, E);
13166+
}
13167+
case X86::BI__builtin_ia32_vpermi2varqi512: {
13168+
APValue R;
13169+
if (!evalShuffleGeneric(Info, E, R,
13170+
[](unsigned DstIdx, unsigned ShuffleMask) {
13171+
int Offset = ShuffleMask & 0x3F;
13172+
unsigned SrcIdx = (ShuffleMask >> 6) & 0x1;
13173+
return std::pair<unsigned, int>{SrcIdx, Offset};
13174+
}))
13175+
return false;
13176+
return Success(R, E);
13177+
}
1308313178
}
1308413179
}
1308513180

clang/lib/Headers/avx10_2_512bf16intrin.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ typedef __bf16 __m512bh_u __attribute__((__vector_size__(64), __aligned__(1)));
2424
__attribute__((__always_inline__, __nodebug__, __target__("avx10.2"), \
2525
__min_vector_width__(512)))
2626

27+
#if defined(__cplusplus) && (__cplusplus >= 201103L)
28+
#define __DEFAULT_FN_ATTRS512_CONSTEXPR __DEFAULT_FN_ATTRS512 constexpr
29+
#else
30+
#define __DEFAULT_FN_ATTRS512_CONSTEXPR __DEFAULT_FN_ATTRS512
31+
#endif
32+
2733
static __inline __m512bh __DEFAULT_FN_ATTRS512 _mm512_setzero_pbh(void) {
2834
return __builtin_bit_cast(__m512bh, _mm512_setzero_ps());
2935
}
@@ -167,7 +173,7 @@ _mm512_mask_blend_pbh(__mmask32 __U, __m512bh __A, __m512bh __W) {
167173
(__v32bf)__A);
168174
}
169175

170-
static __inline__ __m512bh __DEFAULT_FN_ATTRS512
176+
static __inline__ __m512bh __DEFAULT_FN_ATTRS512_CONSTEXPR
171177
_mm512_permutex2var_pbh(__m512bh __A, __m512i __I, __m512bh __B) {
172178
return (__m512bh)__builtin_ia32_vpermi2varhi512((__v32hi)__A, (__v32hi)__I,
173179
(__v32hi)__B);
@@ -555,6 +561,7 @@ static __inline__ __m512bh __DEFAULT_FN_ATTRS512 _mm512_maskz_fnmsub_pbh(
555561
(__v32bf)_mm512_setzero_pbh());
556562
}
557563

564+
#undef __DEFAULT_FN_ATTRS512_CONSTEXPR
558565
#undef __DEFAULT_FN_ATTRS512
559566

560567
#endif

0 commit comments

Comments
 (0)