Skip to content

Commit 36dad93

Browse files
committed
add runtime check for avx512
Signed-off-by: jiqing-feng <[email protected]>
1 parent af54c9d commit 36dad93

File tree

1 file changed

+58
-45
lines changed

1 file changed

+58
-45
lines changed

csrc/cpu_ops.cpp

Lines changed: 58 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,26 @@ using namespace BinSearch;
1414
#if defined(__AVX512F__)
1515
#include <immintrin.h>
1616

17+
bool has_avx512f() {
18+
static const bool supported_avx512f = __builtin_cpu_supports("avx512f");
19+
return supported_avx512f;
20+
}
21+
22+
bool has_avx512bf16() {
23+
static const bool supported_avx512bf16 = __builtin_cpu_supports("avx512bf16");
24+
return supported_avx512bf16;
25+
}
26+
1727
inline __m256i cvt_fp32_to_fp16(const __m512 src) {
1828
return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
1929
}
2030

2131
inline __m256i cvt_fp32_to_bf16(const __m512 src) {
2232
#if defined(__AVX512BF16__)
23-
return reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(src));
24-
#else
33+
if (has_avx512bf16()) {
34+
return reinterpret_cast<__m256i>(_mm512_cvtneps_pbh(src));
35+
}
36+
#endif
2537
__m512i value = _mm512_castps_si512(src);
2638
__m512i nan = _mm512_set1_epi32(0xffff);
2739
auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q);
@@ -38,7 +50,6 @@ inline __m256i cvt_fp32_to_bf16(const __m512 src) {
3850
// Check NaN before converting back to bf16
3951
t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value);
4052
return _mm512_cvtusepi32_epi16(t_value);
41-
#endif
4253
}
4354

4455
static inline __m512 set_nf4_lut() {
@@ -68,51 +79,53 @@ void dequantizeBlockwise4bitCpu(
6879
return;
6980

7081
#if defined(__AVX512F__)
71-
long long dim_0 = m;
72-
long long dim_1 = n;
73-
long long input_dim_1 = dim_1 >> 1;
74-
long long absmax_dim_1 = dim_1 / blocksize;
75-
using Tcomp = float;
76-
constexpr auto VEC_LEN = sizeof(__m512i) / sizeof(Tcomp); // 16
77-
if (dim_1 % VEC_LEN == 0 && blocksize >= VEC_LEN) {
78-
__m512 lut = DATA_TYPE == 1 ? set_fp4_lut() : set_nf4_lut();
79-
constexpr auto k_step = VEC_LEN / 2; // 8
80-
BNB_OMP_PARALLEL_FOR
81-
for (int block_idx = 0; block_idx < dim_0; ++block_idx) {
82-
for (int k = 0; k < input_dim_1; k += k_step) {
83-
// Load 64 bits of nf4 data and a single scale data
84-
uint8_t* p = &A[block_idx * input_dim_1 + k];
85-
uint64_t packed;
86-
std::memcpy(&packed, p, sizeof(uint64_t));
87-
auto scale_idx = k * 2 / blocksize;
88-
auto vscales = _mm512_set1_ps((float)absmax[block_idx * absmax_dim_1 + scale_idx]);
89-
// unpack nf4 data to 32-bit integers
90-
uint64_t high = 0;
91-
uint64_t low = 0;
92-
for (int i = 0; i < 4; ++i) {
93-
low |= ((packed >> (2 * i * 4)) & 0xf) << ((2 * i + 1) * 8);
94-
low |= ((packed >> ((2 * i + 1) * 4)) & 0xf) << (2 * i * 8);
95-
high |= ((packed >> (2 * i * 4 + 32)) & 0xf) << ((2 * i + 1) * 8);
96-
high |= ((packed >> ((2 * i + 1) * 4 + 32)) & 0xf) << (2 * i * 8);
97-
}
98-
__m128i packed_128 = _mm_set_epi64x(high, low);
99-
__m512i vint32 = _mm512_cvtepu8_epi32(packed_128);
100-
// Table look-up
101-
__m512 vout = _mm512_permutexvar_ps(vint32, lut);
102-
// Apply scale
103-
vout = _mm512_mul_ps(vout, vscales);
104-
// Store results
105-
T* pout = &out[block_idx * dim_1 + k * 2];
106-
if constexpr (std::is_same<T, float>()) {
107-
_mm512_storeu_ps(pout, vout);
108-
} else if constexpr (std::is_same<T, bf16_t>()) {
109-
_mm256_storeu_si256((__m256i*)pout, cvt_fp32_to_bf16(vout));
110-
} else if constexpr (std::is_same<T, fp16_t>()) {
111-
_mm256_storeu_si256((__m256i*)pout, cvt_fp32_to_fp16(vout));
82+
if (has_avx512f()) {
83+
long long dim_0 = m;
84+
long long dim_1 = n;
85+
long long input_dim_1 = dim_1 >> 1;
86+
long long absmax_dim_1 = dim_1 / blocksize;
87+
using Tcomp = float;
88+
constexpr auto VEC_LEN = sizeof(__m512i) / sizeof(Tcomp); // 16
89+
if (dim_1 % VEC_LEN == 0 && blocksize >= VEC_LEN) {
90+
__m512 lut = DATA_TYPE == 1 ? set_fp4_lut() : set_nf4_lut();
91+
constexpr auto k_step = VEC_LEN / 2; // 8
92+
BNB_OMP_PARALLEL_FOR
93+
for (int block_idx = 0; block_idx < dim_0; ++block_idx) {
94+
for (int k = 0; k < input_dim_1; k += k_step) {
95+
// Load 64 bits of nf4 data and a single scale data
96+
uint8_t* p = &A[block_idx * input_dim_1 + k];
97+
uint64_t packed;
98+
std::memcpy(&packed, p, sizeof(uint64_t));
99+
auto scale_idx = k * 2 / blocksize;
100+
auto vscales = _mm512_set1_ps((float)absmax[block_idx * absmax_dim_1 + scale_idx]);
101+
// unpack nf4 data to 32-bit integers
102+
uint64_t high = 0;
103+
uint64_t low = 0;
104+
for (int i = 0; i < 4; ++i) {
105+
low |= ((packed >> (2 * i * 4)) & 0xf) << ((2 * i + 1) * 8);
106+
low |= ((packed >> ((2 * i + 1) * 4)) & 0xf) << (2 * i * 8);
107+
high |= ((packed >> (2 * i * 4 + 32)) & 0xf) << ((2 * i + 1) * 8);
108+
high |= ((packed >> ((2 * i + 1) * 4 + 32)) & 0xf) << (2 * i * 8);
109+
}
110+
__m128i packed_128 = _mm_set_epi64x(high, low);
111+
__m512i vint32 = _mm512_cvtepu8_epi32(packed_128);
112+
// Table look-up
113+
__m512 vout = _mm512_permutexvar_ps(vint32, lut);
114+
// Apply scale
115+
vout = _mm512_mul_ps(vout, vscales);
116+
// Store results
117+
T* pout = &out[block_idx * dim_1 + k * 2];
118+
if constexpr (std::is_same<T, float>()) {
119+
_mm512_storeu_ps(pout, vout);
120+
} else if constexpr (std::is_same<T, bf16_t>()) {
121+
_mm256_storeu_si256((__m256i*)pout, cvt_fp32_to_bf16(vout));
122+
} else if constexpr (std::is_same<T, fp16_t>()) {
123+
_mm256_storeu_si256((__m256i*)pout, cvt_fp32_to_fp16(vout));
124+
}
112125
}
113126
}
127+
return;
114128
}
115-
return;
116129
}
117130
#endif
118131
// Scalar fallback branch

0 commit comments

Comments
 (0)