Skip to content

Commit e036bfa

Browse files
committed
[Clang] VectorExprEvaluator::VisitCallExpr / InterpretBuiltin - Allow PSLL/PSRA/PSRL var intrinsics to be used in constexpr
Resolves:#169176
1 parent ba98668 commit e036bfa

File tree

13 files changed

+520
-260
lines changed

13 files changed

+520
-260
lines changed

clang/include/clang/Basic/BuiltinsX86.td

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -214,17 +214,6 @@ let Header = "emmintrin.h", Attributes = [NoThrow, RequireDeclaration] in {
214214
def _mm_pause : X86LibBuiltin<"void()">;
215215
}
216216

217-
let Features = "sse2", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
218-
def psraw128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Vector<8, short>)">;
219-
def psrad128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">;
220-
def psrlw128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Vector<8, short>)">;
221-
def psrld128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">;
222-
def psrlq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>)">;
223-
def psllw128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Vector<8, short>)">;
224-
def pslld128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">;
225-
def psllq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>)">;
226-
}
227-
228217
let Features = "sse2", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
229218
def movmskpd : X86Builtin<"int(_Vector<2, double>)">;
230219
def pmovmskb128 : X86Builtin<"int(_Vector<16, char>)">;
@@ -265,6 +254,15 @@ let Features = "sse2", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWi
265254
def vec_ext_v4sf : X86Builtin<"float(_Vector<4, float>, _Constant int)">;
266255
def vec_ext_v8hi : X86Builtin<"short(_Vector<8, short>, _Constant int)">;
267256
def vec_set_v8hi : X86Builtin<"_Vector<8, short>(_Vector<8, short>, short, _Constant int)">;
257+
258+
def psraw128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Vector<8, short>)">;
259+
def psrad128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">;
260+
def psrlw128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Vector<8, short>)">;
261+
def psrld128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">;
262+
def psrlq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>)">;
263+
def psllw128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Vector<8, short>)">;
264+
def pslld128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">;
265+
def psllq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>)">;
268266
}
269267

270268
let Features = "sse3", Attributes = [NoThrow] in {
@@ -585,14 +583,6 @@ let Features = "avx2", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] i
585583
def psadbw256
586584
: X86Builtin<
587585
"_Vector<4, long long int>(_Vector<32, char>, _Vector<32, char>)">;
588-
def psllw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
589-
def pslld256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>)">;
590-
def psllq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<2, long long int>)">;
591-
def psraw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
592-
def psrad256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>)">;
593-
def psrlw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
594-
def psrld256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>)">;
595-
def psrlq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<2, long long int>)">;
596586
def permdf256 : X86Builtin<"_Vector<4, double>(_Vector<4, double>, _Constant int)">;
597587
def permti256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<4, long long int>, _Constant int)">;
598588
def permdi256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Constant int)">;
@@ -669,6 +659,15 @@ let Features = "avx2", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWi
669659

670660
def permvarsi256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<8, int>)">;
671661
def permvarsf256 : X86Builtin<"_Vector<8, float>(_Vector<8, float>, _Vector<8, int>)">;
662+
663+
def psllw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
664+
def pslld256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>)">;
665+
def psllq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<2, long long int>)">;
666+
def psraw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
667+
def psrad256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>)">;
668+
def psrlw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
669+
def psrld256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>)">;
670+
def psrlq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<2, long long int>)">;
672671
}
673672

674673
let Features = "avx2", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
@@ -1930,16 +1929,13 @@ let Features = "avx512vl", Attributes = [NoThrow, Const, Constexpr, RequiredVect
19301929
def prorq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Constant int)">;
19311930
}
19321931

1933-
let Features = "avx512bw", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
1934-
def psllw512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Vector<8, short>)">;
1935-
}
1936-
19371932
let Features = "avx512bw", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
19381933
def pmaddubsw512 : X86Builtin<"_Vector<32, short>(_Vector<64, char>, _Vector<64, char>)">;
19391934
def pmaddwd512 : X86Builtin<"_Vector<16, int>(_Vector<32, short>, _Vector<32, short>)">;
19401935
def psllv32hi : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Vector<32, short>)">;
19411936
def pshufhw512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Constant int)">;
19421937
def pshuflw512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Constant int)">;
1938+
def psllw512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Vector<8, short>)">;
19431939
}
19441940

19451941
let Features = "avx512bw,avx512vl", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
@@ -1995,7 +1991,7 @@ let Features = "avx512vl", Attributes = [NoThrow, Const, Constexpr, RequiredVect
19951991
def psravq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<4, long long int>)">;
19961992
}
19971993

1998-
let Features = "avx512bw", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
1994+
let Features = "avx512bw", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
19991995
def psraw512
20001996
: X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Vector<8, short>)">;
20011997
def psrlw512
@@ -2312,25 +2308,17 @@ let Features = "avx512f",
23122308
def psraqi512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, int)">;
23132309
}
23142310

2315-
let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
2311+
let Features = "avx512vl", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
23162312
def psraq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>)">;
2317-
}
2318-
2319-
let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
2320-
def psraq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<2, long long int>)">;
2321-
}
2322-
2323-
let Features = "avx512vl",
2324-
Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
23252313
def psraqi128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, int)">;
23262314
}
23272315

2328-
let Features = "avx512vl",
2329-
Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
2316+
let Features = "avx512vl", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
2317+
def psraq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<2, long long int>)">;
23302318
def psraqi256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, int)">;
23312319
}
23322320

2333-
let Features = "avx512f", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
2321+
let Features = "avx512f", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
23342322
def pslld512 : X86Builtin<"_Vector<16, int>(_Vector<16, int>, _Vector<4, int>)">;
23352323
def psllq512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, _Vector<2, long long int>)">;
23362324
def psrad512 : X86Builtin<"_Vector<16, int>(_Vector<16, int>, _Vector<4, int>)">;

clang/lib/AST/ByteCode/InterpBuiltin.cpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3468,6 +3468,69 @@ static bool interp__builtin_ia32_shuffle_generic(
34683468
return true;
34693469
}
34703470

3471+
static bool interp__builtin_ia32_shift_with_count(
3472+
InterpState &S, CodePtr OpPC, const CallExpr *Call,
3473+
llvm::function_ref<APInt(const APInt &, uint64_t)> ShiftOp,
3474+
llvm::function_ref<APInt(const APInt &, unsigned)> OverflowOp) {
3475+
3476+
assert(Call->getNumArgs() == 2);
3477+
3478+
const Pointer &Count = S.Stk.pop<Pointer>();
3479+
const Pointer &Source = S.Stk.pop<Pointer>();
3480+
3481+
QualType SourceType = Call->getArg(0)->getType();
3482+
QualType CountType = Call->getArg(1)->getType();
3483+
assert(SourceType->isVectorType() && CountType->isVectorType());
3484+
3485+
const auto *SourceVecT = SourceType->castAs<VectorType>();
3486+
const auto *CountVecT = CountType->castAs<VectorType>();
3487+
PrimType SourceElemT = *S.getContext().classify(SourceVecT->getElementType());
3488+
PrimType CountElemT = *S.getContext().classify(CountVecT->getElementType());
3489+
3490+
const Pointer &Dst = S.Stk.peek<Pointer>();
3491+
3492+
unsigned DestEltWidth =
3493+
S.getASTContext().getTypeSize(SourceVecT->getElementType());
3494+
bool IsDestUnsigned = SourceVecT->getElementType()->isUnsignedIntegerType();
3495+
unsigned DestLen = SourceVecT->getNumElements();
3496+
unsigned CountEltWidth =
3497+
S.getASTContext().getTypeSize(CountVecT->getElementType());
3498+
unsigned NumBitsInQWord = 64;
3499+
unsigned NumCountElts = NumBitsInQWord / CountEltWidth;
3500+
3501+
uint64_t CountLQWord = 0;
3502+
for (unsigned EltIdx = 0; EltIdx != NumCountElts; ++EltIdx) {
3503+
uint64_t Elt = 0;
3504+
INT_TYPE_SWITCH(CountElemT,
3505+
{ Elt = static_cast<uint64_t>(Count.elem<T>(EltIdx)); });
3506+
CountLQWord |= (Elt << (EltIdx * CountEltWidth));
3507+
}
3508+
3509+
for (unsigned EltIdx = 0; EltIdx != DestLen; ++EltIdx) {
3510+
APSInt Elt;
3511+
INT_TYPE_SWITCH(SourceElemT, { Elt = Source.elem<T>(EltIdx).toAPSInt(); });
3512+
3513+
APInt Result;
3514+
if (CountLQWord < DestEltWidth) {
3515+
Result = ShiftOp(Elt, CountLQWord);
3516+
} else {
3517+
Result = OverflowOp(Elt, DestEltWidth);
3518+
}
3519+
if (IsDestUnsigned) {
3520+
INT_TYPE_SWITCH(SourceElemT, {
3521+
Dst.elem<T>(EltIdx) = T::from(Result.getZExtValue());
3522+
});
3523+
} else {
3524+
INT_TYPE_SWITCH(SourceElemT, {
3525+
Dst.elem<T>(EltIdx) = T::from(Result.getSExtValue());
3526+
});
3527+
}
3528+
}
3529+
3530+
Dst.initializeAllElements();
3531+
return true;
3532+
}
3533+
34713534
static bool interp__builtin_ia32_shufbitqmb_mask(InterpState &S, CodePtr OpPC,
34723535
const CallExpr *Call) {
34733536

@@ -4826,6 +4889,48 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
48264889
case X86::BI__builtin_ia32_phminposuw128:
48274890
return interp__builtin_ia32_phminposuw(S, OpPC, Call);
48284891

4892+
case X86::BI__builtin_ia32_psraq128:
4893+
case X86::BI__builtin_ia32_psraq256:
4894+
case X86::BI__builtin_ia32_psraq512:
4895+
case X86::BI__builtin_ia32_psrad128:
4896+
case X86::BI__builtin_ia32_psrad256:
4897+
case X86::BI__builtin_ia32_psrad512:
4898+
case X86::BI__builtin_ia32_psraw128:
4899+
case X86::BI__builtin_ia32_psraw256:
4900+
case X86::BI__builtin_ia32_psraw512:
4901+
return interp__builtin_ia32_shift_with_count(
4902+
S, OpPC, Call,
4903+
[](const APInt &Elt, uint64_t Count) { return Elt.ashr(Count); },
4904+
[](const APInt &Elt, unsigned Width) { return Elt.ashr(Width - 1); });
4905+
4906+
case X86::BI__builtin_ia32_psllq128:
4907+
case X86::BI__builtin_ia32_psllq256:
4908+
case X86::BI__builtin_ia32_psllq512:
4909+
case X86::BI__builtin_ia32_pslld128:
4910+
case X86::BI__builtin_ia32_pslld256:
4911+
case X86::BI__builtin_ia32_pslld512:
4912+
case X86::BI__builtin_ia32_psllw128:
4913+
case X86::BI__builtin_ia32_psllw256:
4914+
case X86::BI__builtin_ia32_psllw512:
4915+
return interp__builtin_ia32_shift_with_count(
4916+
S, OpPC, Call,
4917+
[](const APInt &Elt, uint64_t Count) { return Elt.shl(Count); },
4918+
[](const APInt &Elt, unsigned Width) { return APInt::getZero(Width); });
4919+
4920+
case X86::BI__builtin_ia32_psrlq128:
4921+
case X86::BI__builtin_ia32_psrlq256:
4922+
case X86::BI__builtin_ia32_psrlq512:
4923+
case X86::BI__builtin_ia32_psrld128:
4924+
case X86::BI__builtin_ia32_psrld256:
4925+
case X86::BI__builtin_ia32_psrld512:
4926+
case X86::BI__builtin_ia32_psrlw128:
4927+
case X86::BI__builtin_ia32_psrlw256:
4928+
case X86::BI__builtin_ia32_psrlw512:
4929+
return interp__builtin_ia32_shift_with_count(
4930+
S, OpPC, Call,
4931+
[](const APInt &Elt, uint64_t Count) { return Elt.lshr(Count); },
4932+
[](const APInt &Elt, unsigned Width) { return APInt::getZero(Width); });
4933+
48294934
case X86::BI__builtin_ia32_pternlogd128_mask:
48304935
case X86::BI__builtin_ia32_pternlogd256_mask:
48314936
case X86::BI__builtin_ia32_pternlogd512_mask:

clang/lib/AST/ExprConstant.cpp

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12166,6 +12166,52 @@ static bool evalShuffleGeneric(
1216612166
return true;
1216712167
}
1216812168

12169+
static bool evalShiftWithCount(
12170+
EvalInfo &Info, const CallExpr *Call, APValue &Out,
12171+
llvm::function_ref<APInt(const APInt &, uint64_t)> ShiftOp,
12172+
llvm::function_ref<APInt(const APInt &, unsigned)> OverflowOp) {
12173+
12174+
APValue Source, Count;
12175+
if (!EvaluateAsRValue(Info, Call->getArg(0), Source) ||
12176+
!EvaluateAsRValue(Info, Call->getArg(1), Count))
12177+
return false;
12178+
12179+
assert(Call->getNumArgs() == 2);
12180+
12181+
QualType SourceTy = Call->getArg(0)->getType();
12182+
QualType CountTy = Call->getArg(1)->getType();
12183+
assert(SourceTy->isVectorType() && CountTy->isVectorType());
12184+
12185+
QualType DestEltTy = SourceTy->castAs<VectorType>()->getElementType();
12186+
unsigned DestEltWidth = Source.getVectorElt(0).getInt().getBitWidth();
12187+
unsigned DestLen = Source.getVectorLength();
12188+
bool IsDestUnsigned = DestEltTy->isUnsignedIntegerType();
12189+
unsigned CountEltWidth = Count.getVectorElt(0).getInt().getBitWidth();
12190+
unsigned NumBitsInQWord = 64;
12191+
unsigned NumCountElts = NumBitsInQWord / CountEltWidth;
12192+
SmallVector<APValue, 64> Result;
12193+
Result.reserve(DestLen);
12194+
12195+
uint64_t CountLQWord = 0;
12196+
for (unsigned EltIdx = 0; EltIdx != NumCountElts; ++EltIdx) {
12197+
uint64_t Elt = Count.getVectorElt(EltIdx).getInt().getZExtValue();
12198+
CountLQWord |= (Elt << (EltIdx * CountEltWidth));
12199+
}
12200+
12201+
for (unsigned EltIdx = 0; EltIdx != DestLen; ++EltIdx) {
12202+
APInt Elt = Source.getVectorElt(EltIdx).getInt();
12203+
if (CountLQWord < DestEltWidth) {
12204+
Result.push_back(
12205+
APValue(APSInt(ShiftOp(Elt, CountLQWord), IsDestUnsigned)));
12206+
} else {
12207+
Result.push_back(
12208+
APValue(APSInt(OverflowOp(Elt, DestEltWidth), IsDestUnsigned)));
12209+
}
12210+
}
12211+
Out = APValue(Result.data(), Result.size());
12212+
return true;
12213+
}
12214+
1216912215
bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1217012216
if (!IsConstantEvaluatedBuiltinCall(E))
1217112217
return ExprEvaluatorBaseTy::VisitCallExpr(E);
@@ -13130,6 +13176,66 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1313013176
return Success(APValue(Result.data(), Result.size()), E);
1313113177
}
1313213178

13179+
case X86::BI__builtin_ia32_psraq128:
13180+
case X86::BI__builtin_ia32_psraq256:
13181+
case X86::BI__builtin_ia32_psraq512:
13182+
case X86::BI__builtin_ia32_psrad128:
13183+
case X86::BI__builtin_ia32_psrad256:
13184+
case X86::BI__builtin_ia32_psrad512:
13185+
case X86::BI__builtin_ia32_psraw128:
13186+
case X86::BI__builtin_ia32_psraw256:
13187+
case X86::BI__builtin_ia32_psraw512: {
13188+
APValue R;
13189+
if (!evalShiftWithCount(
13190+
Info, E, R,
13191+
[](const APInt &Elt, uint64_t Count) { return Elt.ashr(Count); },
13192+
[](const APInt &Elt, unsigned Width) {
13193+
return Elt.ashr(Width - 1);
13194+
}))
13195+
return false;
13196+
return Success(R, E);
13197+
}
13198+
13199+
case X86::BI__builtin_ia32_psllq128:
13200+
case X86::BI__builtin_ia32_psllq256:
13201+
case X86::BI__builtin_ia32_psllq512:
13202+
case X86::BI__builtin_ia32_pslld128:
13203+
case X86::BI__builtin_ia32_pslld256:
13204+
case X86::BI__builtin_ia32_pslld512:
13205+
case X86::BI__builtin_ia32_psllw128:
13206+
case X86::BI__builtin_ia32_psllw256:
13207+
case X86::BI__builtin_ia32_psllw512: {
13208+
APValue R;
13209+
if (!evalShiftWithCount(
13210+
Info, E, R,
13211+
[](const APInt &Elt, uint64_t Count) { return Elt.shl(Count); },
13212+
[](const APInt &Elt, unsigned Width) {
13213+
return APInt::getZero(Width);
13214+
}))
13215+
return false;
13216+
return Success(R, E);
13217+
}
13218+
13219+
case X86::BI__builtin_ia32_psrlq128:
13220+
case X86::BI__builtin_ia32_psrlq256:
13221+
case X86::BI__builtin_ia32_psrlq512:
13222+
case X86::BI__builtin_ia32_psrld128:
13223+
case X86::BI__builtin_ia32_psrld256:
13224+
case X86::BI__builtin_ia32_psrld512:
13225+
case X86::BI__builtin_ia32_psrlw128:
13226+
case X86::BI__builtin_ia32_psrlw256:
13227+
case X86::BI__builtin_ia32_psrlw512: {
13228+
APValue R;
13229+
if (!evalShiftWithCount(
13230+
Info, E, R,
13231+
[](const APInt &Elt, uint64_t Count) { return Elt.lshr(Count); },
13232+
[](const APInt &Elt, unsigned Width) {
13233+
return APInt::getZero(Width);
13234+
}))
13235+
return false;
13236+
return Success(R, E);
13237+
}
13238+
1313313239
case X86::BI__builtin_ia32_pternlogd128_mask:
1313413240
case X86::BI__builtin_ia32_pternlogd256_mask:
1313513241
case X86::BI__builtin_ia32_pternlogd512_mask:

0 commit comments

Comments
 (0)