diff --git a/CMakeLists.txt b/CMakeLists.txt index f88ac2b11..f36d49adc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -87,6 +87,7 @@ if (BUILD_CPU) set(CMAKE_CXX_STANDARD_REQUIRED ON) string(TOLOWER "${CMAKE_SYSTEM_PROCESSOR}" HOST_ARCH) find_package(OpenMP) + find_package(Torch) endif() if(BUILD_CUDA) @@ -270,7 +271,12 @@ add_library(bitsandbytes SHARED ${SRC_FILES}) target_compile_features(bitsandbytes PUBLIC cxx_std_17) target_include_directories(bitsandbytes PUBLIC csrc include) +# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g") if (BUILD_CPU) + if (Torch_FOUND) + target_link_libraries(bitsandbytes PRIVATE "${TORCH_LIBRARIES}") + add_definitions(-DHAS_TORCH) + endif() if (OpenMP_CXX_FOUND) target_link_libraries(bitsandbytes PRIVATE OpenMP::OpenMP_CXX) add_definitions(-DHAS_OPENMP) diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index f569bf681..94dc4eaaa 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -1,3 +1,7 @@ +#ifdef HAS_TORCH +#include +#include +#endif #include #include #include @@ -230,6 +234,95 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long #define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)) +#ifdef HAS_TORCH +static inline const at::BFloat16* cast_at_bf16(const bf16_t* p) { + static_assert(sizeof(bf16_t) == sizeof(at::BFloat16), "bf16_t size mismatch"); + return reinterpret_cast(p); +} + +static inline at::BFloat16* cast_at_bf16(bf16_t* p) { + static_assert(sizeof(bf16_t) == sizeof(at::BFloat16), "bf16_t size mismatch"); + return reinterpret_cast(p); +} +#endif + +template +inline void unpack_B( + bf16_t* __restrict__ Btmp, const unsigned char* __restrict__ packed_B, + const bf16_t* __restrict__ Bs, // scales [K/gs, N] in bf16 + int64_t N, int64_t K, int blocksize, int64_t ldb, int64_t ldb_tmp, int64_t strideBs +) { + // Dequant: (w - z) * s -> bf16 + const int64_t K2 = K >> 1; // 2 weights packed per byte + const int64_t gs2 = blocksize >> 1; + const int64_t ldb2 = ldb; // packed leading dimension (bytes) + const int64_t ldb_tmp2 = ldb_tmp; // output leading dimension in elements + float* btmp_ptr = reinterpret_cast(Btmp); // direct bf16 storage + + __m256i mask = _mm256_set1_epi8(0xF); // low nibble + __m256i fifteen = _mm256_set1_epi8(15); // shift [-15,15] -> [0,30] for LUT + __m512i lut = DATA_TYPE == 1 + ? _mm512_set_epi16( + 0x0000, -0x4180, -0x41D5, -0x4100, -0x4155, -0x4080, -0x40D5, -0x4455, 0x0000, 0x3E80, + 0x3E2B, 0x3F00, 0x3EAB, 0x3F80, 0x3F2B, 0x3BAB, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000 + ) + : _mm512_set_epi16( + 0x0000, 0x3F80, 0x3F39, 0x3F10, 0x3EE2, 0x3EAD, 0x3E7C, 0x3E25, 0x3DA3, 0x0000, -0x4246, + -0x41C3, -0x416E, -0x4136, -0x40FA, -0x40CE, -0x4080, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000 + ); + __m512i s_idx1 = _mm512_set_epi32(15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8); + __m512i s_idx0 = _mm512_set_epi32(7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0); + + __m512 scale_lo_fp32, scale_hi_fp32; + __m512 scales[4]; + + for (int64_t n = 0; n < N; n += 32) { + for (int64_t k = 0; k < K2; ++k) { + if (k % gs2 == 0) { + const int64_t kgs = k / gs2; + // Load 32 scales (bf16) -> two fp32 vectors (first16, second16) + __m512i scales_bf16 = _mm512_loadu_si512(reinterpret_cast(Bs + kgs * strideBs + n)); + scale_lo_fp32 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(scales_bf16, 0)); + scale_hi_fp32 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(scales_bf16, 1)); + scales[0] = _mm512_permutexvar_ps(s_idx0, scale_lo_fp32); + scales[1] = _mm512_permutexvar_ps(s_idx1, scale_lo_fp32); + scales[2] = _mm512_permutexvar_ps(s_idx0, scale_hi_fp32); + scales[3] = _mm512_permutexvar_ps(s_idx1, scale_hi_fp32); + } + + // Load packed 32 bytes => 64 int4 + __m256i w_u4 = _mm256_loadu_si256(reinterpret_cast(packed_B + k * ldb2 + n)); + + // Split nibbles + __m256i w_lo = w_u4 & mask; + __m256i w_hi = _mm256_srli_epi16(w_u4, 4) & mask; + + // Shift to [0..30] before LUT + w_lo = _mm256_add_epi8(w_lo, fifteen); + w_hi = _mm256_add_epi8(w_hi, fifteen); + + // Lookup (w - z) -> bf16 using LUT (process 16-byte halves) + __m512i w_lo_bf16 = _mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(w_lo), lut); + __m512i w_hi_bf16 = _mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(w_hi), lut); + + __m512 w_lo_fp32_0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(w_lo_bf16, 0)) * scales[0]; + __m512 w_hi_fp32_0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(w_lo_bf16, 1)) * scales[1]; + __m512 w_lo_fp32_1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(w_hi_bf16, 0)) * scales[2]; + __m512 w_hi_fp32_1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(w_hi_bf16, 1)) * scales[3]; + + // Pack scaled (first 16 cols) then (second 16 cols) to bf16 + __m512bh packed0 = _mm512_cvtne2ps_pbh(w_hi_fp32_0, w_lo_fp32_0); + __m512bh packed1 = _mm512_cvtne2ps_pbh(w_hi_fp32_1, w_lo_fp32_1); + + // Store: two blocks of 16 bf16 (32 elements) per k iteration + _mm512_storeu_si512(btmp_ptr + (k * ldb_tmp2 + n + 0), (__m512i)packed0); + _mm512_storeu_si512(btmp_ptr + (k * ldb_tmp2 + n + 16), (__m512i)packed1); + } + } +} + template struct tinygemm_kernel_nn { static inline void apply( const scalar_t*, const unsigned char*, scalar_t*, const scalar_t*, int64_t, int, int64_t, int64_t, int64_t, @@ -239,11 +332,21 @@ template struct tin } }; +// The brgemm will not be used without HAS_TORCH +template struct brgemm { + static inline void apply( + const scalar_t* __restrict__ A, const unsigned char* __restrict__ B, scalar_t* __restrict__ C, + const scalar_t* __restrict__ Bs, scalar_t* __restrict__ Btmp, float* __restrict__ Ctmp, int64_t M, int64_t N, + int64_t K, int blocksize, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBs, bool use_brgemm_dequant_out + ) { + return; + } +}; + template struct tinygemm_kernel_nn { static inline void apply( const bf16_t* __restrict__ A, const unsigned char* __restrict__ B, bf16_t* __restrict__ C, - const bf16_t* __restrict__ Bs, int64_t K, int group_size, int64_t lda, int64_t ldb, int64_t ldc, - int64_t strideBz, int64_t strideBs + const bf16_t* __restrict__ Bs, int64_t K, int blocksize, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBs ) { static_assert(BLOCK_N % 32 == 0); constexpr int ROWS = BLOCK_M; // 32 @@ -273,8 +376,8 @@ template struct tinygemm_kernel_nn> 1; const int64_t lda2 = lda >> 1; - const int64_t ldb2 = ldb; // ldb * 2 >> 1; - const int64_t gs2 = group_size >> 1; // 64 / 2 = 32 + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const int64_t gs2 = blocksize >> 1; // 64 / 2 = 32 const float* a_ptr = reinterpret_cast(A); auto loadc = [&](auto i) { @@ -347,16 +450,107 @@ template struct tinygemm_kernel_nn::apply( \ - A + mb_start * lda, B + nb_start, C + mb_start * ldc + nb_start, Bs + nb_start, K, group_size, lda, ldb, ldc, \ - strideBz, strideBs \ + A + mb_start * lda, B + nb_start, C + mb_start * ldc + nb_start, Bs + nb_start, K, blocksize, lda, ldb, ldc, \ + strideBs \ ); +#ifdef HAS_TORCH + +inline uint16_t float_to_bf16_round(float x) { + uint32_t u; + std::memcpy(&u, &x, sizeof(u)); + uint32_t lsb = (u >> 16) & 1; + uint32_t rounding_bias = 0x7fff + lsb; + u += rounding_bias; + uint16_t hi = static_cast(u >> 16); + // Quiet NaN handling + if ((u & 0x7f800000) == 0x7f800000 && (u & 0x007fffff)) { + hi = 0xffff; + } + return hi; +} + +template +inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) { +#if defined(__AVX512BF16__) + if (has_avx512bf16()) { + int64_t d = 0; + const int V = 32; + for (; d + V <= size; d += V) { + __m512 lo = _mm512_loadu_ps(input + d); + __m512 hi = _mm512_loadu_ps(input + d + 16); + __m512bh packed = _mm512_cvtne2ps_pbh(hi, lo); + _mm512_storeu_si512(reinterpret_cast(out + d), (__m512i)packed); + } + for (; d < size; ++d) { + if constexpr (std::is_same_v) { + // store raw bf16 bits + reinterpret_cast(out)[d] = float_to_bf16_round(input[d]); + } else { + out[d] = static_cast(input[d]); + } + } + return; + } +#endif + for (int64_t d = 0; d < size; ++d) { + if constexpr (std::is_same_v) { + reinterpret_cast(out)[d] = float_to_bf16_round(input[d]); + } else { + out[d] = static_cast(input[d]); + } + } +} + +template struct brgemm { + static inline void apply( + const bf16_t* __restrict__ A, const unsigned char* __restrict__ B, bf16_t* __restrict__ C, + const bf16_t* __restrict__ Bs, bf16_t* __restrict__ Btmp, float* __restrict__ Ctmp, int64_t M, int64_t N, + int64_t K, int blocksize, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBs, bool use_brgemm_dequant_out + ) { + constexpr int BLOCK_N = block_size_n(); + const int ldb_tmp = BLOCK_N; + if (use_brgemm_dequant_out) { + at::native::cpublas::brgemm( + M, N, K, lda, ldb_tmp, BLOCK_N, false, cast_at_bf16(A), cast_at_bf16(Btmp), Ctmp + ); + } else { + for (int64_t k = 0; k < K; k += BLOCK_K) { + int64_t kb_size = std::min(static_cast(BLOCK_K), K - k); + const int64_t kgs = k / blocksize; + + unpack_B( + Btmp, B + (k >> 1) * ldb, Bs + kgs * strideBs, N, kb_size, blocksize, ldb, ldb_tmp, strideBs + ); + + const bool add_C = k != 0; + at::native::cpublas::brgemm( + M, N, kb_size, lda, ldb_tmp, BLOCK_N, add_C, cast_at_bf16(A + k), cast_at_bf16(Btmp), Ctmp + ); + } + } + + // copy from Ctmp to C + for (int64_t m = 0; m < M; ++m) { + copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); + } + } +}; +#endif + template void tinygemm_kernel( const scalar_t* __restrict__ A, const unsigned char* __restrict__ B, scalar_t* __restrict__ C, const scalar_t* __restrict__ Bs, scalar_t* __restrict__ Btmp, float* __restrict__ Ctmp, int64_t M, int64_t N, - int64_t K, int group_size, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBz, int64_t strideBs + int64_t K, int blocksize, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBs, bool brg, + bool use_brgemm_dequant_out = false ) { + if (brg) { + brgemm::apply( + A, B, C, Bs, Btmp, Ctmp, M, N, K, blocksize, lda, ldb, ldc, strideBs, use_brgemm_dequant_out + ); + return; + } constexpr int64_t BLOCK_M = 4; constexpr int64_t BLOCK_N = 64; const int64_t MB = div_up(M, BLOCK_M); @@ -417,10 +611,39 @@ void gemv_4bit_inference( constexpr int64_t BLOCK_N = block_size_n(); // 32 const int64_t MB = div_up(M, BLOCK_M); // (x + y -1)/ y, res = 1 when M <= 32 const int64_t NB = div_up(N, BLOCK_N); - // TODO: enable brgemm in the future. - // const bool use_brgemm = M > 4; - // const bool use_brgemm_dequant_out = M > 512; - // T* Btmp_start = nullptr; + // TODO: Find better threshold. + T* Btmp_start = nullptr; +#ifdef HAS_TORCH + const bool use_brgemm = M > 4; + const bool use_brgemm_dequant_out = M > 100; + if (use_brgemm_dequant_out) { + // Layout: contiguous [N*K] elements, 64-byte aligned for AVX512 loads + at::Tensor Btmp_t = at::zeros({N, K}, at::dtype(at::kBFloat16)); + at::BFloat16* Btmp_start_pt = Btmp_t.data_ptr(); + Btmp_start = reinterpret_cast(Btmp_start_pt); + BNB_OMP_PARALLEL_FOR + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(N - nb_start, BLOCK_N); + T* Btmp = Btmp_start + nb_start * K; + for (int64_t k = 0; k < K; k += BLOCK_K) { + int64_t kb_size = std::min(BLOCK_K, K - k); + int64_t kgs = k / blocksize; + int64_t strideBs = N; + int64_t ldb = nb_size; + const T* Bs = absmax + nb_start; + const unsigned char* Bw = reinterpret_cast(w + nb_start * K / 2); + unpack_B( + Btmp + k * BLOCK_N, Bw + (k >> 1) * ldb, Bs + kgs * strideBs, nb_size, kb_size, blocksize, ldb, + BLOCK_N, strideBs + ); + } + } + } +#else + const bool use_brgemm = false; + const bool use_brgemm_dequant_out = false; +#endif // l2 cache block for n int64_t cache_blocks_nb = get_cache_blocks(BLOCK_N * K); parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) { @@ -439,24 +662,27 @@ void gemv_4bit_inference( /* B */ w + nb_start * K / 2, // divide by 2 since w is u4 packed in u8, K is w.size(1) * 2 /* C */ out + mb_start * out_stride + nb_start, /* Bs */ absmax + nb_start, - /* Btmp */ Btmp_inner, + /* Btmp */ use_brgemm_dequant_out ? Btmp_start + nb_start * K : Btmp_inner, /* Ctmp */ Ctmp, /* M */ mb_size, /* N */ nb_size, /* K */ K, - /* gs */ blocksize, // group_size + /* gs */ blocksize, // blocksize /* lda */ x_stride, /* ldb */ nb_size, /* ldc */ out_stride, - /* sBz */ N, - /* sBs */ N + /* sBs */ N, + /* brg */ use_brgemm, + /* dequant choice*/ use_brgemm_dequant_out ); } } } - // if (use_brgemm) { - // at::native::cpublas::brgemm_release(); - // } +#ifdef HAS_TORCH + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } +#endif }); } #endif