Skip to content

Commit 4cd8c11

Browse files
authored
[X86] Replace default _mm512_sqrt_pd/s/h implementations with generic __builtin_elementwise_sqrt (#168057)
Followup to #165682
1 parent eb98b65 commit 4cd8c11

File tree

3 files changed

+23
-41
lines changed

3 files changed

+23
-41
lines changed

clang/lib/Headers/avx512fintrin.h

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,26 +1450,19 @@ _mm512_mask_mullox_epi64(__m512i __W, __mmask8 __U, __m512i __A, __m512i __B) {
14501450
(__v8df)_mm512_sqrt_round_pd((A), (R)), \
14511451
(__v8df)_mm512_setzero_pd()))
14521452

1453-
static __inline__ __m512d __DEFAULT_FN_ATTRS512
1454-
_mm512_sqrt_pd(__m512d __A)
1455-
{
1456-
return (__m512d)__builtin_ia32_sqrtpd512((__v8df)__A,
1457-
_MM_FROUND_CUR_DIRECTION);
1453+
static __inline__ __m512d __DEFAULT_FN_ATTRS512 _mm512_sqrt_pd(__m512d __A) {
1454+
return (__m512d)__builtin_elementwise_sqrt((__v8df)__A);
14581455
}
14591456

14601457
static __inline__ __m512d __DEFAULT_FN_ATTRS512
1461-
_mm512_mask_sqrt_pd (__m512d __W, __mmask8 __U, __m512d __A)
1462-
{
1463-
return (__m512d)__builtin_ia32_selectpd_512(__U,
1464-
(__v8df)_mm512_sqrt_pd(__A),
1458+
_mm512_mask_sqrt_pd(__m512d __W, __mmask8 __U, __m512d __A) {
1459+
return (__m512d)__builtin_ia32_selectpd_512(__U, (__v8df)_mm512_sqrt_pd(__A),
14651460
(__v8df)__W);
14661461
}
14671462

14681463
static __inline__ __m512d __DEFAULT_FN_ATTRS512
1469-
_mm512_maskz_sqrt_pd (__mmask8 __U, __m512d __A)
1470-
{
1471-
return (__m512d)__builtin_ia32_selectpd_512(__U,
1472-
(__v8df)_mm512_sqrt_pd(__A),
1464+
_mm512_maskz_sqrt_pd(__mmask8 __U, __m512d __A) {
1465+
return (__m512d)__builtin_ia32_selectpd_512(__U, (__v8df)_mm512_sqrt_pd(__A),
14731466
(__v8df)_mm512_setzero_pd());
14741467
}
14751468

@@ -1486,26 +1479,19 @@ _mm512_maskz_sqrt_pd (__mmask8 __U, __m512d __A)
14861479
(__v16sf)_mm512_sqrt_round_ps((A), (R)), \
14871480
(__v16sf)_mm512_setzero_ps()))
14881481

1489-
static __inline__ __m512 __DEFAULT_FN_ATTRS512
1490-
_mm512_sqrt_ps(__m512 __A)
1491-
{
1492-
return (__m512)__builtin_ia32_sqrtps512((__v16sf)__A,
1493-
_MM_FROUND_CUR_DIRECTION);
1482+
static __inline__ __m512 __DEFAULT_FN_ATTRS512 _mm512_sqrt_ps(__m512 __A) {
1483+
return (__m512)__builtin_elementwise_sqrt((__v16sf)__A);
14941484
}
14951485

1496-
static __inline__ __m512 __DEFAULT_FN_ATTRS512
1497-
_mm512_mask_sqrt_ps(__m512 __W, __mmask16 __U, __m512 __A)
1498-
{
1499-
return (__m512)__builtin_ia32_selectps_512(__U,
1500-
(__v16sf)_mm512_sqrt_ps(__A),
1486+
static __inline__ __m512 __DEFAULT_FN_ATTRS512
1487+
_mm512_mask_sqrt_ps(__m512 __W, __mmask16 __U, __m512 __A) {
1488+
return (__m512)__builtin_ia32_selectps_512(__U, (__v16sf)_mm512_sqrt_ps(__A),
15011489
(__v16sf)__W);
15021490
}
15031491

1504-
static __inline__ __m512 __DEFAULT_FN_ATTRS512
1505-
_mm512_maskz_sqrt_ps( __mmask16 __U, __m512 __A)
1506-
{
1507-
return (__m512)__builtin_ia32_selectps_512(__U,
1508-
(__v16sf)_mm512_sqrt_ps(__A),
1492+
static __inline__ __m512 __DEFAULT_FN_ATTRS512
1493+
_mm512_maskz_sqrt_ps(__mmask16 __U, __m512 __A) {
1494+
return (__m512)__builtin_ia32_selectps_512(__U, (__v16sf)_mm512_sqrt_ps(__A),
15091495
(__v16sf)_mm512_setzero_ps());
15101496
}
15111497

clang/lib/Headers/avx512fp16intrin.h

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,24 +1401,20 @@ _mm_maskz_scalef_sh(__mmask8 __U, __m128h __A, __m128h __B) {
14011401
(__v32hf)_mm512_setzero_ph()))
14021402

14031403
static __inline__ __m512h __DEFAULT_FN_ATTRS512 _mm512_sqrt_ph(__m512h __A) {
1404-
return (__m512h)__builtin_ia32_sqrtph512((__v32hf)__A,
1405-
_MM_FROUND_CUR_DIRECTION);
1404+
return (__m512h)__builtin_elementwise_sqrt((__v32hf)__A);
14061405
}
14071406

14081407
static __inline__ __m512h __DEFAULT_FN_ATTRS512
14091408
_mm512_mask_sqrt_ph(__m512h __W, __mmask32 __U, __m512h __A) {
14101409
return (__m512h)__builtin_ia32_selectph_512(
1411-
(__mmask32)(__U),
1412-
(__v32hf)__builtin_ia32_sqrtph512((__A), (_MM_FROUND_CUR_DIRECTION)),
1413-
(__v32hf)(__m512h)(__W));
1410+
(__mmask32)(__U), (__v32hf)_mm512_sqrt_ph(__A), (__v32hf)(__m512h)(__W));
14141411
}
14151412

14161413
static __inline__ __m512h __DEFAULT_FN_ATTRS512
14171414
_mm512_maskz_sqrt_ph(__mmask32 __U, __m512h __A) {
1418-
return (__m512h)__builtin_ia32_selectph_512(
1419-
(__mmask32)(__U),
1420-
(__v32hf)__builtin_ia32_sqrtph512((__A), (_MM_FROUND_CUR_DIRECTION)),
1421-
(__v32hf)_mm512_setzero_ph());
1415+
return (__m512h)__builtin_ia32_selectph_512((__mmask32)(__U),
1416+
(__v32hf)_mm512_sqrt_ph(__A),
1417+
(__v32hf)_mm512_setzero_ph());
14221418
}
14231419

14241420
#define _mm_sqrt_round_sh(A, B, R) \

clang/test/CodeGen/X86/avx512fp16-builtins-constrained.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ __m512h test_mm512_sqrt_ph(__m512h x) {
6565
__m512h test_mm512_mask_sqrt_ph (__m512h __W, __mmask32 __U, __m512h __A)
6666
{
6767
// COMMON-LABEL: test_mm512_mask_sqrt_ph
68-
// UNCONSTRAINED: call <32 x half> @llvm.sqrt.v32f16(<32 x half> %{{.*}})
69-
// CONSTRAINED: call <32 x half> @llvm.experimental.constrained.sqrt.v32f16(<32 x half> %{{.*}}, metadata !{{.*}})
68+
// UNCONSTRAINED: call {{.*}}<32 x half> @llvm.sqrt.v32f16(<32 x half> %{{.*}})
69+
// CONSTRAINED: call {{.*}}<32 x half> @llvm.experimental.constrained.sqrt.v32f16(<32 x half> %{{.*}}, metadata !{{.*}})
7070
// CHECK-ASM: vsqrtph %zmm{{.*}},
7171
// COMMONIR: bitcast i32 %{{.*}} to <32 x i1>
7272
// COMMONIR: select <32 x i1> %{{.*}}, <32 x half> %{{.*}}, <32 x half> %{{.*}}
@@ -76,8 +76,8 @@ __m512h test_mm512_mask_sqrt_ph (__m512h __W, __mmask32 __U, __m512h __A)
7676
__m512h test_mm512_maskz_sqrt_ph (__mmask32 __U, __m512h __A)
7777
{
7878
// COMMON-LABEL: test_mm512_maskz_sqrt_ph
79-
// UNCONSTRAINED: call <32 x half> @llvm.sqrt.v32f16(<32 x half> %{{.*}})
80-
// CONSTRAINED: call <32 x half> @llvm.experimental.constrained.sqrt.v32f16(<32 x half> %{{.*}}, metadata !{{.*}})
79+
// UNCONSTRAINED: call {{.*}}<32 x half> @llvm.sqrt.v32f16(<32 x half> %{{.*}})
80+
// CONSTRAINED: call {{.*}}<32 x half> @llvm.experimental.constrained.sqrt.v32f16(<32 x half> %{{.*}}, metadata !{{.*}})
8181
// CHECK-ASM: vsqrtph %zmm{{.*}},
8282
// COMMONIR: bitcast i32 %{{.*}} to <32 x i1>
8383
// COMMONIR: select <32 x i1> %{{.*}}, <32 x half> %{{.*}}, <32 x half> {{.*}}

0 commit comments

Comments
 (0)