Skip to content

Commit 8d7df8b

Browse files
authored
[X86] Allow AVX2 per-element shift intrinsics to be used in constexpr (#154780)
This handles constant folding for the AVX2 per-element shift intrinsics, which handle out of bounds shift amounts (logical result = 0, arithmetic result = signbit splat) AVX512 intrinsics will follow in follow up patches First stage of #154287
1 parent 4ab5efd commit 8d7df8b

File tree

4 files changed

+82
-67
lines changed

4 files changed

+82
-67
lines changed

clang/include/clang/Basic/BuiltinsX86.td

Lines changed: 14 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -627,11 +627,23 @@ let Features = "avx2", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] i
627627
let Features = "avx2", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
628628
def pmuldq256 : X86Builtin<"_Vector<4, long long int>(_Vector<8, int>, _Vector<8, int>)">;
629629
def pmuludq256 : X86Builtin<"_Vector<4, long long int>(_Vector<8, int>, _Vector<8, int>)">;
630-
}
631630

632-
let Features = "avx2", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
633631
def pmulhuw256 : X86Builtin<"_Vector<16, unsigned short>(_Vector<16, unsigned short>, _Vector<16, unsigned short>)">;
634632
def pmulhw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<16, short>)">;
633+
634+
def psllv8si : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<8, int>)">;
635+
def psrav8si : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<8, int>)">;
636+
def psrlv8si : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<8, int>)">;
637+
def psllv4di : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<4, long long int>)">;
638+
def psrlv4di : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<4, long long int>)">;
639+
}
640+
641+
let Features = "avx2", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
642+
def psllv4si : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">;
643+
def psrav4si : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">;
644+
def psrlv4si : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">;
645+
def psllv2di : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>)">;
646+
def psrlv2di : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>)">;
635647
}
636648

637649
let Features = "avx2", Attributes = [NoThrow, RequiredVectorWidth<256>] in {
@@ -654,46 +666,6 @@ let Features = "avx2", Attributes = [NoThrow, RequiredVectorWidth<128>] in {
654666
def maskstoreq : X86Builtin<"void(_Vector<2, long long int *>, _Vector<2, long long int>, _Vector<2, long long int>)">;
655667
}
656668

657-
let Features = "avx2", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
658-
def psllv8si : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<8, int>)">;
659-
}
660-
661-
let Features = "avx2", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
662-
def psllv4si : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">;
663-
}
664-
665-
let Features = "avx2", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
666-
def psllv4di : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<4, long long int>)">;
667-
}
668-
669-
let Features = "avx2", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
670-
def psllv2di : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>)">;
671-
}
672-
673-
let Features = "avx2", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
674-
def psrav8si : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<8, int>)">;
675-
}
676-
677-
let Features = "avx2", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
678-
def psrav4si : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">;
679-
}
680-
681-
let Features = "avx2", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
682-
def psrlv8si : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<8, int>)">;
683-
}
684-
685-
let Features = "avx2", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
686-
def psrlv4si : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">;
687-
}
688-
689-
let Features = "avx2", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
690-
def psrlv4di : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<4, long long int>)">;
691-
}
692-
693-
let Features = "avx2", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
694-
def psrlv2di : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>)">;
695-
}
696-
697669
let Features = "avx2", Attributes = [NoThrow, RequiredVectorWidth<128>] in {
698670
def gatherd_pd : X86Builtin<"_Vector<2, double>(_Vector<2, double>, double const *, _Vector<4, int>, _Vector<2, double>, _Constant char)">;
699671
}

clang/lib/AST/ExprConstant.cpp

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11669,13 +11669,24 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1166911669
case clang::X86::BI__builtin_ia32_pmulhuw512:
1167011670
case clang::X86::BI__builtin_ia32_pmulhw128:
1167111671
case clang::X86::BI__builtin_ia32_pmulhw256:
11672-
case clang::X86::BI__builtin_ia32_pmulhw512: {
11672+
case clang::X86::BI__builtin_ia32_pmulhw512:
11673+
case clang::X86::BI__builtin_ia32_psllv2di:
11674+
case clang::X86::BI__builtin_ia32_psllv4di:
11675+
case clang::X86::BI__builtin_ia32_psllv4si:
11676+
case clang::X86::BI__builtin_ia32_psllv8si:
11677+
case clang::X86::BI__builtin_ia32_psrav4si:
11678+
case clang::X86::BI__builtin_ia32_psrav8si:
11679+
case clang::X86::BI__builtin_ia32_psrlv2di:
11680+
case clang::X86::BI__builtin_ia32_psrlv4di:
11681+
case clang::X86::BI__builtin_ia32_psrlv4si:
11682+
case clang::X86::BI__builtin_ia32_psrlv8si:{
1167311683
APValue SourceLHS, SourceRHS;
1167411684
if (!EvaluateAsRValue(Info, E->getArg(0), SourceLHS) ||
1167511685
!EvaluateAsRValue(Info, E->getArg(1), SourceRHS))
1167611686
return false;
1167711687

1167811688
QualType DestEltTy = E->getType()->castAs<VectorType>()->getElementType();
11689+
bool DestUnsigned = DestEltTy->isUnsignedIntegerOrEnumerationType();
1167911690
unsigned SourceLen = SourceLHS.getVectorLength();
1168011691
SmallVector<APValue, 4> ResultElements;
1168111692
ResultElements.reserve(SourceLen);
@@ -11687,12 +11698,12 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1168711698
case Builtin::BI__builtin_elementwise_add_sat:
1168811699
ResultElements.push_back(APValue(
1168911700
APSInt(LHS.isSigned() ? LHS.sadd_sat(RHS) : LHS.uadd_sat(RHS),
11690-
DestEltTy->isUnsignedIntegerOrEnumerationType())));
11701+
DestUnsigned)));
1169111702
break;
1169211703
case Builtin::BI__builtin_elementwise_sub_sat:
1169311704
ResultElements.push_back(APValue(
1169411705
APSInt(LHS.isSigned() ? LHS.ssub_sat(RHS) : LHS.usub_sat(RHS),
11695-
DestEltTy->isUnsignedIntegerOrEnumerationType())));
11706+
DestUnsigned)));
1169611707
break;
1169711708
case clang::X86::BI__builtin_ia32_pmulhuw128:
1169811709
case clang::X86::BI__builtin_ia32_pmulhuw256:
@@ -11706,6 +11717,40 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1170611717
ResultElements.push_back(APValue(APSInt(llvm::APIntOps::mulhs(LHS, RHS),
1170711718
/*isUnsigned=*/false)));
1170811719
break;
11720+
case clang::X86::BI__builtin_ia32_psllv2di:
11721+
case clang::X86::BI__builtin_ia32_psllv4di:
11722+
case clang::X86::BI__builtin_ia32_psllv4si:
11723+
case clang::X86::BI__builtin_ia32_psllv8si:
11724+
if (RHS.uge(RHS.getBitWidth())) {
11725+
ResultElements.push_back(
11726+
APValue(APSInt(APInt::getZero(RHS.getBitWidth()), DestUnsigned)));
11727+
break;
11728+
}
11729+
ResultElements.push_back(
11730+
APValue(APSInt(LHS.shl(RHS.getZExtValue()), DestUnsigned)));
11731+
break;
11732+
case clang::X86::BI__builtin_ia32_psrav4si:
11733+
case clang::X86::BI__builtin_ia32_psrav8si:
11734+
if (RHS.uge(RHS.getBitWidth())) {
11735+
ResultElements.push_back(
11736+
APValue(APSInt(LHS.ashr(RHS.getBitWidth() - 1), DestUnsigned)));
11737+
break;
11738+
}
11739+
ResultElements.push_back(
11740+
APValue(APSInt(LHS.ashr(RHS.getZExtValue()), DestUnsigned)));
11741+
break;
11742+
case clang::X86::BI__builtin_ia32_psrlv2di:
11743+
case clang::X86::BI__builtin_ia32_psrlv4di:
11744+
case clang::X86::BI__builtin_ia32_psrlv4si:
11745+
case clang::X86::BI__builtin_ia32_psrlv8si:
11746+
if (RHS.uge(RHS.getBitWidth())) {
11747+
ResultElements.push_back(
11748+
APValue(APSInt(APInt::getZero(RHS.getBitWidth()), DestUnsigned)));
11749+
break;
11750+
}
11751+
ResultElements.push_back(
11752+
APValue(APSInt(LHS.lshr(RHS.getZExtValue()), DestUnsigned)));
11753+
break;
1170911754
}
1171011755
}
1171111756

clang/lib/Headers/avx2intrin.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3721,7 +3721,7 @@ _mm_maskstore_epi64(long long *__X, __m128i __M, __m128i __Y)
37213721
/// A 256-bit vector of [8 x i32] containing the unsigned shift counts (in
37223722
/// bits).
37233723
/// \returns A 256-bit vector of [8 x i32] containing the result.
3724-
static __inline__ __m256i __DEFAULT_FN_ATTRS256
3724+
static __inline__ __m256i __DEFAULT_FN_ATTRS256_CONSTEXPR
37253725
_mm256_sllv_epi32(__m256i __X, __m256i __Y)
37263726
{
37273727
return (__m256i)__builtin_ia32_psllv8si((__v8si)__X, (__v8si)__Y);
@@ -3743,7 +3743,7 @@ _mm256_sllv_epi32(__m256i __X, __m256i __Y)
37433743
/// A 128-bit vector of [4 x i32] containing the unsigned shift counts (in
37443744
/// bits).
37453745
/// \returns A 128-bit vector of [4 x i32] containing the result.
3746-
static __inline__ __m128i __DEFAULT_FN_ATTRS128
3746+
static __inline__ __m128i __DEFAULT_FN_ATTRS128_CONSTEXPR
37473747
_mm_sllv_epi32(__m128i __X, __m128i __Y)
37483748
{
37493749
return (__m128i)__builtin_ia32_psllv4si((__v4si)__X, (__v4si)__Y);
@@ -3765,7 +3765,7 @@ _mm_sllv_epi32(__m128i __X, __m128i __Y)
37653765
/// A 256-bit vector of [4 x i64] containing the unsigned shift counts (in
37663766
/// bits).
37673767
/// \returns A 256-bit vector of [4 x i64] containing the result.
3768-
static __inline__ __m256i __DEFAULT_FN_ATTRS256
3768+
static __inline__ __m256i __DEFAULT_FN_ATTRS256_CONSTEXPR
37693769
_mm256_sllv_epi64(__m256i __X, __m256i __Y)
37703770
{
37713771
return (__m256i)__builtin_ia32_psllv4di((__v4di)__X, (__v4di)__Y);
@@ -3787,7 +3787,7 @@ _mm256_sllv_epi64(__m256i __X, __m256i __Y)
37873787
/// A 128-bit vector of [2 x i64] containing the unsigned shift counts (in
37883788
/// bits).
37893789
/// \returns A 128-bit vector of [2 x i64] containing the result.
3790-
static __inline__ __m128i __DEFAULT_FN_ATTRS128
3790+
static __inline__ __m128i __DEFAULT_FN_ATTRS128_CONSTEXPR
37913791
_mm_sllv_epi64(__m128i __X, __m128i __Y)
37923792
{
37933793
return (__m128i)__builtin_ia32_psllv2di((__v2di)__X, (__v2di)__Y);
@@ -3810,7 +3810,7 @@ _mm_sllv_epi64(__m128i __X, __m128i __Y)
38103810
/// A 256-bit vector of [8 x i32] containing the unsigned shift counts (in
38113811
/// bits).
38123812
/// \returns A 256-bit vector of [8 x i32] containing the result.
3813-
static __inline__ __m256i __DEFAULT_FN_ATTRS256
3813+
static __inline__ __m256i __DEFAULT_FN_ATTRS256_CONSTEXPR
38143814
_mm256_srav_epi32(__m256i __X, __m256i __Y)
38153815
{
38163816
return (__m256i)__builtin_ia32_psrav8si((__v8si)__X, (__v8si)__Y);
@@ -3833,7 +3833,7 @@ _mm256_srav_epi32(__m256i __X, __m256i __Y)
38333833
/// A 128-bit vector of [4 x i32] containing the unsigned shift counts (in
38343834
/// bits).
38353835
/// \returns A 128-bit vector of [4 x i32] containing the result.
3836-
static __inline__ __m128i __DEFAULT_FN_ATTRS128
3836+
static __inline__ __m128i __DEFAULT_FN_ATTRS128_CONSTEXPR
38373837
_mm_srav_epi32(__m128i __X, __m128i __Y)
38383838
{
38393839
return (__m128i)__builtin_ia32_psrav4si((__v4si)__X, (__v4si)__Y);
@@ -3855,7 +3855,7 @@ _mm_srav_epi32(__m128i __X, __m128i __Y)
38553855
/// A 256-bit vector of [8 x i32] containing the unsigned shift counts (in
38563856
/// bits).
38573857
/// \returns A 256-bit vector of [8 x i32] containing the result.
3858-
static __inline__ __m256i __DEFAULT_FN_ATTRS256
3858+
static __inline__ __m256i __DEFAULT_FN_ATTRS256_CONSTEXPR
38593859
_mm256_srlv_epi32(__m256i __X, __m256i __Y)
38603860
{
38613861
return (__m256i)__builtin_ia32_psrlv8si((__v8si)__X, (__v8si)__Y);
@@ -3877,7 +3877,7 @@ _mm256_srlv_epi32(__m256i __X, __m256i __Y)
38773877
/// A 128-bit vector of [4 x i32] containing the unsigned shift counts (in
38783878
/// bits).
38793879
/// \returns A 128-bit vector of [4 x i32] containing the result.
3880-
static __inline__ __m128i __DEFAULT_FN_ATTRS128
3880+
static __inline__ __m128i __DEFAULT_FN_ATTRS128_CONSTEXPR
38813881
_mm_srlv_epi32(__m128i __X, __m128i __Y)
38823882
{
38833883
return (__m128i)__builtin_ia32_psrlv4si((__v4si)__X, (__v4si)__Y);
@@ -3899,7 +3899,7 @@ _mm_srlv_epi32(__m128i __X, __m128i __Y)
38993899
/// A 256-bit vector of [4 x i64] containing the unsigned shift counts (in
39003900
/// bits).
39013901
/// \returns A 256-bit vector of [4 x i64] containing the result.
3902-
static __inline__ __m256i __DEFAULT_FN_ATTRS256
3902+
static __inline__ __m256i __DEFAULT_FN_ATTRS256_CONSTEXPR
39033903
_mm256_srlv_epi64(__m256i __X, __m256i __Y)
39043904
{
39053905
return (__m256i)__builtin_ia32_psrlv4di((__v4di)__X, (__v4di)__Y);
@@ -3921,7 +3921,7 @@ _mm256_srlv_epi64(__m256i __X, __m256i __Y)
39213921
/// A 128-bit vector of [2 x i64] containing the unsigned shift counts (in
39223922
/// bits).
39233923
/// \returns A 128-bit vector of [2 x i64] containing the result.
3924-
static __inline__ __m128i __DEFAULT_FN_ATTRS128
3924+
static __inline__ __m128i __DEFAULT_FN_ATTRS128_CONSTEXPR
39253925
_mm_srlv_epi64(__m128i __X, __m128i __Y)
39263926
{
39273927
return (__m128i)__builtin_ia32_psrlv2di((__v2di)__X, (__v2di)__Y);

0 commit comments

Comments
 (0)