Skip to content

Commit bd73af2

Browse files
committed
[clang] add sqrt{pd|ps}512
Signed-off-by: Shreeyash Pandey <[email protected]>
1 parent bece3db commit bd73af2

File tree

5 files changed

+28
-5
lines changed

5 files changed

+28
-5
lines changed

clang/include/clang/Basic/BuiltinsX86.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,7 @@ let Features = "pku", Attributes = [NoThrow] in {
950950
def wrpkru : X86Builtin<"void(unsigned int)">;
951951
}
952952

953-
let Features = "avx512f", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
953+
let Features = "avx512f", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
954954
def sqrtpd512 : X86Builtin<"_Vector<8, double>(_Vector<8, double>, _Constant int)">;
955955
def sqrtps512 : X86Builtin<"_Vector<16, float>(_Vector<16, float>, _Constant int)">;
956956
}

clang/lib/AST/ByteCode/InterpBuiltin.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3016,7 +3016,12 @@ static bool interp__builtin_x86_sqrt(InterpState &S, CodePtr OpPC,
30163016
unsigned ID) {
30173017
llvm::errs() << "Entering x86 sqrtpd/ps interpretbuiltin\n";
30183018

3019-
assert(Call->getNumArgs() == 1);
3019+
llvm::errs() << "BI__builtin_ia32_sqrtpd512 " << X86::BI__builtin_ia32_sqrtpd512 << '\n';
3020+
llvm::errs() << "BI__builtin_ia32_sqrtps512 " << X86::BI__builtin_ia32_sqrtps512 << '\n';
3021+
llvm::errs() << "Current ID " << ID << '\n';
3022+
llvm::errs() << "GetNumArgs " << Call->getNumArgs() << '\n';
3023+
unsigned NumArgs = Call->getNumArgs();
3024+
assert(NumArgs == 1 || NumArgs == 2);
30203025
const Expr *ArgExpr = Call->getArg(0);
30213026
QualType ArgTy = ArgExpr->getType();
30223027
QualType ResultTy = Call->getType();
@@ -3033,6 +3038,16 @@ static bool interp__builtin_x86_sqrt(InterpState &S, CodePtr OpPC,
30333038
SemanticsPtr = &S.getContext().getFloatSemantics(ArgTy);
30343039
const llvm::fltSemantics &Semantics = *SemanticsPtr;
30353040

3041+
if (NumArgs == 2) {
3042+
if (!Call->getArg(1)->getType()->isIntegerType()) {
3043+
return false;
3044+
}
3045+
APSInt RoundingMode = popToAPSInt(S, Call->getArg(1));
3046+
if (RoundingMode.getZExtValue() != 4) {
3047+
return false;
3048+
}
3049+
}
3050+
30363051

30373052
// Scalar case
30383053
if (!ArgTy->isVectorType()) {
@@ -3831,6 +3846,8 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
38313846
case X86::BI__builtin_ia32_sqrtps:
38323847
case X86::BI__builtin_ia32_sqrtpd256:
38333848
case X86::BI__builtin_ia32_sqrtps256:
3849+
case X86::BI__builtin_ia32_sqrtps512:
3850+
case X86::BI__builtin_ia32_sqrtpd512:
38343851
return interp__builtin_x86_sqrt(S, OpPC, Call, BuiltinID);
38353852

38363853
default:

clang/lib/AST/ExprConstant.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12239,7 +12239,9 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
1223912239
case X86::BI__builtin_ia32_sqrtpd:
1224012240
case X86::BI__builtin_ia32_sqrtps:
1224112241
case X86::BI__builtin_ia32_sqrtpd256:
12242-
case X86::BI__builtin_ia32_sqrtps256: {
12242+
case X86::BI__builtin_ia32_sqrtps256:
12243+
case X86::BI__builtin_ia32_sqrtps512:
12244+
case X86::BI__builtin_ia32_sqrtpd512: {
1224312245
llvm::errs() << "We are inside sqrtpd/sqrtps\n";
1224412246
APValue Source;
1224512247
if (!EvaluateAsRValue(Info, E->getArg(0), Source))

clang/lib/Headers/avx512fintrin.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,7 +1458,7 @@ _mm512_mask_mullox_epi64(__m512i __W, __mmask8 __U, __m512i __A, __m512i __B) {
14581458
(__v8df)_mm512_sqrt_round_pd((A), (R)), \
14591459
(__v8df)_mm512_setzero_pd()))
14601460

1461-
static __inline__ __m512d __DEFAULT_FN_ATTRS512
1461+
static __inline__ __m512d __DEFAULT_FN_ATTRS512_CONSTEXPR
14621462
_mm512_sqrt_pd(__m512d __A)
14631463
{
14641464
return (__m512d)__builtin_ia32_sqrtpd512((__v8df)__A,
@@ -1494,7 +1494,7 @@ _mm512_maskz_sqrt_pd (__mmask8 __U, __m512d __A)
14941494
(__v16sf)_mm512_sqrt_round_ps((A), (R)), \
14951495
(__v16sf)_mm512_setzero_ps()))
14961496

1497-
static __inline__ __m512 __DEFAULT_FN_ATTRS512
1497+
static __inline__ __m512 __DEFAULT_FN_ATTRS512_CONSTEXPR
14981498
_mm512_sqrt_ps(__m512 __A)
14991499
{
15001500
return (__m512)__builtin_ia32_sqrtps512((__v16sf)__A,

clang/test/CodeGen/X86/avx512f-builtins.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ __m512d test_mm512_sqrt_pd(__m512d a)
1818
return _mm512_sqrt_pd(a);
1919
}
2020

21+
TEST_CONSTEXPR(match_m512d(_mm512_sqrt_pd(_mm512_set_pd(16.0, 9.0, 4.0, 1.0, 16.0, 9.0, 4.0, 1.0)), 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0));
22+
2123
__m512d test_mm512_mask_sqrt_pd (__m512d __W, __mmask8 __U, __m512d __A)
2224
{
2325
// CHECK-LABEL: test_mm512_mask_sqrt_pd
@@ -68,6 +70,8 @@ __m512 test_mm512_sqrt_ps(__m512 a)
6870
return _mm512_sqrt_ps(a);
6971
}
7072

73+
TEST_CONSTEXPR(match_m512(_mm512_sqrt_ps(_mm512_set_ps(64.0f, 49.0f, 36.0f, 25.0f, 16.0f, 9.0f, 4.0f, 1.0f, 64.0f, 49.0f, 36.0f, 25.0f, 16.0f, 9.0f, 4.0f, 1.0f)), 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f));
74+
7175
__m512 test_mm512_mask_sqrt_ps(__m512 __W, __mmask16 __U, __m512 __A)
7276
{
7377
// CHECK-LABEL: test_mm512_mask_sqrt_ps

0 commit comments

Comments
 (0)