@@ -14,14 +14,26 @@ using namespace BinSearch;
1414#if defined(__AVX512F__)
1515#include < immintrin.h>
1616
17+ bool has_avx512f () {
18+ static const bool supported_avx512f = __builtin_cpu_supports (" avx512f" );
19+ return supported_avx512f;
20+ }
21+
22+ bool has_avx512bf16 () {
23+ static const bool supported_avx512bf16 = __builtin_cpu_supports (" avx512bf16" );
24+ return supported_avx512bf16;
25+ }
26+
1727inline __m256i cvt_fp32_to_fp16 (const __m512 src) {
1828 return _mm512_cvtps_ph (src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
1929}
2030
2131inline __m256i cvt_fp32_to_bf16 (const __m512 src) {
2232#if defined(__AVX512BF16__)
23- return reinterpret_cast <__m256i>(_mm512_cvtneps_pbh (src));
24- #else
33+ if (has_avx512bf16 ()) {
34+ return reinterpret_cast <__m256i>(_mm512_cvtneps_pbh (src));
35+ }
36+ #endif
2537 __m512i value = _mm512_castps_si512 (src);
2638 __m512i nan = _mm512_set1_epi32 (0xffff );
2739 auto mask_value = _mm512_cmp_ps_mask (src, src, _CMP_ORD_Q);
@@ -38,7 +50,6 @@ inline __m256i cvt_fp32_to_bf16(const __m512 src) {
3850 // Check NaN before converting back to bf16
3951 t_value = _mm512_mask_blend_epi32 (mask_value, nan, t_value);
4052 return _mm512_cvtusepi32_epi16 (t_value);
41- #endif
4253}
4354
4455static inline __m512 set_nf4_lut () {
@@ -68,51 +79,53 @@ void dequantizeBlockwise4bitCpu(
6879 return ;
6980
7081#if defined(__AVX512F__)
71- long long dim_0 = m;
72- long long dim_1 = n;
73- long long input_dim_1 = dim_1 >> 1 ;
74- long long absmax_dim_1 = dim_1 / blocksize;
75- using Tcomp = float ;
76- constexpr auto VEC_LEN = sizeof (__m512i) / sizeof (Tcomp); // 16
77- if (dim_1 % VEC_LEN == 0 && blocksize >= VEC_LEN) {
78- __m512 lut = DATA_TYPE == 1 ? set_fp4_lut () : set_nf4_lut ();
79- constexpr auto k_step = VEC_LEN / 2 ; // 8
80- BNB_OMP_PARALLEL_FOR
81- for (int block_idx = 0 ; block_idx < dim_0; ++block_idx) {
82- for (int k = 0 ; k < input_dim_1; k += k_step) {
83- // Load 64 bits of nf4 data and a single scale data
84- uint8_t * p = &A[block_idx * input_dim_1 + k];
85- uint64_t packed;
86- std::memcpy (&packed, p, sizeof (uint64_t ));
87- auto scale_idx = k * 2 / blocksize;
88- auto vscales = _mm512_set1_ps ((float )absmax[block_idx * absmax_dim_1 + scale_idx]);
89- // unpack nf4 data to 32-bit integers
90- uint64_t high = 0 ;
91- uint64_t low = 0 ;
92- for (int i = 0 ; i < 4 ; ++i) {
93- low |= ((packed >> (2 * i * 4 )) & 0xf ) << ((2 * i + 1 ) * 8 );
94- low |= ((packed >> ((2 * i + 1 ) * 4 )) & 0xf ) << (2 * i * 8 );
95- high |= ((packed >> (2 * i * 4 + 32 )) & 0xf ) << ((2 * i + 1 ) * 8 );
96- high |= ((packed >> ((2 * i + 1 ) * 4 + 32 )) & 0xf ) << (2 * i * 8 );
97- }
98- __m128i packed_128 = _mm_set_epi64x (high, low);
99- __m512i vint32 = _mm512_cvtepu8_epi32 (packed_128);
100- // Table look-up
101- __m512 vout = _mm512_permutexvar_ps (vint32, lut);
102- // Apply scale
103- vout = _mm512_mul_ps (vout, vscales);
104- // Store results
105- T* pout = &out[block_idx * dim_1 + k * 2 ];
106- if constexpr (std::is_same<T, float >()) {
107- _mm512_storeu_ps (pout, vout);
108- } else if constexpr (std::is_same<T, bf16_t >()) {
109- _mm256_storeu_si256 ((__m256i*)pout, cvt_fp32_to_bf16 (vout));
110- } else if constexpr (std::is_same<T, fp16_t >()) {
111- _mm256_storeu_si256 ((__m256i*)pout, cvt_fp32_to_fp16 (vout));
82+ if (has_avx512f ()) {
83+ long long dim_0 = m;
84+ long long dim_1 = n;
85+ long long input_dim_1 = dim_1 >> 1 ;
86+ long long absmax_dim_1 = dim_1 / blocksize;
87+ using Tcomp = float ;
88+ constexpr auto VEC_LEN = sizeof (__m512i) / sizeof (Tcomp); // 16
89+ if (dim_1 % VEC_LEN == 0 && blocksize >= VEC_LEN) {
90+ __m512 lut = DATA_TYPE == 1 ? set_fp4_lut () : set_nf4_lut ();
91+ constexpr auto k_step = VEC_LEN / 2 ; // 8
92+ BNB_OMP_PARALLEL_FOR
93+ for (int block_idx = 0 ; block_idx < dim_0; ++block_idx) {
94+ for (int k = 0 ; k < input_dim_1; k += k_step) {
95+ // Load 64 bits of nf4 data and a single scale data
96+ uint8_t * p = &A[block_idx * input_dim_1 + k];
97+ uint64_t packed;
98+ std::memcpy (&packed, p, sizeof (uint64_t ));
99+ auto scale_idx = k * 2 / blocksize;
100+ auto vscales = _mm512_set1_ps ((float )absmax[block_idx * absmax_dim_1 + scale_idx]);
101+ // unpack nf4 data to 32-bit integers
102+ uint64_t high = 0 ;
103+ uint64_t low = 0 ;
104+ for (int i = 0 ; i < 4 ; ++i) {
105+ low |= ((packed >> (2 * i * 4 )) & 0xf ) << ((2 * i + 1 ) * 8 );
106+ low |= ((packed >> ((2 * i + 1 ) * 4 )) & 0xf ) << (2 * i * 8 );
107+ high |= ((packed >> (2 * i * 4 + 32 )) & 0xf ) << ((2 * i + 1 ) * 8 );
108+ high |= ((packed >> ((2 * i + 1 ) * 4 + 32 )) & 0xf ) << (2 * i * 8 );
109+ }
110+ __m128i packed_128 = _mm_set_epi64x (high, low);
111+ __m512i vint32 = _mm512_cvtepu8_epi32 (packed_128);
112+ // Table look-up
113+ __m512 vout = _mm512_permutexvar_ps (vint32, lut);
114+ // Apply scale
115+ vout = _mm512_mul_ps (vout, vscales);
116+ // Store results
117+ T* pout = &out[block_idx * dim_1 + k * 2 ];
118+ if constexpr (std::is_same<T, float >()) {
119+ _mm512_storeu_ps (pout, vout);
120+ } else if constexpr (std::is_same<T, bf16_t >()) {
121+ _mm256_storeu_si256 ((__m256i*)pout, cvt_fp32_to_bf16 (vout));
122+ } else if constexpr (std::is_same<T, fp16_t >()) {
123+ _mm256_storeu_si256 ((__m256i*)pout, cvt_fp32_to_fp16 (vout));
124+ }
112125 }
113126 }
127+ return ;
114128 }
115- return ;
116129 }
117130#endif
118131 // Scalar fallback branch
0 commit comments