Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ if (BUILD_CPU)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
string(TOLOWER "${CMAKE_SYSTEM_PROCESSOR}" HOST_ARCH)
find_package(OpenMP)
find_package(Torch)
endif()

if(BUILD_CUDA)
Expand Down Expand Up @@ -270,7 +271,12 @@ add_library(bitsandbytes SHARED ${SRC_FILES})
target_compile_features(bitsandbytes PUBLIC cxx_std_17)
target_include_directories(bitsandbytes PUBLIC csrc include)

# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g")
if (BUILD_CPU)
if (Torch_FOUND)
target_link_libraries(bitsandbytes PRIVATE "${TORCH_LIBRARIES}")
add_definitions(-DHAS_TORCH)
endif()
if (OpenMP_CXX_FOUND)
target_link_libraries(bitsandbytes PRIVATE OpenMP::OpenMP_CXX)
add_definitions(-DHAS_OPENMP)
Expand Down
262 changes: 244 additions & 18 deletions csrc/cpu_ops.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#ifdef HAS_TORCH
#include <ATen/ATen.h>
#include <ATen/native/CPUBlas.h>
#endif
#include <BinSearch.h>
#include <cpu_ops.h>
#include <thread>
Expand Down Expand Up @@ -230,6 +234,95 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long

#define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16))

#ifdef HAS_TORCH
static inline const at::BFloat16* cast_at_bf16(const bf16_t* p) {
static_assert(sizeof(bf16_t) == sizeof(at::BFloat16), "bf16_t size mismatch");
return reinterpret_cast<const at::BFloat16*>(p);
}

static inline at::BFloat16* cast_at_bf16(bf16_t* p) {
static_assert(sizeof(bf16_t) == sizeof(at::BFloat16), "bf16_t size mismatch");
return reinterpret_cast<at::BFloat16*>(p);
}
#endif

template <int DATA_TYPE>
inline void unpack_B(
bf16_t* __restrict__ Btmp, const unsigned char* __restrict__ packed_B,
const bf16_t* __restrict__ Bs, // scales [K/gs, N] in bf16
int64_t N, int64_t K, int blocksize, int64_t ldb, int64_t ldb_tmp, int64_t strideBs
) {
// Dequant: (w - z) * s -> bf16
const int64_t K2 = K >> 1; // 2 weights packed per byte
const int64_t gs2 = blocksize >> 1;
const int64_t ldb2 = ldb; // packed leading dimension (bytes)
const int64_t ldb_tmp2 = ldb_tmp; // output leading dimension in elements
float* btmp_ptr = reinterpret_cast<float*>(Btmp); // direct bf16 storage

__m256i mask = _mm256_set1_epi8(0xF); // low nibble
__m256i fifteen = _mm256_set1_epi8(15); // shift [-15,15] -> [0,30] for LUT
__m512i lut = DATA_TYPE == 1
? _mm512_set_epi16(
0x0000, -0x4180, -0x41D5, -0x4100, -0x4155, -0x4080, -0x40D5, -0x4455, 0x0000, 0x3E80,
0x3E2B, 0x3F00, 0x3EAB, 0x3F80, 0x3F2B, 0x3BAB, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000,
0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000
)
: _mm512_set_epi16(
0x0000, 0x3F80, 0x3F39, 0x3F10, 0x3EE2, 0x3EAD, 0x3E7C, 0x3E25, 0x3DA3, 0x0000, -0x4246,
-0x41C3, -0x416E, -0x4136, -0x40FA, -0x40CE, -0x4080, 0x0000, 0x0000, 0x0000, 0x0000,
0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000
);
__m512i s_idx1 = _mm512_set_epi32(15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8);
__m512i s_idx0 = _mm512_set_epi32(7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0);

__m512 scale_lo_fp32, scale_hi_fp32;
__m512 scales[4];

for (int64_t n = 0; n < N; n += 32) {
for (int64_t k = 0; k < K2; ++k) {
if (k % gs2 == 0) {
const int64_t kgs = k / gs2;
// Load 32 scales (bf16) -> two fp32 vectors (first16, second16)
__m512i scales_bf16 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(Bs + kgs * strideBs + n));
scale_lo_fp32 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(scales_bf16, 0));
scale_hi_fp32 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(scales_bf16, 1));
scales[0] = _mm512_permutexvar_ps(s_idx0, scale_lo_fp32);
scales[1] = _mm512_permutexvar_ps(s_idx1, scale_lo_fp32);
scales[2] = _mm512_permutexvar_ps(s_idx0, scale_hi_fp32);
scales[3] = _mm512_permutexvar_ps(s_idx1, scale_hi_fp32);
}

// Load packed 32 bytes => 64 int4
__m256i w_u4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(packed_B + k * ldb2 + n));

// Split nibbles
__m256i w_lo = w_u4 & mask;
__m256i w_hi = _mm256_srli_epi16(w_u4, 4) & mask;

// Shift to [0..30] before LUT
w_lo = _mm256_add_epi8(w_lo, fifteen);
w_hi = _mm256_add_epi8(w_hi, fifteen);

// Lookup (w - z) -> bf16 using LUT (process 16-byte halves)
__m512i w_lo_bf16 = _mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(w_lo), lut);
__m512i w_hi_bf16 = _mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(w_hi), lut);

__m512 w_lo_fp32_0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(w_lo_bf16, 0)) * scales[0];
__m512 w_hi_fp32_0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(w_lo_bf16, 1)) * scales[1];
__m512 w_lo_fp32_1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(w_hi_bf16, 0)) * scales[2];
__m512 w_hi_fp32_1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(w_hi_bf16, 1)) * scales[3];

// Pack scaled (first 16 cols) then (second 16 cols) to bf16
__m512bh packed0 = _mm512_cvtne2ps_pbh(w_hi_fp32_0, w_lo_fp32_0);
__m512bh packed1 = _mm512_cvtne2ps_pbh(w_hi_fp32_1, w_lo_fp32_1);

// Store: two blocks of 16 bf16 (32 elements) per k iteration
_mm512_storeu_si512(btmp_ptr + (k * ldb_tmp2 + n + 0), (__m512i)packed0);
_mm512_storeu_si512(btmp_ptr + (k * ldb_tmp2 + n + 16), (__m512i)packed1);
}
}
}

template <typename scalar_t, int BLOCK_M, int BLOCK_N, int DATA_TYPE> struct tinygemm_kernel_nn {
static inline void apply(
const scalar_t*, const unsigned char*, scalar_t*, const scalar_t*, int64_t, int, int64_t, int64_t, int64_t,
Expand All @@ -239,11 +332,21 @@ template <typename scalar_t, int BLOCK_M, int BLOCK_N, int DATA_TYPE> struct tin
}
};

// The brgemm will not be used without HAS_TORCH
template <typename scalar_t, int DATA_TYPE> struct brgemm {
static inline void apply(
const scalar_t* __restrict__ A, const unsigned char* __restrict__ B, scalar_t* __restrict__ C,
const scalar_t* __restrict__ Bs, scalar_t* __restrict__ Btmp, float* __restrict__ Ctmp, int64_t M, int64_t N,
int64_t K, int blocksize, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBs, bool use_brgemm_dequant_out
) {
return;
}
};

template <int BLOCK_M, int BLOCK_N, int DATA_TYPE> struct tinygemm_kernel_nn<bf16_t, BLOCK_M, BLOCK_N, DATA_TYPE> {
static inline void apply(
const bf16_t* __restrict__ A, const unsigned char* __restrict__ B, bf16_t* __restrict__ C,
const bf16_t* __restrict__ Bs, int64_t K, int group_size, int64_t lda, int64_t ldb, int64_t ldc,
int64_t strideBz, int64_t strideBs
const bf16_t* __restrict__ Bs, int64_t K, int blocksize, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBs
) {
static_assert(BLOCK_N % 32 == 0);
constexpr int ROWS = BLOCK_M; // 32
Expand Down Expand Up @@ -273,8 +376,8 @@ template <int BLOCK_M, int BLOCK_N, int DATA_TYPE> struct tinygemm_kernel_nn<bf1
__m512 scales[COLS];
const int64_t K2 = K >> 1;
const int64_t lda2 = lda >> 1;
const int64_t ldb2 = ldb; // ldb * 2 >> 1;
const int64_t gs2 = group_size >> 1; // 64 / 2 = 32
const int64_t ldb2 = ldb; // ldb * 2 >> 1;
const int64_t gs2 = blocksize >> 1; // 64 / 2 = 32
const float* a_ptr = reinterpret_cast<const float*>(A);

auto loadc = [&](auto i) {
Expand Down Expand Up @@ -347,16 +450,107 @@ template <int BLOCK_M, int BLOCK_N, int DATA_TYPE> struct tinygemm_kernel_nn<bf1

#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE, DATA_TYPE) \
tinygemm_kernel_nn<scalar_t, MB_SIZE, NB_SIZE, DATA_TYPE>::apply( \
A + mb_start * lda, B + nb_start, C + mb_start * ldc + nb_start, Bs + nb_start, K, group_size, lda, ldb, ldc, \
strideBz, strideBs \
A + mb_start * lda, B + nb_start, C + mb_start * ldc + nb_start, Bs + nb_start, K, blocksize, lda, ldb, ldc, \
strideBs \
);

#ifdef HAS_TORCH

inline uint16_t float_to_bf16_round(float x) {
uint32_t u;
std::memcpy(&u, &x, sizeof(u));
uint32_t lsb = (u >> 16) & 1;
uint32_t rounding_bias = 0x7fff + lsb;
u += rounding_bias;
uint16_t hi = static_cast<uint16_t>(u >> 16);
// Quiet NaN handling
if ((u & 0x7f800000) == 0x7f800000 && (u & 0x007fffff)) {
hi = 0xffff;
}
return hi;
}

template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
#if defined(__AVX512BF16__)
if (has_avx512bf16()) {
int64_t d = 0;
const int V = 32;
for (; d + V <= size; d += V) {
__m512 lo = _mm512_loadu_ps(input + d);
__m512 hi = _mm512_loadu_ps(input + d + 16);
__m512bh packed = _mm512_cvtne2ps_pbh(hi, lo);
_mm512_storeu_si512(reinterpret_cast<void*>(out + d), (__m512i)packed);
}
for (; d < size; ++d) {
if constexpr (std::is_same_v<scalar_t, bf16_t>) {
// store raw bf16 bits
reinterpret_cast<uint16_t*>(out)[d] = float_to_bf16_round(input[d]);
} else {
out[d] = static_cast<scalar_t>(input[d]);
}
}
return;
}
#endif
for (int64_t d = 0; d < size; ++d) {
if constexpr (std::is_same_v<scalar_t, bf16_t>) {
reinterpret_cast<uint16_t*>(out)[d] = float_to_bf16_round(input[d]);
} else {
out[d] = static_cast<scalar_t>(input[d]);
}
}
}

template <int DATA_TYPE> struct brgemm<bf16_t, DATA_TYPE> {
static inline void apply(
const bf16_t* __restrict__ A, const unsigned char* __restrict__ B, bf16_t* __restrict__ C,
const bf16_t* __restrict__ Bs, bf16_t* __restrict__ Btmp, float* __restrict__ Ctmp, int64_t M, int64_t N,
int64_t K, int blocksize, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBs, bool use_brgemm_dequant_out
) {
constexpr int BLOCK_N = block_size_n();
const int ldb_tmp = BLOCK_N;
if (use_brgemm_dequant_out) {
at::native::cpublas::brgemm(
M, N, K, lda, ldb_tmp, BLOCK_N, false, cast_at_bf16(A), cast_at_bf16(Btmp), Ctmp
);
} else {
for (int64_t k = 0; k < K; k += BLOCK_K) {
int64_t kb_size = std::min(static_cast<int64_t>(BLOCK_K), K - k);
const int64_t kgs = k / blocksize;

unpack_B<DATA_TYPE>(
Btmp, B + (k >> 1) * ldb, Bs + kgs * strideBs, N, kb_size, blocksize, ldb, ldb_tmp, strideBs
);

const bool add_C = k != 0;
at::native::cpublas::brgemm(
M, N, kb_size, lda, ldb_tmp, BLOCK_N, add_C, cast_at_bf16(A + k), cast_at_bf16(Btmp), Ctmp
);
}
}

// copy from Ctmp to C
for (int64_t m = 0; m < M; ++m) {
copy_stub<bf16_t>(C + m * ldc, Ctmp + m * BLOCK_N, N);
}
}
};
#endif

template <typename scalar_t, int DATA_TYPE>
void tinygemm_kernel(
const scalar_t* __restrict__ A, const unsigned char* __restrict__ B, scalar_t* __restrict__ C,
const scalar_t* __restrict__ Bs, scalar_t* __restrict__ Btmp, float* __restrict__ Ctmp, int64_t M, int64_t N,
int64_t K, int group_size, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBz, int64_t strideBs
int64_t K, int blocksize, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBs, bool brg,
bool use_brgemm_dequant_out = false
) {
if (brg) {
brgemm<scalar_t, DATA_TYPE>::apply(
A, B, C, Bs, Btmp, Ctmp, M, N, K, blocksize, lda, ldb, ldc, strideBs, use_brgemm_dequant_out
);
return;
}
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64;
const int64_t MB = div_up(M, BLOCK_M);
Expand Down Expand Up @@ -417,10 +611,39 @@ void gemv_4bit_inference(
constexpr int64_t BLOCK_N = block_size_n(); // 32
const int64_t MB = div_up(M, BLOCK_M); // (x + y -1)/ y, res = 1 when M <= 32
const int64_t NB = div_up(N, BLOCK_N);
// TODO: enable brgemm in the future.
// const bool use_brgemm = M > 4;
// const bool use_brgemm_dequant_out = M > 512;
// T* Btmp_start = nullptr;
// TODO: Find better threshold.
T* Btmp_start = nullptr;
#ifdef HAS_TORCH
const bool use_brgemm = M > 4;
const bool use_brgemm_dequant_out = M > 100;
if (use_brgemm_dequant_out) {
// Layout: contiguous [N*K] elements, 64-byte aligned for AVX512 loads
at::Tensor Btmp_t = at::zeros({N, K}, at::dtype(at::kBFloat16));
at::BFloat16* Btmp_start_pt = Btmp_t.data_ptr<at::BFloat16>();
Btmp_start = reinterpret_cast<T*>(Btmp_start_pt);
BNB_OMP_PARALLEL_FOR
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min<int64_t>(N - nb_start, BLOCK_N);
T* Btmp = Btmp_start + nb_start * K;
for (int64_t k = 0; k < K; k += BLOCK_K) {
int64_t kb_size = std::min<int64_t>(BLOCK_K, K - k);
int64_t kgs = k / blocksize;
int64_t strideBs = N;
int64_t ldb = nb_size;
const T* Bs = absmax + nb_start;
const unsigned char* Bw = reinterpret_cast<const unsigned char*>(w + nb_start * K / 2);
unpack_B<DATA_TYPE>(
Btmp + k * BLOCK_N, Bw + (k >> 1) * ldb, Bs + kgs * strideBs, nb_size, kb_size, blocksize, ldb,
BLOCK_N, strideBs
);
}
}
}
#else
const bool use_brgemm = false;
const bool use_brgemm_dequant_out = false;
#endif
// l2 cache block for n
int64_t cache_blocks_nb = get_cache_blocks<T>(BLOCK_N * K);
parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) {
Expand All @@ -439,24 +662,27 @@ void gemv_4bit_inference(
/* B */ w + nb_start * K / 2, // divide by 2 since w is u4 packed in u8, K is w.size(1) * 2
/* C */ out + mb_start * out_stride + nb_start,
/* Bs */ absmax + nb_start,
/* Btmp */ Btmp_inner,
/* Btmp */ use_brgemm_dequant_out ? Btmp_start + nb_start * K : Btmp_inner,
/* Ctmp */ Ctmp,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* gs */ blocksize, // group_size
/* gs */ blocksize, // blocksize
/* lda */ x_stride,
/* ldb */ nb_size,
/* ldc */ out_stride,
/* sBz */ N,
/* sBs */ N
/* sBs */ N,
/* brg */ use_brgemm,
/* dequant choice*/ use_brgemm_dequant_out
);
}
}
}
// if (use_brgemm) {
// at::native::cpublas::brgemm_release();
// }
#ifdef HAS_TORCH
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
#endif
});
}
#endif
Expand Down