Skip to content

Commit f8e7716

Browse files
authored
[Clang] VectorExprEvaluator::VisitCallExpr / InterpretBuiltin - Allow PSLL/PSRA/PSRL var intrinsics to be used in constexpr (#169276)
Resolves #169176
1 parent 634f6e9 commit f8e7716

File tree

15 files changed

+749
-443
lines changed

15 files changed

+749
-443
lines changed

clang/include/clang/Basic/BuiltinsX86.td

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -210,17 +210,6 @@ let Header = "emmintrin.h", Attributes = [NoThrow, RequireDeclaration] in {
210210
def _mm_pause : X86LibBuiltin<"void()">;
211211
}
212212

213-
let Features = "sse2", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
214-
def psraw128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Vector<8, short>)">;
215-
def psrad128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">;
216-
def psrlw128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Vector<8, short>)">;
217-
def psrld128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">;
218-
def psrlq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>)">;
219-
def psllw128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Vector<8, short>)">;
220-
def pslld128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">;
221-
def psllq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>)">;
222-
}
223-
224213
let Features = "sse2", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
225214
def movmskpd : X86Builtin<"int(_Vector<2, double>)">;
226215
def pmovmskb128 : X86Builtin<"int(_Vector<16, char>)">;
@@ -261,6 +250,15 @@ let Features = "sse2", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWi
261250
def vec_ext_v4sf : X86Builtin<"float(_Vector<4, float>, _Constant int)">;
262251
def vec_ext_v8hi : X86Builtin<"short(_Vector<8, short>, _Constant int)">;
263252
def vec_set_v8hi : X86Builtin<"_Vector<8, short>(_Vector<8, short>, short, _Constant int)">;
253+
254+
def psraw128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Vector<8, short>)">;
255+
def psrad128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">;
256+
def psrlw128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Vector<8, short>)">;
257+
def psrld128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">;
258+
def psrlq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>)">;
259+
def psllw128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Vector<8, short>)">;
260+
def pslld128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">;
261+
def psllq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>)">;
264262
}
265263

266264
let Features = "sse3", Attributes = [NoThrow] in {
@@ -579,14 +577,6 @@ let Features = "avx2", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] i
579577
def psadbw256
580578
: X86Builtin<
581579
"_Vector<4, long long int>(_Vector<32, char>, _Vector<32, char>)">;
582-
def psllw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
583-
def pslld256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>)">;
584-
def psllq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<2, long long int>)">;
585-
def psraw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
586-
def psrad256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>)">;
587-
def psrlw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
588-
def psrld256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>)">;
589-
def psrlq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<2, long long int>)">;
590580
def permdf256 : X86Builtin<"_Vector<4, double>(_Vector<4, double>, _Constant int)">;
591581
def permti256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<4, long long int>, _Constant int)">;
592582
def permdi256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Constant int)">;
@@ -663,6 +653,15 @@ let Features = "avx2", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWi
663653

664654
def permvarsi256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<8, int>)">;
665655
def permvarsf256 : X86Builtin<"_Vector<8, float>(_Vector<8, float>, _Vector<8, int>)">;
656+
657+
def psllw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
658+
def pslld256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>)">;
659+
def psllq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<2, long long int>)">;
660+
def psraw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
661+
def psrad256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>)">;
662+
def psrlw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
663+
def psrld256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>)">;
664+
def psrlq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<2, long long int>)">;
666665
}
667666

668667
let Features = "avx2", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
@@ -1926,16 +1925,13 @@ let Features = "avx512vl", Attributes = [NoThrow, Const, Constexpr, RequiredVect
19261925
def prorq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Constant int)">;
19271926
}
19281927

1929-
let Features = "avx512bw", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
1930-
def psllw512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Vector<8, short>)">;
1931-
}
1932-
19331928
let Features = "avx512bw", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
19341929
def pmaddubsw512 : X86Builtin<"_Vector<32, short>(_Vector<64, char>, _Vector<64, char>)">;
19351930
def pmaddwd512 : X86Builtin<"_Vector<16, int>(_Vector<32, short>, _Vector<32, short>)">;
19361931
def psllv32hi : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Vector<32, short>)">;
19371932
def pshufhw512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Constant int)">;
19381933
def pshuflw512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Constant int)">;
1934+
def psllw512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Vector<8, short>)">;
19391935
}
19401936

19411937
let Features = "avx512bw,avx512vl", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
@@ -1991,7 +1987,7 @@ let Features = "avx512vl", Attributes = [NoThrow, Const, Constexpr, RequiredVect
19911987
def psravq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<4, long long int>)">;
19921988
}
19931989

1994-
let Features = "avx512bw", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
1990+
let Features = "avx512bw", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
19951991
def psraw512
19961992
: X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Vector<8, short>)">;
19971993
def psrlw512
@@ -2308,25 +2304,17 @@ let Features = "avx512f",
23082304
def psraqi512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, int)">;
23092305
}
23102306

2311-
let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
2307+
let Features = "avx512vl", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
23122308
def psraq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>)">;
2313-
}
2314-
2315-
let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
2316-
def psraq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<2, long long int>)">;
2317-
}
2318-
2319-
let Features = "avx512vl",
2320-
Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
23212309
def psraqi128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, int)">;
23222310
}
23232311

2324-
let Features = "avx512vl",
2325-
Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
2312+
let Features = "avx512vl", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
2313+
def psraq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<2, long long int>)">;
23262314
def psraqi256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, int)">;
23272315
}
23282316

2329-
let Features = "avx512f", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
2317+
let Features = "avx512f", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
23302318
def pslld512 : X86Builtin<"_Vector<16, int>(_Vector<16, int>, _Vector<4, int>)">;
23312319
def psllq512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, _Vector<2, long long int>)">;
23322320
def psrad512 : X86Builtin<"_Vector<16, int>(_Vector<16, int>, _Vector<4, int>)">;

clang/lib/AST/ByteCode/InterpBuiltin.cpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3468,6 +3468,69 @@ static bool interp__builtin_ia32_shuffle_generic(
34683468
return true;
34693469
}
34703470

3471+
static bool interp__builtin_ia32_shift_with_count(
3472+
InterpState &S, CodePtr OpPC, const CallExpr *Call,
3473+
llvm::function_ref<APInt(const APInt &, uint64_t)> ShiftOp,
3474+
llvm::function_ref<APInt(const APInt &, unsigned)> OverflowOp) {
3475+
3476+
assert(Call->getNumArgs() == 2);
3477+
3478+
const Pointer &Count = S.Stk.pop<Pointer>();
3479+
const Pointer &Source = S.Stk.pop<Pointer>();
3480+
3481+
QualType SourceType = Call->getArg(0)->getType();
3482+
QualType CountType = Call->getArg(1)->getType();
3483+
assert(SourceType->isVectorType() && CountType->isVectorType());
3484+
3485+
const auto *SourceVecT = SourceType->castAs<VectorType>();
3486+
const auto *CountVecT = CountType->castAs<VectorType>();
3487+
PrimType SourceElemT = *S.getContext().classify(SourceVecT->getElementType());
3488+
PrimType CountElemT = *S.getContext().classify(CountVecT->getElementType());
3489+
3490+
const Pointer &Dst = S.Stk.peek<Pointer>();
3491+
3492+
unsigned DestEltWidth =
3493+
S.getASTContext().getTypeSize(SourceVecT->getElementType());
3494+
bool IsDestUnsigned = SourceVecT->getElementType()->isUnsignedIntegerType();
3495+
unsigned DestLen = SourceVecT->getNumElements();
3496+
unsigned CountEltWidth =
3497+
S.getASTContext().getTypeSize(CountVecT->getElementType());
3498+
unsigned NumBitsInQWord = 64;
3499+
unsigned NumCountElts = NumBitsInQWord / CountEltWidth;
3500+
3501+
uint64_t CountLQWord = 0;
3502+
for (unsigned EltIdx = 0; EltIdx != NumCountElts; ++EltIdx) {
3503+
uint64_t Elt = 0;
3504+
INT_TYPE_SWITCH(CountElemT,
3505+
{ Elt = static_cast<uint64_t>(Count.elem<T>(EltIdx)); });
3506+
CountLQWord |= (Elt << (EltIdx * CountEltWidth));
3507+
}
3508+
3509+
for (unsigned EltIdx = 0; EltIdx != DestLen; ++EltIdx) {
3510+
APSInt Elt;
3511+
INT_TYPE_SWITCH(SourceElemT, { Elt = Source.elem<T>(EltIdx).toAPSInt(); });
3512+
3513+
APInt Result;
3514+
if (CountLQWord < DestEltWidth) {
3515+
Result = ShiftOp(Elt, CountLQWord);
3516+
} else {
3517+
Result = OverflowOp(Elt, DestEltWidth);
3518+
}
3519+
if (IsDestUnsigned) {
3520+
INT_TYPE_SWITCH(SourceElemT, {
3521+
Dst.elem<T>(EltIdx) = T::from(Result.getZExtValue());
3522+
});
3523+
} else {
3524+
INT_TYPE_SWITCH(SourceElemT, {
3525+
Dst.elem<T>(EltIdx) = T::from(Result.getSExtValue());
3526+
});
3527+
}
3528+
}
3529+
3530+
Dst.initializeAllElements();
3531+
return true;
3532+
}
3533+
34713534
static bool interp__builtin_ia32_shufbitqmb_mask(InterpState &S, CodePtr OpPC,
34723535
const CallExpr *Call) {
34733536

@@ -4971,6 +5034,48 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
49715034
case X86::BI__builtin_ia32_phminposuw128:
49725035
return interp__builtin_ia32_phminposuw(S, OpPC, Call);
49735036

5037+
case X86::BI__builtin_ia32_psraq128:
5038+
case X86::BI__builtin_ia32_psraq256:
5039+
case X86::BI__builtin_ia32_psraq512:
5040+
case X86::BI__builtin_ia32_psrad128:
5041+
case X86::BI__builtin_ia32_psrad256:
5042+
case X86::BI__builtin_ia32_psrad512:
5043+
case X86::BI__builtin_ia32_psraw128:
5044+
case X86::BI__builtin_ia32_psraw256:
5045+
case X86::BI__builtin_ia32_psraw512:
5046+
return interp__builtin_ia32_shift_with_count(
5047+
S, OpPC, Call,
5048+
[](const APInt &Elt, uint64_t Count) { return Elt.ashr(Count); },
5049+
[](const APInt &Elt, unsigned Width) { return Elt.ashr(Width - 1); });
5050+
5051+
case X86::BI__builtin_ia32_psllq128:
5052+
case X86::BI__builtin_ia32_psllq256:
5053+
case X86::BI__builtin_ia32_psllq512:
5054+
case X86::BI__builtin_ia32_pslld128:
5055+
case X86::BI__builtin_ia32_pslld256:
5056+
case X86::BI__builtin_ia32_pslld512:
5057+
case X86::BI__builtin_ia32_psllw128:
5058+
case X86::BI__builtin_ia32_psllw256:
5059+
case X86::BI__builtin_ia32_psllw512:
5060+
return interp__builtin_ia32_shift_with_count(
5061+
S, OpPC, Call,
5062+
[](const APInt &Elt, uint64_t Count) { return Elt.shl(Count); },
5063+
[](const APInt &Elt, unsigned Width) { return APInt::getZero(Width); });
5064+
5065+
case X86::BI__builtin_ia32_psrlq128:
5066+
case X86::BI__builtin_ia32_psrlq256:
5067+
case X86::BI__builtin_ia32_psrlq512:
5068+
case X86::BI__builtin_ia32_psrld128:
5069+
case X86::BI__builtin_ia32_psrld256:
5070+
case X86::BI__builtin_ia32_psrld512:
5071+
case X86::BI__builtin_ia32_psrlw128:
5072+
case X86::BI__builtin_ia32_psrlw256:
5073+
case X86::BI__builtin_ia32_psrlw512:
5074+
return interp__builtin_ia32_shift_with_count(
5075+
S, OpPC, Call,
5076+
[](const APInt &Elt, uint64_t Count) { return Elt.lshr(Count); },
5077+
[](const APInt &Elt, unsigned Width) { return APInt::getZero(Width); });
5078+
49745079
case X86::BI__builtin_ia32_pternlogd128_mask:
49755080
case X86::BI__builtin_ia32_pternlogd256_mask:
49765081
case X86::BI__builtin_ia32_pternlogd512_mask:

clang/lib/AST/ExprConstant.cpp

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12166,6 +12166,52 @@ static bool evalShuffleGeneric(
1216612166
return true;
1216712167
}
1216812168

12169+
static bool evalShiftWithCount(
12170+
EvalInfo &Info, const CallExpr *Call, APValue &Out,
12171+
llvm::function_ref<APInt(const APInt &, uint64_t)> ShiftOp,
12172+
llvm::function_ref<APInt(const APInt &, unsigned)> OverflowOp) {
12173+
12174+
APValue Source, Count;
12175+
if (!EvaluateAsRValue(Info, Call->getArg(0), Source) ||
12176+
!EvaluateAsRValue(Info, Call->getArg(1), Count))
12177+
return false;
12178+
12179+
assert(Call->getNumArgs() == 2);
12180+
12181+
QualType SourceTy = Call->getArg(0)->getType();
12182+
QualType CountTy = Call->getArg(1)->getType();
12183+
assert(SourceTy->isVectorType() && CountTy->isVectorType());
12184+
12185+
QualType DestEltTy = SourceTy->castAs<VectorType>()->getElementType();
12186+
unsigned DestEltWidth = Source.getVectorElt(0).getInt().getBitWidth();
12187+
unsigned DestLen = Source.getVectorLength();
12188+
bool IsDestUnsigned = DestEltTy->isUnsignedIntegerType();
12189+
unsigned CountEltWidth = Count.getVectorElt(0).getInt().getBitWidth();
12190+
unsigned NumBitsInQWord = 64;
12191+
unsigned NumCountElts = NumBitsInQWord / CountEltWidth;
12192+
SmallVector<APValue, 64> Result;
12193+
Result.reserve(DestLen);
12194+
12195+
uint64_t CountLQWord = 0;
12196+
for (unsigned EltIdx = 0; EltIdx != NumCountElts; ++EltIdx) {
12197+
uint64_t Elt = Count.getVectorElt(EltIdx).getInt().getZExtValue();
12198+
CountLQWord |= (Elt << (EltIdx * CountEltWidth));
12199+
}
12200+
12201+
for (unsigned EltIdx = 0; EltIdx != DestLen; ++EltIdx) {
12202+
APInt Elt = Source.getVectorElt(EltIdx).getInt();
12203+
if (CountLQWord < DestEltWidth) {
12204+
Result.push_back(
12205+
APValue(APSInt(ShiftOp(Elt, CountLQWord), IsDestUnsigned)));
12206+
} else {
12207+
Result.push_back(
12208+
APValue(APSInt(OverflowOp(Elt, DestEltWidth), IsDestUnsigned)));
12209+
}
12210+
}
12211+
Out = APValue(Result.data(), Result.size());
12212+
return true;
12213+
}
12214+
1216912215
bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1217012216
if (!IsConstantEvaluatedBuiltinCall(E))
1217112217
return ExprEvaluatorBaseTy::VisitCallExpr(E);
@@ -13169,6 +13215,66 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1316913215
return Success(APValue(Result.data(), Result.size()), E);
1317013216
}
1317113217

13218+
case X86::BI__builtin_ia32_psraq128:
13219+
case X86::BI__builtin_ia32_psraq256:
13220+
case X86::BI__builtin_ia32_psraq512:
13221+
case X86::BI__builtin_ia32_psrad128:
13222+
case X86::BI__builtin_ia32_psrad256:
13223+
case X86::BI__builtin_ia32_psrad512:
13224+
case X86::BI__builtin_ia32_psraw128:
13225+
case X86::BI__builtin_ia32_psraw256:
13226+
case X86::BI__builtin_ia32_psraw512: {
13227+
APValue R;
13228+
if (!evalShiftWithCount(
13229+
Info, E, R,
13230+
[](const APInt &Elt, uint64_t Count) { return Elt.ashr(Count); },
13231+
[](const APInt &Elt, unsigned Width) {
13232+
return Elt.ashr(Width - 1);
13233+
}))
13234+
return false;
13235+
return Success(R, E);
13236+
}
13237+
13238+
case X86::BI__builtin_ia32_psllq128:
13239+
case X86::BI__builtin_ia32_psllq256:
13240+
case X86::BI__builtin_ia32_psllq512:
13241+
case X86::BI__builtin_ia32_pslld128:
13242+
case X86::BI__builtin_ia32_pslld256:
13243+
case X86::BI__builtin_ia32_pslld512:
13244+
case X86::BI__builtin_ia32_psllw128:
13245+
case X86::BI__builtin_ia32_psllw256:
13246+
case X86::BI__builtin_ia32_psllw512: {
13247+
APValue R;
13248+
if (!evalShiftWithCount(
13249+
Info, E, R,
13250+
[](const APInt &Elt, uint64_t Count) { return Elt.shl(Count); },
13251+
[](const APInt &Elt, unsigned Width) {
13252+
return APInt::getZero(Width);
13253+
}))
13254+
return false;
13255+
return Success(R, E);
13256+
}
13257+
13258+
case X86::BI__builtin_ia32_psrlq128:
13259+
case X86::BI__builtin_ia32_psrlq256:
13260+
case X86::BI__builtin_ia32_psrlq512:
13261+
case X86::BI__builtin_ia32_psrld128:
13262+
case X86::BI__builtin_ia32_psrld256:
13263+
case X86::BI__builtin_ia32_psrld512:
13264+
case X86::BI__builtin_ia32_psrlw128:
13265+
case X86::BI__builtin_ia32_psrlw256:
13266+
case X86::BI__builtin_ia32_psrlw512: {
13267+
APValue R;
13268+
if (!evalShiftWithCount(
13269+
Info, E, R,
13270+
[](const APInt &Elt, uint64_t Count) { return Elt.lshr(Count); },
13271+
[](const APInt &Elt, unsigned Width) {
13272+
return APInt::getZero(Width);
13273+
}))
13274+
return false;
13275+
return Success(R, E);
13276+
}
13277+
1317213278
case X86::BI__builtin_ia32_pternlogd128_mask:
1317313279
case X86::BI__builtin_ia32_pternlogd256_mask:
1317413280
case X86::BI__builtin_ia32_pternlogd512_mask:

0 commit comments

Comments
 (0)