Skip to content
Open
Show file tree
Hide file tree
Changes from 70 commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
6be1412
add template to support more dtypes
jiqing-feng Oct 28, 2025
252ac0f
update cmake list
jiqing-feng Oct 28, 2025
f98c9e5
fix typo
jiqing-feng Oct 28, 2025
902bf35
fix compile cpu
jiqing-feng Oct 28, 2025
fef8459
make different dtype works
jiqing-feng Oct 29, 2025
55cbaa0
use bf16 on CPU
jiqing-feng Oct 29, 2025
bbef95b
fix state2 dtype
jiqing-feng Oct 29, 2025
e842513
remove torch
jiqing-feng Oct 30, 2025
d4473fa
rm torch
jiqing-feng Oct 30, 2025
dea8dd6
enable float to bf16
jiqing-feng Oct 30, 2025
e9bb4fe
rm dequantizeBlockwise4bitCpu
jiqing-feng Oct 30, 2025
cdc8d5e
fix check
jiqing-feng Oct 30, 2025
baacfac
enable dequant 4bit kernel
jiqing-feng Oct 30, 2025
eec3521
fix typo
jiqing-feng Oct 30, 2025
d7cc1c5
fix typo
jiqing-feng Oct 30, 2025
124b754
fix dequantize
jiqing-feng Oct 30, 2025
0f918c7
fix
jiqing-feng Oct 30, 2025
e1a8b20
fix
jiqing-feng Oct 30, 2025
eab45c8
test
jiqing-feng Oct 30, 2025
d9f5dd8
fix
jiqing-feng Oct 30, 2025
070f8a0
fix
jiqing-feng Oct 30, 2025
a84addf
fix
jiqing-feng Oct 30, 2025
c4bb660
fix
jiqing-feng Oct 30, 2025
4ba13fd
fix
jiqing-feng Oct 30, 2025
c0d05ec
change input param
jiqing-feng Oct 31, 2025
62a16a6
fix typo
jiqing-feng Oct 31, 2025
d9ad828
fix input param
jiqing-feng Oct 31, 2025
09ed6cb
spliut 8bit and 4bit
jiqing-feng Oct 31, 2025
a3f7b61
fix typo
jiqing-feng Oct 31, 2025
4708470
fix typo
jiqing-feng Oct 31, 2025
1dfe9f7
fix input params
jiqing-feng Oct 31, 2025
00289c4
fix input params
jiqing-feng Oct 31, 2025
a2578ba
fix
jiqing-feng Oct 31, 2025
72033dc
fix typo
jiqing-feng Oct 31, 2025
1c20ae8
enable dequant4bit
jiqing-feng Oct 31, 2025
7552fe2
fix
jiqing-feng Oct 31, 2025
8b32a39
fix
jiqing-feng Oct 31, 2025
8f1cc36
fix reverse
jiqing-feng Oct 31, 2025
49d242a
fix dequant 4bit fallback path
jiqing-feng Nov 3, 2025
4a9a6dc
fix fp4 dequant
jiqing-feng Nov 3, 2025
6bcd19e
Merge branch 'main' into cpu_kernel
jiqing-feng Nov 4, 2025
d7e981d
rm _Float16
jiqing-feng Nov 5, 2025
48739b0
tmp codes
jiqing-feng Nov 6, 2025
f784be8
enable gemv
jiqing-feng Nov 7, 2025
92192c9
change to 4bit dequant
jiqing-feng Nov 7, 2025
bd02e71
fix def
jiqing-feng Nov 7, 2025
8520069
fix type
jiqing-feng Nov 7, 2025
e921cbb
fix absmax dtype
jiqing-feng Nov 7, 2025
9b5d97a
fix type
jiqing-feng Nov 7, 2025
fd6cff1
fix compile and type
jiqing-feng Nov 7, 2025
46d6e47
enable gemv
jiqing-feng Nov 7, 2025
3271c30
fix shape
jiqing-feng Nov 7, 2025
176a2b6
fix lib name
jiqing-feng Nov 7, 2025
196984a
debug
jiqing-feng Nov 7, 2025
7652115
update
jiqing-feng Nov 11, 2025
ea0e649
enable gemv 4bit bf16
jiqing-feng Nov 12, 2025
9277d24
enable avx512 check
jiqing-feng Nov 13, 2025
4fb315b
fix check
jiqing-feng Nov 13, 2025
81f1984
fix endif
jiqing-feng Nov 13, 2025
0f78bad
fix format
jiqing-feng Nov 13, 2025
fcb8456
fix format
jiqing-feng Nov 13, 2025
c5e1894
fix def
jiqing-feng Nov 13, 2025
f2029c6
rebase
jiqing-feng Nov 14, 2025
df1d669
fix position
jiqing-feng Nov 14, 2025
bb3ac8d
fix format
jiqing-feng Nov 14, 2025
26b5685
rm duplicated func
jiqing-feng Nov 14, 2025
445725b
Merge branch 'main' into cpu_fused_kernel
jiqing-feng Nov 17, 2025
580010c
rm useless code comments
jiqing-feng Nov 17, 2025
57b89bf
fix out shape
jiqing-feng Nov 19, 2025
302a5fe
Merge branch 'main' into cpu_fused_kernel
jiqing-feng Nov 19, 2025
de5fb9c
fix comments
jiqing-feng Nov 20, 2025
6858a90
add reverse format
jiqing-feng Nov 20, 2025
3b3d609
check avx512bf15
jiqing-feng Nov 20, 2025
fbb911b
fix has_avx512bf16
jiqing-feng Nov 20, 2025
3179b42
fix tests
jiqing-feng Nov 20, 2025
0c88d43
fix absmax shhape
jiqing-feng Nov 20, 2025
feb8ad2
fix compile
jiqing-feng Nov 20, 2025
c6b714d
fix tests
jiqing-feng Nov 20, 2025
5497111
fix test_gemv
jiqing-feng Nov 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,24 @@ if (BUILD_CPU)
include(CheckCXXCompilerFlag)
check_cxx_compiler_flag(-mavx512f HAS_AVX512F_FLAG)
check_cxx_compiler_flag(-mavx512bf16 HAS_AVX512BF16_FLAG)
check_cxx_compiler_flag(-mavx512dq HAS_AVX512DQ)
check_cxx_compiler_flag(-mavx512bw HAS_AVX512BW)
check_cxx_compiler_flag(-mavx512vl HAS_AVX512VL)
if (HAS_AVX512F_FLAG)
target_compile_options(bitsandbytes PRIVATE -mavx512f)
endif()
if (HAS_AVX512BF16_FLAG)
target_compile_options(bitsandbytes PRIVATE -mavx512bf16)
endif()
if(HAS_AVX512DQ)
target_compile_options(bitsandbytes PRIVATE -mavx512dq)
endif()
if(HAS_AVX512BW)
target_compile_options(bitsandbytes PRIVATE -mavx512bw)
endif()
if(HAS_AVX512VL)
target_compile_options(bitsandbytes PRIVATE -mavx512vl)
endif()
target_compile_options(
bitsandbytes PRIVATE
-mprefer-vector-width=256
Expand Down
8 changes: 7 additions & 1 deletion bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,16 @@ def matmul_4bit(
bias: Optional[torch.Tensor] = None,
):
assert quant_state is not None
# Change dtype to bfloat16 on CPU
# Change dtype to input dtype on CPU
if A.device.type == "cpu":
quant_state.dtype = A.dtype

if getattr(quant_state, "enable_optimized_cpu", False):
out = F.gemv_4bit(A, B, out, state=quant_state)
if bias is not None:
out += bias
return out

if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
if A.shape[-1] % quant_state.blocksize != 0:
warn(
Expand Down
61 changes: 60 additions & 1 deletion bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from bitsandbytes.functional import get_ptr
from bitsandbytes.functional import get_ptr, has_avx512bf16

from ..._ops import register_kernel
from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib
Expand Down Expand Up @@ -217,3 +217,62 @@ def _(
raise ValueError

return out

if has_avx512bf16():

@register_kernel("bitsandbytes::gemv_4bit", "cpu")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
) -> torch.Tensor:
# Applied from dequantize_4bit
dtype = A.dtype
quant_type = "fp4" if code[1] > 0 else "nf4"
# cpu fused op only support bf16 for now.
if dtype != torch.bfloat16:
A = A.to(torch.bfloat16)

final_out_shape = (*A.shape[:-1], shapeB[0])
A = A.reshape(-1, A.shape[-1])
out_shape = (*A.shape[:-1], shapeB[0])
out = torch.empty(out_shape, dtype=A.dtype, device=A.device)
M = A.shape[0]
N = shapeB[0]
K = A.shape[1]
x_strideM = A.stride(0)
out_strideM = out.stride(0)
if quant_type == "fp4":
lib.gemv_4bit_inference_cpu_fp4_bf16(
ct.c_int64(M),
ct.c_int64(N),
ct.c_int64(K),
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(out),
ct.c_int64(blocksize),
ct.c_int64(x_strideM),
ct.c_int64(out_strideM),
)
elif quant_type == "nf4":
lib.gemv_4bit_inference_cpu_nf4_bf16(
ct.c_int64(M),
ct.c_int64(N),
ct.c_int64(K),
get_ptr(A),
get_ptr(B),
get_ptr(absmax),
get_ptr(out),
ct.c_int64(blocksize),
ct.c_int64(x_strideM),
ct.c_int64(out_strideM),
)

if dtype != torch.bfloat16:
out = out.to(dtype)

return out.reshape(final_out_shape)
52 changes: 52 additions & 0 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2103,4 +2103,56 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
return out


def convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantState, block_n: int = 32):
"""
qweight: (K * N / 2) uint8
return: packed_weight
"""
assert qweight.dtype == torch.uint8, "qweight must be uint8"
qweight = qweight.reshape(-1)
unpacked_w = torch.empty(qweight.shape[0] * 2, dtype=torch.int32, device=qweight.device)
unpacked_w[1::2] = qweight & 0xF
unpacked_w[::2] = qweight >> 4
qweight_final = unpacked_w.reshape(quant_state.shape).to(torch.uint8) # (*, N, K)
# pack weight: [*, N, K] -> [*, N, K/2] combine low and high bit
assert len(qweight_final.shape) == 2
N, K = qweight_final.shape[0], qweight_final.shape[1]
assert N % block_n == 0, "N must be divisible by block_n"
assert K % 2 == 0, "K must be even"
BLOCK_N = block_n
BIT_COUNT = 32 # (=32 low +32 high)
new_shape = [N // BLOCK_N, BLOCK_N, K // 2, 2]
out_shape = [N, K // 2]
qw = qweight_final.reshape(new_shape) # (..., N/B, B, K/2, 2)
qw = qw.transpose(-3, -2).contiguous() # (..., N/B, K/2, B, 2)
qw = qw.reshape(-1, BIT_COUNT * 2) # [-1, 64]
high = qw[:, BIT_COUNT:] # high 32
low = qw[:, :BIT_COUNT] # low 32
packed = ((high << 4) | low).to(torch.uint8) # combine
final_qweight = packed.reshape(out_shape)
if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32:
absmax = absmax.float()

quant_state.absmax = (
absmax.reshape(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize)
.T.to(torch.bfloat16)
.contiguous()
)
quant_state.nested = False
delattr(quant_state, "state2")

quant_state.dtype = torch.bfloat16
return final_qweight, quant_state


def has_avx512bf16():
if hasattr(lib, "has_avx512bf16_cpu") and lib.has_avx512bf16_cpu():
return True
else:
return False


C = 127.0
19 changes: 16 additions & 3 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import bitsandbytes as bnb
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
from bitsandbytes.functional import QuantState
from bitsandbytes.functional import QuantState, convert_weight_packed_for_cpu, has_avx512bf16
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer

Expand Down Expand Up @@ -479,6 +479,7 @@ def __init__(
self.compute_type_is_set = compute_dtype is not None
self.quant_state = None
self.quant_storage = quant_storage
self.enable_optimized_cpu = False

def set_compute_type(self, x):
if x.dtype in [torch.float32, torch.bfloat16]:
Expand Down Expand Up @@ -512,8 +513,20 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
destination[prefix + "weight." + k] = v if keep_vars else v.detach()

def forward(self, x: torch.Tensor):
quant_state = self.weight.quant_state
fix_4bit_weight_quant_state_from_module(self)

if (
not self.enable_optimized_cpu
and x.device.type == "cpu"
and has_avx512bf16()
and not self.training
and x.requires_grad == False
):
self.weight.data, quant_state = convert_weight_packed_for_cpu(self.weight.data, quant_state)
self.enable_optimized_cpu = True
quant_state.enable_optimized_cpu = True

Comment on lines 531 to 539
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a couple things I'm wondering about:

When we serialize from CPU after running through forward(), we probably still want to be compatible with other devices. I am thinking for when serializing we want to undo this transformation if it's present.

Possibly an edge concern, but if we do a forward pass on CPU and then move to an accelerator, what would happen? I assume the weights are then in the wrong order?

@SunMarc I would appreciate any feedback you might have on this part!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me, I prefer that we stick with only one packing format for serialization and the all other hardware / kernels convert this packing format at initialization or during the forward as we do here.

So we need a way to disable serialization or send a warning when someone tries to do that. This is probably something that we can do in transformers as I think most of the models are serialized from there.

Also instead of enable_optimized_cpu maybe we can rename it packing_format ?

Possibly an edge concern, but if we do a forward pass on CPU and then move to an accelerator, what would happen? I assume the weights are then in the wrong order?

Either we re-convert the weights for cuda (but this opens the door to many conversion function between all packing format) or we just raise an error asking the users to only run the model on one device.

# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
Expand All @@ -527,9 +540,9 @@ def forward(self, x: torch.Tensor):
x = x.to(self.compute_dtype)

bias = None if self.bias is None else self.bias.to(self.compute_dtype)
weight = self.weight.t()
weight = self.weight if getattr(quant_state, "enable_optimized_cpu", False) else self.weight.t()

return bnb.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
return bnb.matmul_4bit(x, weight, bias=bias, quant_state=quant_state).to(inp_dtype)


class LinearFP4(Linear4bit):
Expand Down
Loading
Loading