Skip to content

Commit a4014ec

Browse files
author
Jason
committed
Rewrote bf16->f32 conversion intrinsics
1 parent cafc064 commit a4014ec

File tree

2 files changed

+10
-22
lines changed

2 files changed

+10
-22
lines changed

clang/lib/Headers/avx512bf16intrin.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ typedef __bf16 __bfloat16 __attribute__((deprecated("use __bf16 instead")));
3636
/// \returns A float data whose sign field and exponent field keep unchanged,
3737
/// and fraction field is extended to 23 bits.
3838
static __inline__ float __DEFAULT_FN_ATTRS _mm_cvtsbh_ss(__bf16 __A) {
39-
return __builtin_ia32_cvtsbf162ss_32(__A);
39+
return float(__A);
4040
}
4141

4242
/// Convert Two Packed Single Data to One Packed BF16 Data.
@@ -236,8 +236,7 @@ _mm512_maskz_dpbf16_ps(__mmask16 __U, __m512 __D, __m512bh __A, __m512bh __B) {
236236
/// A 256-bit vector of [16 x bfloat].
237237
/// \returns A 512-bit vector of [16 x float] come from conversion of __A
238238
static __inline__ __m512 __DEFAULT_FN_ATTRS512 _mm512_cvtpbh_ps(__m256bh __A) {
239-
return _mm512_castsi512_ps((__m512i)_mm512_slli_epi32(
240-
(__m512i)_mm512_cvtepi16_epi32((__m256i)__A), 16));
239+
return (__m512)__builtin_convertvector(__A, __v16sf);
241240
}
242241

243242
/// Convert Packed BF16 Data to Packed float Data using zeroing mask.
@@ -252,8 +251,7 @@ static __inline__ __m512 __DEFAULT_FN_ATTRS512 _mm512_cvtpbh_ps(__m256bh __A) {
252251
/// \returns A 512-bit vector of [16 x float] come from conversion of __A
253252
static __inline__ __m512 __DEFAULT_FN_ATTRS512
254253
_mm512_maskz_cvtpbh_ps(__mmask16 __U, __m256bh __A) {
255-
return _mm512_castsi512_ps((__m512i)_mm512_slli_epi32(
256-
(__m512i)_mm512_maskz_cvtepi16_epi32((__mmask16)__U, (__m256i)__A), 16));
254+
return _mm512_maskz_mov_ps(__U, (__m512)__builtin_convertvector(__A, __v16sf));
257255
}
258256

259257
/// Convert Packed BF16 Data to Packed float Data using merging mask.
@@ -270,9 +268,7 @@ _mm512_maskz_cvtpbh_ps(__mmask16 __U, __m256bh __A) {
270268
/// \returns A 512-bit vector of [16 x float] come from conversion of __A
271269
static __inline__ __m512 __DEFAULT_FN_ATTRS512
272270
_mm512_mask_cvtpbh_ps(__m512 __S, __mmask16 __U, __m256bh __A) {
273-
return _mm512_castsi512_ps((__m512i)_mm512_mask_slli_epi32(
274-
(__m512i)__S, (__mmask16)__U,
275-
(__m512i)_mm512_cvtepi16_epi32((__m256i)__A), 16));
271+
return _mm512_mask_mov_ps(__S, __U, (__m512)__builtin_convertvector(__A, __v16sf));
276272
}
277273

278274
#undef __DEFAULT_FN_ATTRS

clang/lib/Headers/avx512vlbf16intrin.h

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,7 @@ static __inline__ __bf16 __DEFAULT_FN_ATTRS128 _mm_cvtness_sbh(float __A) {
422422
/// A 128-bit vector of [4 x bfloat].
423423
/// \returns A 128-bit vector of [4 x float] come from conversion of __A
424424
static __inline__ __m128 __DEFAULT_FN_ATTRS128 _mm_cvtpbh_ps(__m128bh __A) {
425-
return _mm_castsi128_ps(
426-
(__m128i)_mm_slli_epi32((__m128i)_mm_cvtepi16_epi32((__m128i)__A), 16));
425+
return (__m128)__builtin_convertvector(__A, __v4sf);
427426
}
428427

429428
/// Convert Packed BF16 Data to Packed float Data.
@@ -434,8 +433,7 @@ static __inline__ __m128 __DEFAULT_FN_ATTRS128 _mm_cvtpbh_ps(__m128bh __A) {
434433
/// A 128-bit vector of [8 x bfloat].
435434
/// \returns A 256-bit vector of [8 x float] come from conversion of __A
436435
static __inline__ __m256 __DEFAULT_FN_ATTRS256 _mm256_cvtpbh_ps(__m128bh __A) {
437-
return _mm256_castsi256_ps((__m256i)_mm256_slli_epi32(
438-
(__m256i)_mm256_cvtepi16_epi32((__m128i)__A), 16));
436+
return (__m256)__builtin_convertvector(__A, __v8sf);
439437
}
440438

441439
/// Convert Packed BF16 Data to Packed float Data using zeroing mask.
@@ -450,8 +448,7 @@ static __inline__ __m256 __DEFAULT_FN_ATTRS256 _mm256_cvtpbh_ps(__m128bh __A) {
450448
/// \returns A 128-bit vector of [4 x float] come from conversion of __A
451449
static __inline__ __m128 __DEFAULT_FN_ATTRS128
452450
_mm_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
453-
return _mm_castsi128_ps((__m128i)_mm_slli_epi32(
454-
(__m128i)_mm_maskz_cvtepi16_epi32((__mmask8)__U, (__m128i)__A), 16));
451+
return __mm_maskz_mov_ps(__U, (__m128)__builtin_convertvector(__A, __v4sf));
455452
}
456453

457454
/// Convert Packed BF16 Data to Packed float Data using zeroing mask.
@@ -466,8 +463,7 @@ _mm_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
466463
/// \returns A 256-bit vector of [8 x float] come from conversion of __A
467464
static __inline__ __m256 __DEFAULT_FN_ATTRS256
468465
_mm256_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
469-
return _mm256_castsi256_ps((__m256i)_mm256_slli_epi32(
470-
(__m256i)_mm256_maskz_cvtepi16_epi32((__mmask8)__U, (__m128i)__A), 16));
466+
return __mm256_maskz_mov_ps(__U, (__m256)__builtin_convertvector(__A, __v8sf));
471467
}
472468

473469
/// Convert Packed BF16 Data to Packed float Data using merging mask.
@@ -485,9 +481,7 @@ _mm256_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
485481
/// \returns A 128-bit vector of [4 x float] come from conversion of __A
486482
static __inline__ __m128 __DEFAULT_FN_ATTRS128
487483
_mm_mask_cvtpbh_ps(__m128 __S, __mmask8 __U, __m128bh __A) {
488-
return _mm_castsi128_ps((__m128i)_mm_mask_slli_epi32(
489-
(__m128i)__S, (__mmask8)__U, (__m128i)_mm_cvtepi16_epi32((__m128i)__A),
490-
16));
484+
return __mm_mask_mov_ps(__S, __U, (__m128)__builtin_convertvector(__A, __v4sf));
491485
}
492486

493487
/// Convert Packed BF16 Data to Packed float Data using merging mask.
@@ -505,9 +499,7 @@ _mm_mask_cvtpbh_ps(__m128 __S, __mmask8 __U, __m128bh __A) {
505499
/// \returns A 256-bit vector of [8 x float] come from conversion of __A
506500
static __inline__ __m256 __DEFAULT_FN_ATTRS256
507501
_mm256_mask_cvtpbh_ps(__m256 __S, __mmask8 __U, __m128bh __A) {
508-
return _mm256_castsi256_ps((__m256i)_mm256_mask_slli_epi32(
509-
(__m256i)__S, (__mmask8)__U, (__m256i)_mm256_cvtepi16_epi32((__m128i)__A),
510-
16));
502+
return __mm256_mask_mov_ps(__S, __U, (__m256)__builtin_convertvector(__A, __v8sf));
511503
}
512504

513505
#undef __DEFAULT_FN_ATTRS128

0 commit comments

Comments
 (0)