diff --git a/CMakeLists.txt b/CMakeLists.txt index 5fe4ffc64..f88ac2b11 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -282,6 +282,9 @@ if (BUILD_CPU) check_cxx_compiler_flag(-mavx512bf16 HAS_AVX512BF16_FLAG) if (HAS_AVX512F_FLAG) target_compile_options(bitsandbytes PRIVATE -mavx512f) + target_compile_options(bitsandbytes PRIVATE -mavx512dq) + target_compile_options(bitsandbytes PRIVATE -mavx512bw) + target_compile_options(bitsandbytes PRIVATE -mavx512vl) endif() if (HAS_AVX512BF16_FLAG) target_compile_options(bitsandbytes PRIVATE -mavx512bf16) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 9547f5a93..da168e17b 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -374,10 +374,18 @@ def matmul_4bit( bias: Optional[torch.Tensor] = None, ): assert quant_state is not None - # Change dtype to bfloat16 on CPU + # Change dtype to input dtype on CPU if A.device.type == "cpu": quant_state.dtype = A.dtype + if getattr(quant_state, "packing_format_for_cpu", False): + out = F.gemv_4bit(A, B, out, state=quant_state) + if bias is not None: + out += bias + return out + else: + return MatMul4Bit.apply(A, B, out, bias, quant_state) + if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu": if A.shape[-1] % quant_state.blocksize != 0: warn( diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 25965aec3..def87045c 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -5,7 +5,7 @@ import torch -from bitsandbytes.functional import get_ptr +from bitsandbytes.functional import get_ptr, has_avx512bf16 from ..._ops import register_kernel from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib @@ -217,3 +217,62 @@ def _( raise ValueError return out + + if has_avx512bf16(): + + @register_kernel("bitsandbytes::gemv_4bit", "cpu") + def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + ) -> torch.Tensor: + assert B.dtype == torch.uint8, "Only support uint8 qweight" + dtype = A.dtype + quant_type = "fp4" if code[1] > 0 else "nf4" + # cpu fused op only support bf16 for now. + if dtype != torch.bfloat16: + A = A.to(torch.bfloat16) + + final_out_shape = (*A.shape[:-1], shapeB[0]) + A = A.reshape(-1, A.shape[-1]) + out_shape = (*A.shape[:-1], shapeB[0]) + out = torch.empty(out_shape, dtype=A.dtype, device=A.device) + M = A.shape[0] + N = shapeB[0] + K = A.shape[1] + x_strideM = A.stride(0) + out_strideM = out.stride(0) + if quant_type == "fp4": + lib.gemv_4bit_inference_cpu_fp4_bf16( + ct.c_int64(M), + ct.c_int64(N), + ct.c_int64(K), + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(out), + ct.c_int64(blocksize), + ct.c_int64(x_strideM), + ct.c_int64(out_strideM), + ) + elif quant_type == "nf4": + lib.gemv_4bit_inference_cpu_nf4_bf16( + ct.c_int64(M), + ct.c_int64(N), + ct.c_int64(K), + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(out), + ct.c_int64(blocksize), + ct.c_int64(x_strideM), + ct.c_int64(out_strideM), + ) + + if dtype != torch.bfloat16: + out = out.to(dtype) + + return out.reshape(final_out_shape) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index a006fd8bb..f97d27cca 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2103,4 +2103,137 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): return out +def _convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantState, block_n: int = 32): + """ + qweight: (K * N / 2) uint8 + return: packed_weight + """ + if qweight.dtype != torch.uint8: + quant_state.original_storage_type = qweight.dtype + qweight = qweight.view(torch.uint8) + quant_state.original_dtype = quant_state.dtype + quant_state.original_nested = quant_state.nested + quant_state.original_qshape = qweight.shape + + qweight = qweight.reshape(-1) + unpacked_w = torch.empty(qweight.shape[0] * 2, dtype=torch.int32, device=qweight.device) + unpacked_w[1::2] = qweight & 0xF + unpacked_w[::2] = qweight >> 4 + qweight_final = unpacked_w.reshape(quant_state.shape).to(torch.uint8) # (*, N, K) + # pack weight: [*, N, K] -> [*, N, K/2] combine low and high bit + assert len(qweight_final.shape) == 2 + N, K = qweight_final.shape[0], qweight_final.shape[1] + assert N % block_n == 0, "N must be divisible by block_n" + assert K % 2 == 0, "K must be even" + BLOCK_N = block_n + BIT_COUNT = 32 # (=32 low +32 high) + new_shape = [N // BLOCK_N, BLOCK_N, K // 2, 2] + out_shape = [N, K // 2] + qw = qweight_final.reshape(new_shape) # (..., N/B, B, K/2, 2) + qw = qw.transpose(-3, -2).contiguous() # (..., N/B, K/2, B, 2) + qw = qw.reshape(-1, BIT_COUNT * 2) # [-1, 64] + high = qw[:, BIT_COUNT:] # high 32 + low = qw[:, :BIT_COUNT] # low 32 + packed = ((high << 4) | low).to(torch.uint8) # combine + final_qweight = packed.reshape(out_shape) + if quant_state.nested: + absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) + absmax += quant_state.offset + if absmax.dtype != torch.float32: + absmax = absmax.float() + + quant_state.absmax = absmax + quant_state.nested = False + delattr(quant_state, "state2") + + quant_state.absmax = ( + quant_state.absmax.reshape(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize) + .T.to(torch.bfloat16) + .contiguous() + ) + + quant_state.dtype = torch.bfloat16 + quant_state.packing_format_for_cpu = True + return final_qweight, quant_state + + +def _convert_weight_packed_for_cpu_inverse( + packed_weight: torch.Tensor, + quant_state: QuantState, + block_n: int = 32, +) -> tuple[torch.Tensor, QuantState]: + """ + packed_weight: [N, K/2] uint8, output of `_convert_weight_packed_for_cpu` (final_qweight) + quant_state: QuantState that was modified by `_convert_weight_packed_for_cpu` + Returns: + qweight: [*, N, K] uint8, original qweight shape (quant_state.shape) + recovered_state: QuantState with partially restored fields (best-effort inverse) + """ + assert quant_state.packing_format_for_cpu, "only for packing format" + assert packed_weight.dtype == torch.uint8 + assert len(packed_weight.shape) == 2, "packed_weight should be [N, K/2]" + N, K_half = packed_weight.shape + K = K_half * 2 + + # 1) packed [N, K/2] -> [N//BLOCK_N, BLOCK_N, K/2, 2] + BLOCK_N = block_n + BIT_COUNT = 32 # (=32 low + 32 high) + + assert N % BLOCK_N == 0, "N must be divisible by block_n" + assert K % 2 == 0, "K must be even" + + # [N, K/2] -> [-1, 64] (32 low + 32 high) + packed = packed_weight.reshape(-1, BIT_COUNT) # [-1, 64] + # split high/low nibbles + high = (packed >> 4) & 0xF + low = packed & 0xF + # concatenate to [..., 64], first 32 are low, last 32 are high + qw = torch.cat([low, high], dim=-1).to(torch.uint8) # [..., 64] + + # -> [N/BLOCK_N, K/2, BLOCK_N, 2] -> [N, K] + qw = qw.reshape(N // BLOCK_N, K_half, BLOCK_N, 2) # [N/B, K/2, B, 2] + qw = qw.transpose(-3, -2).contiguous() # [N/B, B, K/2, 2] + qw = qw.reshape(N, K) # [N, K] + + qweight = qw # [N, K] + + unpacked_w = qweight.reshape(-1).to(torch.int32) # [K*N] + high4 = (unpacked_w[::2] & 0xF).to(torch.uint8) + low4 = (unpacked_w[1::2] & 0xF).to(torch.uint8) + qweight = (high4 << 4) | low4 # [K*N/2] + + # 2) Best-effort restore of quant_state fields (absmax / dtype / nested flags, etc.) + recovered_state = quant_state + qweight = qweight.to(torch.uint8).reshape(recovered_state.original_qshape) + + # quantize absmax + if recovered_state.original_nested: + absmax = recovered_state.absmax.T.reshape(-1).to(recovered_state.original_dtype) + offset = absmax.mean() + qabsmax, state2 = quantize_blockwise(absmax - offset, blocksize=256) + recovered_state.absmax = qabsmax + recovered_state.offset = offset + recovered_state.state2 = state2 + + recovered_state.dtype = recovered_state.original_dtype + recovered_state.packing_format_for_cpu = False + + if getattr(recovered_state, "original_storage_type", None): + qweight = qweight.view(recovered_state.original_storage_type) + + return qweight, recovered_state + + +def has_avx512bf16(): + """ + Try calling native lib.has_avx512bf16_cpu(). + Return False explicitly if symbol missing or call fails. + """ + try: + support_avx_bf16 = lib.has_avx512bf16_cpu() + except (AttributeError, RuntimeError, OSError): + support_avx_bf16 = False + return support_avx_bf16 + + C = 127.0 diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 52c44d215..1c9fac799 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -12,7 +12,12 @@ import bitsandbytes as bnb from bitsandbytes.cextension import ROCM_WARP_SIZE_64 -from bitsandbytes.functional import QuantState +from bitsandbytes.functional import ( + QuantState, + _convert_weight_packed_for_cpu, + _convert_weight_packed_for_cpu_inverse, + has_avx512bf16, +) from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer @@ -311,9 +316,13 @@ def cpu(self): return self.to(device="cpu") def cuda(self, device: Optional[int | device | str] = None, non_blocking: bool = False): + if getattr(self.quant_state, "packing_format_for_cpu", False): + self.data, self.quant_state = _convert_weight_packed_for_cpu_inverse(self.data, self.quant_state) return self.to(device="cuda" if device is None else device, non_blocking=non_blocking) def xpu(self, device: Optional[int | device | str] = None, non_blocking: bool = False): + if getattr(self.quant_state, "packing_format_for_cpu", False): + self.data, self.quant_state = _convert_weight_packed_for_cpu_inverse(self.data, self.quant_state) return self.to(device="xpu" if device is None else device, non_blocking=non_blocking) @overload @@ -479,6 +488,7 @@ def __init__( self.compute_type_is_set = compute_dtype is not None self.quant_state = None self.quant_storage = quant_storage + self.support_avx512bf16_for_cpu = has_avx512bf16() def set_compute_type(self, x): if x.dtype in [torch.float32, torch.bfloat16]: @@ -506,13 +516,26 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): then fill state_dict with components of quant_state """ super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias - + if getattr(self.weight.quant_state, "packing_format_for_cpu", False): + self.weight.data, self.weight.quant_state = _convert_weight_packed_for_cpu_inverse( + self.weight.data, self.weight.quant_state + ) if getattr(self.weight, "quant_state", None) is not None: for k, v in self.weight.quant_state.as_dict(packed=True).items(): destination[prefix + "weight." + k] = v if keep_vars else v.detach() def forward(self, x: torch.Tensor): fix_4bit_weight_quant_state_from_module(self) + quant_state = self.weight.quant_state + + if ( + not getattr(quant_state, "packing_format_for_cpu", False) + and x.device.type == "cpu" + and self.support_avx512bf16_for_cpu + and not self.training + and x.requires_grad == False + ): + self.weight.data, quant_state = _convert_weight_packed_for_cpu(self.weight.data, quant_state) # weights are cast automatically as Int8Params, but the bias has to be cast manually if self.bias is not None and self.bias.dtype != x.dtype: @@ -527,9 +550,9 @@ def forward(self, x: torch.Tensor): x = x.to(self.compute_dtype) bias = None if self.bias is None else self.bias.to(self.compute_dtype) - weight = self.weight.t() + weight = self.weight if getattr(quant_state, "packing_format_for_cpu", False) else self.weight.t() - return bnb.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) + return bnb.matmul_4bit(x, weight, bias=bias, quant_state=quant_state).to(inp_dtype) class LinearFP4(Linear4bit): diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 0f0f9cd0a..f569bf681 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -14,38 +14,6 @@ using namespace BinSearch; #if defined(__AVX512F__) #include -#ifdef _MSC_VER -#include - -static inline bool has_avx512f() { - static bool v = [] { - int info[4]; - __cpuidex(info, 7, 0); - return (info[1] & (1 << 16)) != 0; // EBX bit16 AVX512F - }(); - return v; -} - -static inline bool has_avx512bf16() { - static bool v = [] { - int info[4]; - __cpuidex(info, 7, 1); - return (info[0] & (1 << 5)) != 0; // EAX bit5 AVX512_BF16 - }(); - return v; -} -#else -bool has_avx512f() { - static const bool supported_avx512f = __builtin_cpu_supports("avx512f"); - return supported_avx512f; -} - -bool has_avx512bf16() { - static const bool supported_avx512bf16 = __builtin_cpu_supports("avx512bf16"); - return supported_avx512bf16; -} -#endif - inline __m256i cvt_fp32_to_fp16(const __m512 src) { return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); } @@ -258,6 +226,241 @@ void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long } } +#if defined(__AVX512F__) && defined(__AVX512BF16__) + +#define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)) + +template struct tinygemm_kernel_nn { + static inline void apply( + const scalar_t*, const unsigned char*, scalar_t*, const scalar_t*, int64_t, int, int64_t, int64_t, int64_t, + int64_t, int64_t + ) { + static_assert(sizeof(scalar_t) == 0, "tinygemm_kernel_nn primary template should never be instantiated"); + } +}; + +template struct tinygemm_kernel_nn { + static inline void apply( + const bf16_t* __restrict__ A, const unsigned char* __restrict__ B, bf16_t* __restrict__ C, + const bf16_t* __restrict__ Bs, int64_t K, int group_size, int64_t lda, int64_t ldb, int64_t ldc, + int64_t strideBz, int64_t strideBs + ) { + static_assert(BLOCK_N % 32 == 0); + constexpr int ROWS = BLOCK_M; // 32 + constexpr int COLS = BLOCK_N / 16; // 2 + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 16 * 4; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + __m512 vc_master[ROWS * COLS]; + + __m256i mask = _mm256_set1_epi8(0xF); // lower 4 bit + __m256i fifteen = _mm256_set1_epi8(15); + __m512i lut = DATA_TYPE == 1 + ? _mm512_set_epi16( + 0x0000, -0x4180, -0x41D5, -0x4100, -0x4155, -0x4080, -0x40D5, -0x4455, 0x0000, 0x3E80, + 0x3E2B, 0x3F00, 0x3EAB, 0x3F80, 0x3F2B, 0x3BAB, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000 + ) + : _mm512_set_epi16( + 0x0000, 0x3F80, 0x3F39, 0x3F10, 0x3EE2, 0x3EAD, 0x3E7C, 0x3E25, 0x3DA3, 0x0000, -0x4246, + -0x41C3, -0x416E, -0x4136, -0x40FA, -0x40CE, -0x4080, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000 + ); + __m512 scales[COLS]; + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const int64_t gs2 = group_size >> 1; // 64 / 2 = 32 + const float* a_ptr = reinterpret_cast(A); + + auto loadc = [&](auto i) { + constexpr int col = i % COLS; + vc_master[i] = _mm512_set1_ps(0.f); + }; + Unroll{}(loadc); + + auto pre_compute = [&](auto i, int64_t kgs) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + vc[i] = _mm512_set1_ps(0.f); // reset accumulator + + // load scales + if constexpr (row == 0 && col % 2 == 0) { + // Bs layout: [K/gs, BLOCK_N] : [strideBs, 1], dtype=bf16 + __m512i tmp = _mm512_loadu_si512(reinterpret_cast(Bs + kgs * strideBs + col * 16)); + scales[col] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(tmp, 0)); + scales[col + 1] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(tmp, 1)); + } + }; + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + } + if constexpr (row == 0 && col % 2 == 0) { + __m256i vb_u4 = _mm256_loadu_si256(reinterpret_cast(B + k * ldb + col * 16)); + + // deinterleave and lookup to BF16 + __m256i vb_i8_lo = vb_u4 & mask; + __m256i vb_i8_hi = _mm256_srli_epi16(vb_u4, 4) & mask; + vb_i8_lo = _mm256_add_epi8(vb_i8_lo, fifteen); + vb_i8_hi = _mm256_add_epi8(vb_i8_hi, fifteen); + vb[col] = (__m512bh)_mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(vb_i8_lo), lut); + vb[col + 1] = (__m512bh)_mm512_permutexvar_epi16(_mm512_cvtepi8_epi16(vb_i8_hi), lut); + + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + } + vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + }; + auto post_compute = [&](auto i, int64_t kgs) { + vc_master[i] = _mm512_fmadd_ps(vc[i], scales[i % COLS], vc_master[i]); + }; + for (int64_t k = 0; k < K2; k += gs2) { + Unroll{}(pre_compute, k / gs2); + for (int64_t k_offset = 0; k_offset < gs2; ++k_offset) { + Unroll{}(compute, k + k_offset); + } + Unroll{}(post_compute, k / gs2); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + if constexpr (col % 2 == 0) { + _mm512_storeu_si512( + reinterpret_cast<__m512i*>(C + row * ldc + col * 16), + (__m512i)(_mm512_cvtne2ps_pbh(vc_master[i + 1], vc_master[i])) + ); + } + }; + Unroll{}(storec); + } +}; + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE, DATA_TYPE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, B + nb_start, C + mb_start * ldc + nb_start, Bs + nb_start, K, group_size, lda, ldb, ldc, \ + strideBz, strideBs \ + ); + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, const unsigned char* __restrict__ B, scalar_t* __restrict__ C, + const scalar_t* __restrict__ Bs, scalar_t* __restrict__ Btmp, float* __restrict__ Ctmp, int64_t M, int64_t N, + int64_t K, int group_size, int64_t lda, int64_t ldb, int64_t ldc, int64_t strideBz, int64_t strideBs +) { + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch (mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: + LAUNCH_TINYGEMM_KERNEL_NN(1, 32, DATA_TYPE); + break; + case 0x14: + LAUNCH_TINYGEMM_KERNEL_NN(1, 64, DATA_TYPE); + break; + // mb_size = 2 + case 0x22: + LAUNCH_TINYGEMM_KERNEL_NN(2, 32, DATA_TYPE); + break; + case 0x24: + LAUNCH_TINYGEMM_KERNEL_NN(2, 64, DATA_TYPE); + break; + // mb_size = 3 + case 0x32: + LAUNCH_TINYGEMM_KERNEL_NN(3, 32, DATA_TYPE); + break; + case 0x34: + LAUNCH_TINYGEMM_KERNEL_NN(3, 64, DATA_TYPE); + break; + // mb_size = 4 + case 0x42: + LAUNCH_TINYGEMM_KERNEL_NN(4, 32, DATA_TYPE); + break; + case 0x44: + LAUNCH_TINYGEMM_KERNEL_NN(4, 64, DATA_TYPE); + break; + default: { + std::fprintf( + stderr, "[bitsandbytes] Unexpected block size %lldx%lld\n", (long long)mb_size, (long long)nb_size + ); + std::abort(); // or return; if you prefer silent exit + } + } + } + } +} + +template +void gemv_4bit_inference( + int64_t M, int64_t N, int64_t K, const T* __restrict__ x, const unsigned char* __restrict__ w, + const T* __restrict__ absmax, T* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride +) { + constexpr int64_t BLOCK_M = block_size_m(); // 32 + constexpr int64_t BLOCK_N = block_size_n(); // 32 + const int64_t MB = div_up(M, BLOCK_M); // (x + y -1)/ y, res = 1 when M <= 32 + const int64_t NB = div_up(N, BLOCK_N); + // TODO: enable brgemm in the future. + // const bool use_brgemm = M > 4; + // const bool use_brgemm_dequant_out = M > 512; + // T* Btmp_start = nullptr; + // l2 cache block for n + int64_t cache_blocks_nb = get_cache_blocks(BLOCK_N * K); + parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) { + // for brgemm, use float32 for accumulate + alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; + alignas(64) T Btmp_inner[BLOCK_N * BLOCK_K]; // BLOCK_K = 128 + for (int64_t nbb = begin_nb; nbb < end_nb; nbb += cache_blocks_nb) { + for (int64_t mb = begin_mb; mb < end_mb; ++mb) { // 0-1 + for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, end_nb); ++nb) { + int64_t mb_start = mb * BLOCK_M; // 0 + int64_t mb_size = std::min(M - mb_start, BLOCK_M); + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(N - nb_start, BLOCK_N); + tinygemm_kernel( + /* A */ x + mb_start * x_stride, + /* B */ w + nb_start * K / 2, // divide by 2 since w is u4 packed in u8, K is w.size(1) * 2 + /* C */ out + mb_start * out_stride + nb_start, + /* Bs */ absmax + nb_start, + /* Btmp */ Btmp_inner, + /* Ctmp */ Ctmp, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* gs */ blocksize, // group_size + /* lda */ x_stride, + /* ldb */ nb_size, + /* ldc */ out_stride, + /* sBz */ N, + /* sBs */ N + ); + } + } + } + // if (use_brgemm) { + // at::native::cpublas::brgemm_release(); + // } + }); +} +#endif + //============================================================== // TEMPLATE DEFINITIONS //============================================================== @@ -293,14 +496,13 @@ template void dequantizeBlockwise4bitCpu( unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n ); -// template void gemv_4bit_inference( -// int m, int n, int k, fp16_t* A, unsigned char* B, float* absmax, float* datatype, fp16_t* out, -// int lda, int ldb, int ldc, int blocksize); - -// template void gemv_4bit_inference( -// int m, int n, int k, bf16_t* A, unsigned char* B, float* absmax, float* datatype, bf16_t* out, -// int lda, int ldb, int ldc, int blocksize); - -// template void gemv_4bit_inference( -// int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, -// int lda, int ldb, int ldc, int blocksize); +#if defined(__AVX512F__) && defined(__AVX512BF16__) +template void gemv_4bit_inference( + int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, + const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride +); +template void gemv_4bit_inference( + int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, + const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride +); +#endif diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 7040833a0..6803b29f9 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -1,9 +1,123 @@ #ifndef BITSANDBYTES_CPU_OPS_H #define BITSANDBYTES_CPU_OPS_H +#include +#include #include #include #include +#include +#include + +#if defined(_OPENMP) +#include +#endif + +// amx-bf16 +#define TILE_M 16 +#define TILE_N 16 +#define TILE_K 32 +// work around compiler internal error +#define BLOCK_K 128 // 4 * TILE_K + +// block size for AMX gemm +constexpr int block_size_m() { return 2 * TILE_M; } + +constexpr int block_size_n() { return 2 * TILE_N; } + +template inline int get_cache_blocks(int chunk_size) { + // L2 2MB and ratio of 50% + const int L2_size = 2048 * 1024 >> 1; + return std::max(1, int(L2_size / (chunk_size * sizeof(T)))); +} + +// forced unroll for perf critical path +#if __has_attribute(always_inline) +#define ALWAYS_INLINE __attribute__((__always_inline__)) inline +#else +#define ALWAYS_INLINE inline +#endif + +template struct Unroll { + template ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + Unroll{}(f, args...); + f(std::integral_constant{}, args...); + } +}; + +template <> struct Unroll<1> { + template ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + f(std::integral_constant{}, args...); + } +}; + +template ::value, int>::type = 0> inline T div_up(T x, T y) { + return (x + y - 1) / y; +} + +inline int get_max_threads() { +#if defined(_OPENMP) + return omp_get_max_threads(); +#else + unsigned hc = std::thread::hardware_concurrency(); + return hc == 0 ? 1 : int(hc); +#endif +} + +inline int adjust_num_threads(int m) { + int actual_nth = get_max_threads(); + if (m == 1) + return actual_nth; + return std::max(1, (actual_nth >> 1) * 2); +} + +template inline void parallel_2d(int m, int n, const func_t& f) { + // make sure we have even num_threads + int nth = adjust_num_threads(m); + + // [NOTE] thread blocking: + // + // 1) prefer square block per thread + // 2) use even number of CPU cores + // 3) use all `num_threads` cores + // + // we have: + // TM * TN = T + // BM / TM = BN / TN + // then: + // TM = ((BM / BN) * T) ^ 0.5 + // + float r = float(m) / n; + int nth_m = std::ceil(std::sqrt(r * nth)); + int nth_n = 1; + for (; nth_m > 0; --nth_m) { + nth_n = nth / nth_m; + if (nth_m * nth_n == nth) { + break; + } + } + +#if defined(_OPENMP) +#pragma omp parallel num_threads(nth) + { + int ith = omp_get_thread_num(); + int ith_m = ith / nth_n; + int ith_n = ith % nth_n; + + int thread_block_m = div_up(m, nth_m); + int thread_block_n = div_up(n, nth_n); + + int begin_m = ith_m * thread_block_m; + int end_m = std::min(m, begin_m + thread_block_m); + int begin_n = ith_n * thread_block_n; + int end_n = std::min(n, begin_n + thread_block_n); + + f(begin_m, end_m, begin_n, end_n); + } +#else + f(0, m, 0, n); +#endif +} void quantize_cpu(float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n); @@ -22,6 +136,13 @@ static inline bf16_t float_to_bf16(float x) { return bf16_t{static_cast(r >> 16)}; } +static float bf16_to_float(uint16_t bf16) { + uint32_t bits = (uint32_t)bf16 << 16; + float f; + std::memcpy(&f, &bits, sizeof(f)); + return f; +} + static inline fp16_t float_to_fp16(float x) { uint32_t bits; std::memcpy(&bits, &x, 4); @@ -162,4 +283,52 @@ void dequantizeBlockwise4bitCpu( unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n ); +#if defined(__AVX512F__) +#include + +#ifdef _MSC_VER +#include + +static inline bool has_avx512f() { + static bool v = [] { + int info[4]; + __cpuidex(info, 7, 0); + return (info[1] & (1 << 16)) != 0; // EBX bit16 AVX512F + }(); + return v; +} + +#if defined(__AVX512BF16__) +static inline bool has_avx512bf16() { + static bool v = [] { + int info[4]; + __cpuidex(info, 7, 1); + return (info[0] & (1 << 5)) != 0; // EAX bit5 AVX512_BF16 + }(); + return v; +} +#endif +#else +static inline bool has_avx512f() { + static const bool supported_avx512f = __builtin_cpu_supports("avx512f"); + return supported_avx512f; +} + +#if defined(__AVX512BF16__) +static inline bool has_avx512bf16() { + static const bool supported_avx512bf16 = __builtin_cpu_supports("avx512bf16"); + return supported_avx512bf16; +} +#endif +#endif +#endif + +#if defined(__AVX512F__) && defined(__AVX512BF16__) +template +void gemv_4bit_inference( + int64_t M, int64_t N, int64_t K, const T* __restrict__ x, const unsigned char* __restrict__ w, + const T* __restrict__ absmax, T* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride +); +#endif + #endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index e6b3d0866..07c79fc95 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -891,4 +891,26 @@ void cdequantize_blockwise_cpu_nf4_fp16( ) { dequantizeBlockwise4bitCpu(A, absmax, out, blocksize, m, n); } + +#if defined(__AVX512F__) && defined(__AVX512BF16__) +void gemv_4bit_inference_cpu_fp4_bf16( + int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, + const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride +) { + gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); +} + +void gemv_4bit_inference_cpu_nf4_bf16( + int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w, + const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride +) { + gemv_4bit_inference(M, N, K, x, w, absmax, out, blocksize, x_stride, out_stride); +} +#endif +#if defined(__AVX512F__) +bool has_avx512f_cpu() { return has_avx512f(); } +#if defined(__AVX512BF16__) +bool has_avx512bf16_cpu() { return has_avx512bf16(); } +#endif +#endif } diff --git a/tests/test_functional.py b/tests/test_functional.py index d420ff352..55964818c 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1318,6 +1318,10 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double quant_storage=quant_storage, ) C3 = torch.matmul(A, B.t()) + # CPU requires convert weight packed for gemv + if device == "cpu" and F.has_avx512bf16(): + qB, state = F._convert_weight_packed_for_cpu(qB, state) + qB = qB.t() C2 = F.gemv_4bit(A, qB.t(), state=state) A.requires_grad = True C1 = bnb.matmul_4bit(A, qB.t(), state) diff --git a/tests/test_ops.py b/tests/test_ops.py index da589005e..8d9aa5ab2 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -219,11 +219,26 @@ def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): out_features = 1024 in_features = 256 + if device == "cpu" and blocksize > in_features: + pytest.skip("CPU implementation only suppoer blocksize <= in_features") + A = torch.randn((1, 1, in_features), dtype=dtype, device=device) B = torch.randn((out_features, in_features), dtype=dtype, device=A.device) B_q, absmax = torch.ops.bitsandbytes.quantize_4bit(B, blocksize, quant_type, storage_dtype) code = bitsandbytes.functional.get_4bit_type(quant_type, device=A.device, blocksize=blocksize) + if device == "cpu" and bitsandbytes.functional.has_avx512bf16(): + state = bitsandbytes.functional.QuantState( + absmax=absmax, + shape=B.shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + ) + B_q, state = bitsandbytes.functional._convert_weight_packed_for_cpu(B_q, state) + B_q = B_q.t() + absmax = state.absmax out = torch.ops.bitsandbytes.gemv_4bit.default(A, B_q, B.shape, absmax, code, blocksize) assert out.device == A.device