Skip to content

Commit f9256ca

Browse files
authored
[Headers][X86] Allow AVX512 masked arithmetic ss/sd intrinsics to be used in constexpr (#162816)
This PR just resolves ss/sd part of AVX512 masked arithmetic intrinsics of #160559.
1 parent 128caa1 commit f9256ca

File tree

9 files changed

+178
-83
lines changed

9 files changed

+178
-83
lines changed

clang/include/clang/Basic/BuiltinsX86.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4117,15 +4117,15 @@ let Features = "avx512f", Attributes = [NoThrow, Const, Constexpr, RequiredVecto
41174117
def selectpd_512 : X86Builtin<"_Vector<8, double>(unsigned char, _Vector<8, double>, _Vector<8, double>)">;
41184118
}
41194119

4120-
let Features = "avx512fp16", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
4120+
let Features = "avx512fp16", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
41214121
def selectsh_128 : X86Builtin<"_Vector<8, _Float16>(unsigned char, _Vector<8, _Float16>, _Vector<8, _Float16>)">;
41224122
}
41234123

4124-
let Features = "avx512bf16", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
4124+
let Features = "avx512bf16", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
41254125
def selectsbf_128 : X86Builtin<"_Vector<8, __bf16>(unsigned char, _Vector<8, __bf16>, _Vector<8, __bf16>)">;
41264126
}
41274127

4128-
let Features = "avx512f", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
4128+
let Features = "avx512f", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
41294129
def selectss_128 : X86Builtin<"_Vector<4, float>(unsigned char, _Vector<4, float>, _Vector<4, float>)">;
41304130
def selectsd_128 : X86Builtin<"_Vector<2, double>(unsigned char, _Vector<2, double>, _Vector<2, double>)">;
41314131
}

clang/lib/AST/ByteCode/InterpBuiltin.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2838,6 +2838,30 @@ static bool interp__builtin_select(InterpState &S, CodePtr OpPC,
28382838
return true;
28392839
}
28402840

2841+
/// Scalar variant of AVX512 predicated select:
2842+
/// Result[i] = (Mask bit 0) ? LHS[i] : RHS[i], but only element 0 may change.
2843+
/// All other elements are taken from RHS.
2844+
static bool interp__builtin_select_scalar(InterpState &S,
2845+
const CallExpr *Call) {
2846+
unsigned N =
2847+
Call->getArg(1)->getType()->getAs<VectorType>()->getNumElements();
2848+
2849+
const Pointer &W = S.Stk.pop<Pointer>();
2850+
const Pointer &A = S.Stk.pop<Pointer>();
2851+
APSInt U = popToAPSInt(S, Call->getArg(0));
2852+
const Pointer &Dst = S.Stk.peek<Pointer>();
2853+
2854+
bool TakeA0 = U.getZExtValue() & 1ULL;
2855+
2856+
for (unsigned I = TakeA0; I != N; ++I)
2857+
Dst.elem<Floating>(I) = W.elem<Floating>(I);
2858+
if (TakeA0)
2859+
Dst.elem<Floating>(0) = A.elem<Floating>(0);
2860+
2861+
Dst.initializeAllElements();
2862+
return true;
2863+
}
2864+
28412865
static bool interp__builtin_blend(InterpState &S, CodePtr OpPC,
28422866
const CallExpr *Call) {
28432867
APSInt Mask = popToAPSInt(S, Call->getArg(2));
@@ -4151,6 +4175,11 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
41514175
return APInt::getAllOnes(DstBits);
41524176
});
41534177

4178+
case clang::X86::BI__builtin_ia32_selectss_128:
4179+
case clang::X86::BI__builtin_ia32_selectsd_128:
4180+
case clang::X86::BI__builtin_ia32_selectsh_128:
4181+
case clang::X86::BI__builtin_ia32_selectsbf_128:
4182+
return interp__builtin_select_scalar(S, Call);
41544183
case clang::X86::BI__builtin_ia32_vprotbi:
41554184
case clang::X86::BI__builtin_ia32_vprotdi:
41564185
case clang::X86::BI__builtin_ia32_vprotqi:

clang/lib/AST/ExprConstant.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12202,6 +12202,24 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1220212202
return Success(APValue(ResultElements.data(), SourceLen), E);
1220312203
};
1220412204

12205+
auto EvalSelectScalar = [&](unsigned Len) -> bool {
12206+
APSInt Mask;
12207+
APValue AVal, WVal;
12208+
if (!EvaluateInteger(E->getArg(0), Mask, Info) ||
12209+
!EvaluateAsRValue(Info, E->getArg(1), AVal) ||
12210+
!EvaluateAsRValue(Info, E->getArg(2), WVal))
12211+
return false;
12212+
12213+
bool TakeA0 = (Mask.getZExtValue() & 1u) != 0;
12214+
SmallVector<APValue, 4> Res;
12215+
Res.reserve(Len);
12216+
Res.push_back(TakeA0 ? AVal.getVectorElt(0) : WVal.getVectorElt(0));
12217+
for (unsigned I = 1; I < Len; ++I)
12218+
Res.push_back(WVal.getVectorElt(I));
12219+
APValue V(Res.data(), Res.size());
12220+
return Success(V, E);
12221+
};
12222+
1220512223
switch (E->getBuiltinCallee()) {
1220612224
default:
1220712225
return false;
@@ -12505,6 +12523,13 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1250512523
return APInt((Src).trunc(DstBits));
1250612524
return APInt::getAllOnes(DstBits);
1250712525
});
12526+
case clang::X86::BI__builtin_ia32_selectss_128:
12527+
return EvalSelectScalar(4);
12528+
case clang::X86::BI__builtin_ia32_selectsd_128:
12529+
return EvalSelectScalar(2);
12530+
case clang::X86::BI__builtin_ia32_selectsh_128:
12531+
case clang::X86::BI__builtin_ia32_selectsbf_128:
12532+
return EvalSelectScalar(8);
1250812533
case clang::X86::BI__builtin_ia32_pmuldq128:
1250912534
case clang::X86::BI__builtin_ia32_pmuldq256:
1251012535
case clang::X86::BI__builtin_ia32_pmuldq512:

clang/lib/Headers/avx10_2bf16intrin.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,12 +221,12 @@ static __inline__ __m128bh __DEFAULT_FN_ATTRS128 _mm_move_sbh(__m128bh __a,
221221
return __a;
222222
}
223223

224-
static __inline__ __m128bh __DEFAULT_FN_ATTRS128
224+
static __inline__ __m128bh __DEFAULT_FN_ATTRS128_CONSTEXPR
225225
_mm_mask_move_sbh(__m128bh __W, __mmask8 __U, __m128bh __A, __m128bh __B) {
226226
return __builtin_ia32_selectsbf_128(__U, _mm_move_sbh(__A, __B), __W);
227227
}
228228

229-
static __inline__ __m128bh __DEFAULT_FN_ATTRS128
229+
static __inline__ __m128bh __DEFAULT_FN_ATTRS128_CONSTEXPR
230230
_mm_maskz_move_sbh(__mmask8 __U, __m128bh __A, __m128bh __B) {
231231
return __builtin_ia32_selectsbf_128(__U, _mm_move_sbh(__A, __B),
232232
_mm_setzero_pbh());

clang/lib/Headers/avx512fintrin.h

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1820,14 +1820,14 @@ _mm512_maskz_abs_epi32(__mmask16 __U, __m512i __A) {
18201820
(__v16si)_mm512_setzero_si512());
18211821
}
18221822

1823-
static __inline__ __m128 __DEFAULT_FN_ATTRS128
1824-
_mm_mask_add_ss(__m128 __W, __mmask8 __U,__m128 __A, __m128 __B) {
1823+
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
1824+
_mm_mask_add_ss(__m128 __W, __mmask8 __U, __m128 __A, __m128 __B) {
18251825
__A = _mm_add_ss(__A, __B);
18261826
return __builtin_ia32_selectss_128(__U, __A, __W);
18271827
}
18281828

1829-
static __inline__ __m128 __DEFAULT_FN_ATTRS128
1830-
_mm_maskz_add_ss(__mmask8 __U,__m128 __A, __m128 __B) {
1829+
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
1830+
_mm_maskz_add_ss(__mmask8 __U, __m128 __A, __m128 __B) {
18311831
__A = _mm_add_ss(__A, __B);
18321832
return __builtin_ia32_selectss_128(__U, __A, _mm_setzero_ps());
18331833
}
@@ -1850,14 +1850,14 @@ _mm_maskz_add_ss(__mmask8 __U,__m128 __A, __m128 __B) {
18501850
(__v4sf)_mm_setzero_ps(), \
18511851
(__mmask8)(U), (int)(R)))
18521852

1853-
static __inline__ __m128d __DEFAULT_FN_ATTRS128
1854-
_mm_mask_add_sd(__m128d __W, __mmask8 __U,__m128d __A, __m128d __B) {
1853+
static __inline__ __m128d __DEFAULT_FN_ATTRS128_CONSTEXPR
1854+
_mm_mask_add_sd(__m128d __W, __mmask8 __U, __m128d __A, __m128d __B) {
18551855
__A = _mm_add_sd(__A, __B);
18561856
return __builtin_ia32_selectsd_128(__U, __A, __W);
18571857
}
18581858

1859-
static __inline__ __m128d __DEFAULT_FN_ATTRS128
1860-
_mm_maskz_add_sd(__mmask8 __U,__m128d __A, __m128d __B) {
1859+
static __inline__ __m128d __DEFAULT_FN_ATTRS128_CONSTEXPR
1860+
_mm_maskz_add_sd(__mmask8 __U, __m128d __A, __m128d __B) {
18611861
__A = _mm_add_sd(__A, __B);
18621862
return __builtin_ia32_selectsd_128(__U, __A, _mm_setzero_pd());
18631863
}
@@ -1935,14 +1935,14 @@ _mm512_maskz_add_ps(__mmask16 __U, __m512 __A, __m512 __B) {
19351935
(__v16sf)_mm512_add_round_ps((A), (B), (R)), \
19361936
(__v16sf)_mm512_setzero_ps()))
19371937

1938-
static __inline__ __m128 __DEFAULT_FN_ATTRS128
1939-
_mm_mask_sub_ss(__m128 __W, __mmask8 __U,__m128 __A, __m128 __B) {
1938+
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
1939+
_mm_mask_sub_ss(__m128 __W, __mmask8 __U, __m128 __A, __m128 __B) {
19401940
__A = _mm_sub_ss(__A, __B);
19411941
return __builtin_ia32_selectss_128(__U, __A, __W);
19421942
}
19431943

1944-
static __inline__ __m128 __DEFAULT_FN_ATTRS128
1945-
_mm_maskz_sub_ss(__mmask8 __U,__m128 __A, __m128 __B) {
1944+
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
1945+
_mm_maskz_sub_ss(__mmask8 __U, __m128 __A, __m128 __B) {
19461946
__A = _mm_sub_ss(__A, __B);
19471947
return __builtin_ia32_selectss_128(__U, __A, _mm_setzero_ps());
19481948
}
@@ -1964,14 +1964,14 @@ _mm_maskz_sub_ss(__mmask8 __U,__m128 __A, __m128 __B) {
19641964
(__v4sf)_mm_setzero_ps(), \
19651965
(__mmask8)(U), (int)(R)))
19661966

1967-
static __inline__ __m128d __DEFAULT_FN_ATTRS128
1968-
_mm_mask_sub_sd(__m128d __W, __mmask8 __U,__m128d __A, __m128d __B) {
1967+
static __inline__ __m128d __DEFAULT_FN_ATTRS128_CONSTEXPR
1968+
_mm_mask_sub_sd(__m128d __W, __mmask8 __U, __m128d __A, __m128d __B) {
19691969
__A = _mm_sub_sd(__A, __B);
19701970
return __builtin_ia32_selectsd_128(__U, __A, __W);
19711971
}
19721972

1973-
static __inline__ __m128d __DEFAULT_FN_ATTRS128
1974-
_mm_maskz_sub_sd(__mmask8 __U,__m128d __A, __m128d __B) {
1973+
static __inline__ __m128d __DEFAULT_FN_ATTRS128_CONSTEXPR
1974+
_mm_maskz_sub_sd(__mmask8 __U, __m128d __A, __m128d __B) {
19751975
__A = _mm_sub_sd(__A, __B);
19761976
return __builtin_ia32_selectsd_128(__U, __A, _mm_setzero_pd());
19771977
}
@@ -2050,14 +2050,14 @@ _mm512_maskz_sub_ps(__mmask16 __U, __m512 __A, __m512 __B) {
20502050
(__v16sf)_mm512_sub_round_ps((A), (B), (R)), \
20512051
(__v16sf)_mm512_setzero_ps()))
20522052

2053-
static __inline__ __m128 __DEFAULT_FN_ATTRS128
2054-
_mm_mask_mul_ss(__m128 __W, __mmask8 __U,__m128 __A, __m128 __B) {
2053+
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
2054+
_mm_mask_mul_ss(__m128 __W, __mmask8 __U, __m128 __A, __m128 __B) {
20552055
__A = _mm_mul_ss(__A, __B);
20562056
return __builtin_ia32_selectss_128(__U, __A, __W);
20572057
}
20582058

2059-
static __inline__ __m128 __DEFAULT_FN_ATTRS128
2060-
_mm_maskz_mul_ss(__mmask8 __U,__m128 __A, __m128 __B) {
2059+
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
2060+
_mm_maskz_mul_ss(__mmask8 __U, __m128 __A, __m128 __B) {
20612061
__A = _mm_mul_ss(__A, __B);
20622062
return __builtin_ia32_selectss_128(__U, __A, _mm_setzero_ps());
20632063
}
@@ -2079,14 +2079,14 @@ _mm_maskz_mul_ss(__mmask8 __U,__m128 __A, __m128 __B) {
20792079
(__v4sf)_mm_setzero_ps(), \
20802080
(__mmask8)(U), (int)(R)))
20812081

2082-
static __inline__ __m128d __DEFAULT_FN_ATTRS128
2083-
_mm_mask_mul_sd(__m128d __W, __mmask8 __U,__m128d __A, __m128d __B) {
2082+
static __inline__ __m128d __DEFAULT_FN_ATTRS128_CONSTEXPR
2083+
_mm_mask_mul_sd(__m128d __W, __mmask8 __U, __m128d __A, __m128d __B) {
20842084
__A = _mm_mul_sd(__A, __B);
20852085
return __builtin_ia32_selectsd_128(__U, __A, __W);
20862086
}
20872087

2088-
static __inline__ __m128d __DEFAULT_FN_ATTRS128
2089-
_mm_maskz_mul_sd(__mmask8 __U,__m128d __A, __m128d __B) {
2088+
static __inline__ __m128d __DEFAULT_FN_ATTRS128_CONSTEXPR
2089+
_mm_maskz_mul_sd(__mmask8 __U, __m128d __A, __m128d __B) {
20902090
__A = _mm_mul_sd(__A, __B);
20912091
return __builtin_ia32_selectsd_128(__U, __A, _mm_setzero_pd());
20922092
}
@@ -2165,14 +2165,14 @@ _mm512_maskz_mul_ps(__mmask16 __U, __m512 __A, __m512 __B) {
21652165
(__v16sf)_mm512_mul_round_ps((A), (B), (R)), \
21662166
(__v16sf)_mm512_setzero_ps()))
21672167

2168-
static __inline__ __m128 __DEFAULT_FN_ATTRS128
2169-
_mm_mask_div_ss(__m128 __W, __mmask8 __U,__m128 __A, __m128 __B) {
2168+
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
2169+
_mm_mask_div_ss(__m128 __W, __mmask8 __U, __m128 __A, __m128 __B) {
21702170
__A = _mm_div_ss(__A, __B);
21712171
return __builtin_ia32_selectss_128(__U, __A, __W);
21722172
}
21732173

2174-
static __inline__ __m128 __DEFAULT_FN_ATTRS128
2175-
_mm_maskz_div_ss(__mmask8 __U,__m128 __A, __m128 __B) {
2174+
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
2175+
_mm_maskz_div_ss(__mmask8 __U, __m128 __A, __m128 __B) {
21762176
__A = _mm_div_ss(__A, __B);
21772177
return __builtin_ia32_selectss_128(__U, __A, _mm_setzero_ps());
21782178
}
@@ -2195,14 +2195,14 @@ _mm_maskz_div_ss(__mmask8 __U,__m128 __A, __m128 __B) {
21952195
(__v4sf)_mm_setzero_ps(), \
21962196
(__mmask8)(U), (int)(R)))
21972197

2198-
static __inline__ __m128d __DEFAULT_FN_ATTRS128
2199-
_mm_mask_div_sd(__m128d __W, __mmask8 __U,__m128d __A, __m128d __B) {
2198+
static __inline__ __m128d __DEFAULT_FN_ATTRS128_CONSTEXPR
2199+
_mm_mask_div_sd(__m128d __W, __mmask8 __U, __m128d __A, __m128d __B) {
22002200
__A = _mm_div_sd(__A, __B);
22012201
return __builtin_ia32_selectsd_128(__U, __A, __W);
22022202
}
22032203

2204-
static __inline__ __m128d __DEFAULT_FN_ATTRS128
2205-
_mm_maskz_div_sd(__mmask8 __U,__m128d __A, __m128d __B) {
2204+
static __inline__ __m128d __DEFAULT_FN_ATTRS128_CONSTEXPR
2205+
_mm_maskz_div_sd(__mmask8 __U, __m128d __A, __m128d __B) {
22062206
__A = _mm_div_sd(__A, __B);
22072207
return __builtin_ia32_selectsd_128(__U, __A, _mm_setzero_pd());
22082208
}

0 commit comments

Comments
 (0)