Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions clang/lib/Headers/avx512bf16intrin.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ typedef __bf16 __bfloat16 __attribute__((deprecated("use __bf16 instead")));
/// \returns A float data whose sign field and exponent field keep unchanged,
/// and fraction field is extended to 23 bits.
static __inline__ float __DEFAULT_FN_ATTRS _mm_cvtsbh_ss(__bf16 __A) {
return __builtin_ia32_cvtsbf162ss_32(__A);
return float(__A);
}

/// Convert Two Packed Single Data to One Packed BF16 Data.
Expand Down Expand Up @@ -236,8 +236,7 @@ _mm512_maskz_dpbf16_ps(__mmask16 __U, __m512 __D, __m512bh __A, __m512bh __B) {
/// A 256-bit vector of [16 x bfloat].
/// \returns A 512-bit vector of [16 x float] come from conversion of __A
static __inline__ __m512 __DEFAULT_FN_ATTRS512 _mm512_cvtpbh_ps(__m256bh __A) {
return _mm512_castsi512_ps((__m512i)_mm512_slli_epi32(
(__m512i)_mm512_cvtepi16_epi32((__m256i)__A), 16));
return (__m512)__builtin_convertvector(__A, __v16sf);
}

/// Convert Packed BF16 Data to Packed float Data using zeroing mask.
Expand All @@ -252,8 +251,7 @@ static __inline__ __m512 __DEFAULT_FN_ATTRS512 _mm512_cvtpbh_ps(__m256bh __A) {
/// \returns A 512-bit vector of [16 x float] come from conversion of __A
static __inline__ __m512 __DEFAULT_FN_ATTRS512
_mm512_maskz_cvtpbh_ps(__mmask16 __U, __m256bh __A) {
return _mm512_castsi512_ps((__m512i)_mm512_slli_epi32(
(__m512i)_mm512_maskz_cvtepi16_epi32((__mmask16)__U, (__m256i)__A), 16));
return _mm512_maskz_mov_ps(__U, (__m512)__builtin_convertvector(__A, __v16sf));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(style) - wrap _mm512_cvtpbh_ps instead of calling __builtin_convertvector directly - same for the other mask/maskz intrinsics

}

/// Convert Packed BF16 Data to Packed float Data using merging mask.
Expand All @@ -270,9 +268,7 @@ _mm512_maskz_cvtpbh_ps(__mmask16 __U, __m256bh __A) {
/// \returns A 512-bit vector of [16 x float] come from conversion of __A
static __inline__ __m512 __DEFAULT_FN_ATTRS512
_mm512_mask_cvtpbh_ps(__m512 __S, __mmask16 __U, __m256bh __A) {
return _mm512_castsi512_ps((__m512i)_mm512_mask_slli_epi32(
(__m512i)__S, (__mmask16)__U,
(__m512i)_mm512_cvtepi16_epi32((__m256i)__A), 16));
return _mm512_mask_mov_ps(__S, __U, (__m512)__builtin_convertvector(__A, __v16sf));
}

#undef __DEFAULT_FN_ATTRS
Expand Down
20 changes: 6 additions & 14 deletions clang/lib/Headers/avx512vlbf16intrin.h
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,7 @@ static __inline__ __bf16 __DEFAULT_FN_ATTRS128 _mm_cvtness_sbh(float __A) {
/// A 128-bit vector of [4 x bfloat].
/// \returns A 128-bit vector of [4 x float] come from conversion of __A
static __inline__ __m128 __DEFAULT_FN_ATTRS128 _mm_cvtpbh_ps(__m128bh __A) {
return _mm_castsi128_ps(
(__m128i)_mm_slli_epi32((__m128i)_mm_cvtepi16_epi32((__m128i)__A), 16));
return (__m128)__builtin_convertvector(__A, __v4sf);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're missing a shuffle vector to only access the bottom 4 elements of __A:

Suggested change
return (__m128)__builtin_convertvector(__A, __v4sf);
return (__m128)__builtin_convertvector(__builtin_shufflevector(__A, __A, 0, 1, 2, 3), __v4sf);

(will need clang-format)

}

/// Convert Packed BF16 Data to Packed float Data.
Expand All @@ -434,8 +433,7 @@ static __inline__ __m128 __DEFAULT_FN_ATTRS128 _mm_cvtpbh_ps(__m128bh __A) {
/// A 128-bit vector of [8 x bfloat].
/// \returns A 256-bit vector of [8 x float] come from conversion of __A
static __inline__ __m256 __DEFAULT_FN_ATTRS256 _mm256_cvtpbh_ps(__m128bh __A) {
return _mm256_castsi256_ps((__m256i)_mm256_slli_epi32(
(__m256i)_mm256_cvtepi16_epi32((__m128i)__A), 16));
return (__m256)__builtin_convertvector(__A, __v8sf);
}

/// Convert Packed BF16 Data to Packed float Data using zeroing mask.
Expand All @@ -450,8 +448,7 @@ static __inline__ __m256 __DEFAULT_FN_ATTRS256 _mm256_cvtpbh_ps(__m128bh __A) {
/// \returns A 128-bit vector of [4 x float] come from conversion of __A
static __inline__ __m128 __DEFAULT_FN_ATTRS128
_mm_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
return _mm_castsi128_ps((__m128i)_mm_slli_epi32(
(__m128i)_mm_maskz_cvtepi16_epi32((__mmask8)__U, (__m128i)__A), 16));
return _mm_maskz_mov_ps(__U, (__m128)__builtin_convertvector(__A, __v4sf));
}

/// Convert Packed BF16 Data to Packed float Data using zeroing mask.
Expand All @@ -466,8 +463,7 @@ _mm_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
/// \returns A 256-bit vector of [8 x float] come from conversion of __A
static __inline__ __m256 __DEFAULT_FN_ATTRS256
_mm256_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
return _mm256_castsi256_ps((__m256i)_mm256_slli_epi32(
(__m256i)_mm256_maskz_cvtepi16_epi32((__mmask8)__U, (__m128i)__A), 16));
return _mm256_maskz_mov_ps(__U, (__m256)__builtin_convertvector(__A, __v8sf));
}

/// Convert Packed BF16 Data to Packed float Data using merging mask.
Expand All @@ -485,9 +481,7 @@ _mm256_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
/// \returns A 128-bit vector of [4 x float] come from conversion of __A
static __inline__ __m128 __DEFAULT_FN_ATTRS128
_mm_mask_cvtpbh_ps(__m128 __S, __mmask8 __U, __m128bh __A) {
return _mm_castsi128_ps((__m128i)_mm_mask_slli_epi32(
(__m128i)__S, (__mmask8)__U, (__m128i)_mm_cvtepi16_epi32((__m128i)__A),
16));
return _mm_mask_mov_ps(__S, __U, (__m128)__builtin_convertvector(__A, __v4sf));
}

/// Convert Packed BF16 Data to Packed float Data using merging mask.
Expand All @@ -505,9 +499,7 @@ _mm_mask_cvtpbh_ps(__m128 __S, __mmask8 __U, __m128bh __A) {
/// \returns A 256-bit vector of [8 x float] come from conversion of __A
static __inline__ __m256 __DEFAULT_FN_ATTRS256
_mm256_mask_cvtpbh_ps(__m256 __S, __mmask8 __U, __m128bh __A) {
return _mm256_castsi256_ps((__m256i)_mm256_mask_slli_epi32(
(__m256i)__S, (__mmask8)__U, (__m256i)_mm256_cvtepi16_epi32((__m128i)__A),
16));
return _mm256_mask_mov_ps(__S, __U, (__m256)__builtin_convertvector(__A, __v8sf));
}

#undef __DEFAULT_FN_ATTRS128
Expand Down
Loading