Skip to content

Commit 853399d

Browse files
author
liuzhenya
committed
[Headers][X86] Allow AVX512 masked arithmetic ss/sd intrinsics to be used in constexpr
1 parent 95f2728 commit 853399d

File tree

9 files changed

+175
-83
lines changed

9 files changed

+175
-83
lines changed

clang/include/clang/Basic/BuiltinsX86.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4116,15 +4116,15 @@ let Features = "avx512f", Attributes = [NoThrow, Const, Constexpr, RequiredVecto
41164116
def selectpd_512 : X86Builtin<"_Vector<8, double>(unsigned char, _Vector<8, double>, _Vector<8, double>)">;
41174117
}
41184118

4119-
let Features = "avx512fp16", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
4119+
let Features = "avx512fp16", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
41204120
def selectsh_128 : X86Builtin<"_Vector<8, _Float16>(unsigned char, _Vector<8, _Float16>, _Vector<8, _Float16>)">;
41214121
}
41224122

4123-
let Features = "avx512bf16", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
4123+
let Features = "avx512bf16", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
41244124
def selectsbf_128 : X86Builtin<"_Vector<8, __bf16>(unsigned char, _Vector<8, __bf16>, _Vector<8, __bf16>)">;
41254125
}
41264126

4127-
let Features = "avx512f", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
4127+
let Features = "avx512f", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
41284128
def selectss_128 : X86Builtin<"_Vector<4, float>(unsigned char, _Vector<4, float>, _Vector<4, float>)">;
41294129
def selectsd_128 : X86Builtin<"_Vector<2, double>(unsigned char, _Vector<2, double>, _Vector<2, double>)">;
41304130
}

clang/lib/AST/ByteCode/InterpBuiltin.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2419,6 +2419,27 @@ static bool interp__builtin_elementwise_int_unaryop(
24192419
return false;
24202420
}
24212421

2422+
static bool interp__builtin_select_scalar(InterpState &S,
2423+
const CallExpr *Call) {
2424+
unsigned N =
2425+
Call->getArg(1)->getType()->getAs<VectorType>()->getNumElements();
2426+
2427+
const Pointer &W = S.Stk.pop<Pointer>();
2428+
const Pointer &A = S.Stk.pop<Pointer>();
2429+
APSInt U = popToAPSInt(S, Call->getArg(0));
2430+
const Pointer &Dst = S.Stk.peek<Pointer>();
2431+
2432+
bool TakeA0 = U.getZExtValue() & 1ULL;
2433+
2434+
for (unsigned I = 0; I < N; ++I)
2435+
Dst.elem<Floating>(I) = W.elem<Floating>(I);
2436+
if (TakeA0)
2437+
Dst.elem<Floating>(0) = A.elem<Floating>(0);
2438+
2439+
Dst.initializeAllElements();
2440+
return true;
2441+
}
2442+
24222443
static bool interp__builtin_elementwise_int_binop(
24232444
InterpState &S, CodePtr OpPC, const CallExpr *Call,
24242445
llvm::function_ref<APInt(const APSInt &, const APSInt &)> Fn) {
@@ -4121,6 +4142,11 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
41214142
return APInt::getAllOnes(DstBits);
41224143
});
41234144

4145+
case clang::X86::BI__builtin_ia32_selectss_128:
4146+
case clang::X86::BI__builtin_ia32_selectsd_128:
4147+
case clang::X86::BI__builtin_ia32_selectsh_128:
4148+
case clang::X86::BI__builtin_ia32_selectsbf_128:
4149+
return interp__builtin_select_scalar(S, Call);
41244150
case clang::X86::BI__builtin_ia32_vprotbi:
41254151
case clang::X86::BI__builtin_ia32_vprotdi:
41264152
case clang::X86::BI__builtin_ia32_vprotqi:

clang/lib/AST/ExprConstant.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12202,6 +12202,24 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1220212202
return Success(APValue(ResultElements.data(), SourceLen), E);
1220312203
};
1220412204

12205+
auto EvalSelectScalar = [&](unsigned Len) -> bool {
12206+
APSInt Mask;
12207+
APValue AVal, WVal;
12208+
if (!EvaluateInteger(E->getArg(0), Mask, Info) ||
12209+
!EvaluateAsRValue(Info, E->getArg(1), AVal) ||
12210+
!EvaluateAsRValue(Info, E->getArg(2), WVal))
12211+
return false;
12212+
12213+
bool TakeA0 = (Mask.getZExtValue() & 1u) != 0;
12214+
SmallVector<APValue, 4> Res;
12215+
Res.reserve(Len);
12216+
Res.push_back(TakeA0 ? AVal.getVectorElt(0) : WVal.getVectorElt(0));
12217+
for (unsigned i = 1; i < Len; ++i)
12218+
Res.push_back(WVal.getVectorElt(i));
12219+
APValue V(Res.data(), Res.size());
12220+
return Success(V, E);
12221+
};
12222+
1220512223
switch (E->getBuiltinCallee()) {
1220612224
default:
1220712225
return false;
@@ -12505,6 +12523,13 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1250512523
return APInt((Src).trunc(DstBits));
1250612524
return APInt::getAllOnes(DstBits);
1250712525
});
12526+
case clang::X86::BI__builtin_ia32_selectss_128:
12527+
return EvalSelectScalar(4);
12528+
case clang::X86::BI__builtin_ia32_selectsd_128:
12529+
return EvalSelectScalar(2);
12530+
case clang::X86::BI__builtin_ia32_selectsh_128:
12531+
case clang::X86::BI__builtin_ia32_selectsbf_128:
12532+
return EvalSelectScalar(8);
1250812533
case clang::X86::BI__builtin_ia32_pmuldq128:
1250912534
case clang::X86::BI__builtin_ia32_pmuldq256:
1251012535
case clang::X86::BI__builtin_ia32_pmuldq512:

clang/lib/Headers/avx10_2bf16intrin.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,12 +221,12 @@ static __inline__ __m128bh __DEFAULT_FN_ATTRS128 _mm_move_sbh(__m128bh __a,
221221
return __a;
222222
}
223223

224-
static __inline__ __m128bh __DEFAULT_FN_ATTRS128
224+
static __inline__ __m128bh __DEFAULT_FN_ATTRS128_CONSTEXPR
225225
_mm_mask_move_sbh(__m128bh __W, __mmask8 __U, __m128bh __A, __m128bh __B) {
226226
return __builtin_ia32_selectsbf_128(__U, _mm_move_sbh(__A, __B), __W);
227227
}
228228

229-
static __inline__ __m128bh __DEFAULT_FN_ATTRS128
229+
static __inline__ __m128bh __DEFAULT_FN_ATTRS128_CONSTEXPR
230230
_mm_maskz_move_sbh(__mmask8 __U, __m128bh __A, __m128bh __B) {
231231
return __builtin_ia32_selectsbf_128(__U, _mm_move_sbh(__A, __B),
232232
_mm_setzero_pbh());

clang/lib/Headers/avx512fintrin.h

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,14 +1834,14 @@ _mm512_maskz_abs_epi32(__mmask16 __U, __m512i __A) {
18341834
(__v16si)_mm512_setzero_si512());
18351835
}
18361836

1837-
static __inline__ __m128 __DEFAULT_FN_ATTRS128
1838-
_mm_mask_add_ss(__m128 __W, __mmask8 __U,__m128 __A, __m128 __B) {
1837+
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
1838+
_mm_mask_add_ss(__m128 __W, __mmask8 __U, __m128 __A, __m128 __B) {
18391839
__A = _mm_add_ss(__A, __B);
18401840
return __builtin_ia32_selectss_128(__U, __A, __W);
18411841
}
18421842

1843-
static __inline__ __m128 __DEFAULT_FN_ATTRS128
1844-
_mm_maskz_add_ss(__mmask8 __U,__m128 __A, __m128 __B) {
1843+
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
1844+
_mm_maskz_add_ss(__mmask8 __U, __m128 __A, __m128 __B) {
18451845
__A = _mm_add_ss(__A, __B);
18461846
return __builtin_ia32_selectss_128(__U, __A, _mm_setzero_ps());
18471847
}
@@ -1864,14 +1864,14 @@ _mm_maskz_add_ss(__mmask8 __U,__m128 __A, __m128 __B) {
18641864
(__v4sf)_mm_setzero_ps(), \
18651865
(__mmask8)(U), (int)(R)))
18661866

1867-
static __inline__ __m128d __DEFAULT_FN_ATTRS128
1868-
_mm_mask_add_sd(__m128d __W, __mmask8 __U,__m128d __A, __m128d __B) {
1867+
static __inline__ __m128d __DEFAULT_FN_ATTRS128_CONSTEXPR
1868+
_mm_mask_add_sd(__m128d __W, __mmask8 __U, __m128d __A, __m128d __B) {
18691869
__A = _mm_add_sd(__A, __B);
18701870
return __builtin_ia32_selectsd_128(__U, __A, __W);
18711871
}
18721872

1873-
static __inline__ __m128d __DEFAULT_FN_ATTRS128
1874-
_mm_maskz_add_sd(__mmask8 __U,__m128d __A, __m128d __B) {
1873+
static __inline__ __m128d __DEFAULT_FN_ATTRS128_CONSTEXPR
1874+
_mm_maskz_add_sd(__mmask8 __U, __m128d __A, __m128d __B) {
18751875
__A = _mm_add_sd(__A, __B);
18761876
return __builtin_ia32_selectsd_128(__U, __A, _mm_setzero_pd());
18771877
}
@@ -1949,14 +1949,14 @@ _mm512_maskz_add_ps(__mmask16 __U, __m512 __A, __m512 __B) {
19491949
(__v16sf)_mm512_add_round_ps((A), (B), (R)), \
19501950
(__v16sf)_mm512_setzero_ps()))
19511951

1952-
static __inline__ __m128 __DEFAULT_FN_ATTRS128
1953-
_mm_mask_sub_ss(__m128 __W, __mmask8 __U,__m128 __A, __m128 __B) {
1952+
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
1953+
_mm_mask_sub_ss(__m128 __W, __mmask8 __U, __m128 __A, __m128 __B) {
19541954
__A = _mm_sub_ss(__A, __B);
19551955
return __builtin_ia32_selectss_128(__U, __A, __W);
19561956
}
19571957

1958-
static __inline__ __m128 __DEFAULT_FN_ATTRS128
1959-
_mm_maskz_sub_ss(__mmask8 __U,__m128 __A, __m128 __B) {
1958+
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
1959+
_mm_maskz_sub_ss(__mmask8 __U, __m128 __A, __m128 __B) {
19601960
__A = _mm_sub_ss(__A, __B);
19611961
return __builtin_ia32_selectss_128(__U, __A, _mm_setzero_ps());
19621962
}
@@ -1978,14 +1978,14 @@ _mm_maskz_sub_ss(__mmask8 __U,__m128 __A, __m128 __B) {
19781978
(__v4sf)_mm_setzero_ps(), \
19791979
(__mmask8)(U), (int)(R)))
19801980

1981-
static __inline__ __m128d __DEFAULT_FN_ATTRS128
1982-
_mm_mask_sub_sd(__m128d __W, __mmask8 __U,__m128d __A, __m128d __B) {
1981+
static __inline__ __m128d __DEFAULT_FN_ATTRS128_CONSTEXPR
1982+
_mm_mask_sub_sd(__m128d __W, __mmask8 __U, __m128d __A, __m128d __B) {
19831983
__A = _mm_sub_sd(__A, __B);
19841984
return __builtin_ia32_selectsd_128(__U, __A, __W);
19851985
}
19861986

1987-
static __inline__ __m128d __DEFAULT_FN_ATTRS128
1988-
_mm_maskz_sub_sd(__mmask8 __U,__m128d __A, __m128d __B) {
1987+
static __inline__ __m128d __DEFAULT_FN_ATTRS128_CONSTEXPR
1988+
_mm_maskz_sub_sd(__mmask8 __U, __m128d __A, __m128d __B) {
19891989
__A = _mm_sub_sd(__A, __B);
19901990
return __builtin_ia32_selectsd_128(__U, __A, _mm_setzero_pd());
19911991
}
@@ -2064,14 +2064,14 @@ _mm512_maskz_sub_ps(__mmask16 __U, __m512 __A, __m512 __B) {
20642064
(__v16sf)_mm512_sub_round_ps((A), (B), (R)), \
20652065
(__v16sf)_mm512_setzero_ps()))
20662066

2067-
static __inline__ __m128 __DEFAULT_FN_ATTRS128
2068-
_mm_mask_mul_ss(__m128 __W, __mmask8 __U,__m128 __A, __m128 __B) {
2067+
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
2068+
_mm_mask_mul_ss(__m128 __W, __mmask8 __U, __m128 __A, __m128 __B) {
20692069
__A = _mm_mul_ss(__A, __B);
20702070
return __builtin_ia32_selectss_128(__U, __A, __W);
20712071
}
20722072

2073-
static __inline__ __m128 __DEFAULT_FN_ATTRS128
2074-
_mm_maskz_mul_ss(__mmask8 __U,__m128 __A, __m128 __B) {
2073+
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
2074+
_mm_maskz_mul_ss(__mmask8 __U, __m128 __A, __m128 __B) {
20752075
__A = _mm_mul_ss(__A, __B);
20762076
return __builtin_ia32_selectss_128(__U, __A, _mm_setzero_ps());
20772077
}
@@ -2093,14 +2093,14 @@ _mm_maskz_mul_ss(__mmask8 __U,__m128 __A, __m128 __B) {
20932093
(__v4sf)_mm_setzero_ps(), \
20942094
(__mmask8)(U), (int)(R)))
20952095

2096-
static __inline__ __m128d __DEFAULT_FN_ATTRS128
2097-
_mm_mask_mul_sd(__m128d __W, __mmask8 __U,__m128d __A, __m128d __B) {
2096+
static __inline__ __m128d __DEFAULT_FN_ATTRS128_CONSTEXPR
2097+
_mm_mask_mul_sd(__m128d __W, __mmask8 __U, __m128d __A, __m128d __B) {
20982098
__A = _mm_mul_sd(__A, __B);
20992099
return __builtin_ia32_selectsd_128(__U, __A, __W);
21002100
}
21012101

2102-
static __inline__ __m128d __DEFAULT_FN_ATTRS128
2103-
_mm_maskz_mul_sd(__mmask8 __U,__m128d __A, __m128d __B) {
2102+
static __inline__ __m128d __DEFAULT_FN_ATTRS128_CONSTEXPR
2103+
_mm_maskz_mul_sd(__mmask8 __U, __m128d __A, __m128d __B) {
21042104
__A = _mm_mul_sd(__A, __B);
21052105
return __builtin_ia32_selectsd_128(__U, __A, _mm_setzero_pd());
21062106
}
@@ -2179,14 +2179,14 @@ _mm512_maskz_mul_ps(__mmask16 __U, __m512 __A, __m512 __B) {
21792179
(__v16sf)_mm512_mul_round_ps((A), (B), (R)), \
21802180
(__v16sf)_mm512_setzero_ps()))
21812181

2182-
static __inline__ __m128 __DEFAULT_FN_ATTRS128
2183-
_mm_mask_div_ss(__m128 __W, __mmask8 __U,__m128 __A, __m128 __B) {
2182+
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
2183+
_mm_mask_div_ss(__m128 __W, __mmask8 __U, __m128 __A, __m128 __B) {
21842184
__A = _mm_div_ss(__A, __B);
21852185
return __builtin_ia32_selectss_128(__U, __A, __W);
21862186
}
21872187

2188-
static __inline__ __m128 __DEFAULT_FN_ATTRS128
2189-
_mm_maskz_div_ss(__mmask8 __U,__m128 __A, __m128 __B) {
2188+
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
2189+
_mm_maskz_div_ss(__mmask8 __U, __m128 __A, __m128 __B) {
21902190
__A = _mm_div_ss(__A, __B);
21912191
return __builtin_ia32_selectss_128(__U, __A, _mm_setzero_ps());
21922192
}
@@ -2209,14 +2209,14 @@ _mm_maskz_div_ss(__mmask8 __U,__m128 __A, __m128 __B) {
22092209
(__v4sf)_mm_setzero_ps(), \
22102210
(__mmask8)(U), (int)(R)))
22112211

2212-
static __inline__ __m128d __DEFAULT_FN_ATTRS128
2213-
_mm_mask_div_sd(__m128d __W, __mmask8 __U,__m128d __A, __m128d __B) {
2212+
static __inline__ __m128d __DEFAULT_FN_ATTRS128_CONSTEXPR
2213+
_mm_mask_div_sd(__m128d __W, __mmask8 __U, __m128d __A, __m128d __B) {
22142214
__A = _mm_div_sd(__A, __B);
22152215
return __builtin_ia32_selectsd_128(__U, __A, __W);
22162216
}
22172217

2218-
static __inline__ __m128d __DEFAULT_FN_ATTRS128
2219-
_mm_maskz_div_sd(__mmask8 __U,__m128d __A, __m128d __B) {
2218+
static __inline__ __m128d __DEFAULT_FN_ATTRS128_CONSTEXPR
2219+
_mm_maskz_div_sd(__mmask8 __U, __m128d __A, __m128d __B) {
22202220
__A = _mm_div_sd(__A, __B);
22212221
return __builtin_ia32_selectsd_128(__U, __A, _mm_setzero_pd());
22222222
}

0 commit comments

Comments
 (0)