Skip to content

Commit 6858a90

Browse files
committed
add reverse format
Signed-off-by: jiqing-feng <[email protected]>
1 parent de5fb9c commit 6858a90

File tree

2 files changed

+83
-6
lines changed

2 files changed

+83
-6
lines changed

bitsandbytes/functional.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2109,6 +2109,10 @@ def _convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantStat
21092109
return: packed_weight
21102110
"""
21112111
assert qweight.dtype == torch.uint8, "qweight must be uint8"
2112+
quant_state.original_dtype = quant_state.dtype
2113+
quant_state.original_nested = quant_state.nested
2114+
quant_state.original_qshape = qweight.shape
2115+
21122116
qweight = qweight.reshape(-1)
21132117
unpacked_w = torch.empty(qweight.shape[0] * 2, dtype=torch.int32, device=qweight.device)
21142118
unpacked_w[1::2] = qweight & 0xF
@@ -2145,9 +2149,73 @@ def _convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantStat
21452149
delattr(quant_state, "state2")
21462150

21472151
quant_state.dtype = torch.bfloat16
2152+
quant_state.packing_format_for_cpu = True
21482153
return final_qweight, quant_state
21492154

21502155

2156+
def _convert_weight_packed_for_cpu_inverse(
2157+
packed_weight: torch.Tensor,
2158+
quant_state: QuantState,
2159+
block_n: int = 32,
2160+
) -> tuple[torch.Tensor, QuantState]:
2161+
"""
2162+
packed_weight: [N, K/2] uint8, output of `_convert_weight_packed_for_cpu` (final_qweight)
2163+
quant_state: QuantState that was modified by `_convert_weight_packed_for_cpu`
2164+
Returns:
2165+
qweight: [*, N, K] uint8, original qweight shape (quant_state.shape)
2166+
recovered_state: QuantState with partially restored fields (best-effort inverse)
2167+
"""
2168+
assert quant_state.packing_format_for_cpu, "only for packing format"
2169+
assert packed_weight.dtype == torch.uint8
2170+
assert len(packed_weight.shape) == 2, "packed_weight should be [N, K/2]"
2171+
N, K_half = packed_weight.shape
2172+
K = K_half * 2
2173+
2174+
# 1) packed [N, K/2] -> [N//BLOCK_N, BLOCK_N, K/2, 2]
2175+
BLOCK_N = block_n
2176+
BIT_COUNT = 32 # (=32 low + 32 high)
2177+
2178+
assert N % BLOCK_N == 0, "N must be divisible by block_n"
2179+
assert K % 2 == 0, "K must be even"
2180+
2181+
# [N, K/2] -> [-1, 64] (32 low + 32 high)
2182+
packed = packed_weight.reshape(-1, BIT_COUNT) # [-1, 64]
2183+
# split high/low nibbles
2184+
high = (packed >> 4) & 0xF
2185+
low = packed & 0xF
2186+
# concatenate to [..., 64], first 32 are low, last 32 are high
2187+
qw = torch.cat([low, high], dim=-1).to(torch.uint8) # [..., 64]
2188+
2189+
# -> [N/BLOCK_N, K/2, BLOCK_N, 2] -> [N, K]
2190+
qw = qw.reshape(N // BLOCK_N, K_half, BLOCK_N, 2) # [N/B, K/2, B, 2]
2191+
qw = qw.transpose(-3, -2).contiguous() # [N/B, B, K/2, 2]
2192+
qw = qw.reshape(N, K) # [N, K]
2193+
2194+
qweight = qw # [N, K]
2195+
2196+
unpacked_w = qweight.reshape(-1).to(torch.int32) # [K*N]
2197+
high4 = (unpacked_w[::2] & 0xF).to(torch.uint8)
2198+
low4 = (unpacked_w[1::2] & 0xF).to(torch.uint8)
2199+
qweight = (high4 << 4) | low4 # [K*N/2]
2200+
2201+
# 2) Best-effort restore of quant_state fields (absmax / dtype / nested flags, etc.)
2202+
recovered_state = quant_state
2203+
2204+
# quantize absmax
2205+
if recovered_state.original_nested:
2206+
absmax = recovered_state.absmax.T.reshape(-1).to(recovered_state.original_dtype)
2207+
offset = absmax.mean()
2208+
qabsmax, state2 = quantize_blockwise(absmax - offset, blocksize=256)
2209+
recovered_state.absmax = qabsmax
2210+
recovered_state.offset = offset
2211+
recovered_state.state2 = state2
2212+
2213+
recovered_state.dtype = recovered_state.original_dtype
2214+
recovered_state.packing_format_for_cpu = False
2215+
2216+
return qweight.to(torch.uint8).reshape(recovered_state.original_qshape), recovered_state
2217+
2218+
21512219
def has_avx512bf16():
21522220
if hasattr(lib, "has_avx512bf16_cpu") and lib.has_avx512bf16_cpu():
21532221
return True

bitsandbytes/nn/modules.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212

1313
import bitsandbytes as bnb
1414
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
15-
from bitsandbytes.functional import QuantState, _convert_weight_packed_for_cpu, has_avx512bf16
15+
from bitsandbytes.functional import (
16+
QuantState,
17+
_convert_weight_packed_for_cpu,
18+
_convert_weight_packed_for_cpu_inverse,
19+
has_avx512bf16,
20+
)
1621
from bitsandbytes.optim import GlobalOptimManager
1722
from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer
1823

@@ -311,9 +316,13 @@ def cpu(self):
311316
return self.to(device="cpu")
312317

313318
def cuda(self, device: Optional[int | device | str] = None, non_blocking: bool = False):
319+
if getattr(self.quant_state, "packing_format_for_cpu", False):
320+
self.data, self.quant_state = _convert_weight_packed_for_cpu_inverse(self.data, self.quant_state)
314321
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
315322

316323
def xpu(self, device: Optional[int | device | str] = None, non_blocking: bool = False):
324+
if getattr(self.quant_state, "packing_format_for_cpu", False):
325+
self.data, self.quant_state = _convert_weight_packed_for_cpu_inverse(self.data, self.quant_state)
317326
return self.to(device="xpu" if device is None else device, non_blocking=non_blocking)
318327

319328
@overload
@@ -479,7 +488,6 @@ def __init__(
479488
self.compute_type_is_set = compute_dtype is not None
480489
self.quant_state = None
481490
self.quant_storage = quant_storage
482-
self.packing_format_for_cpu = False
483491

484492
def set_compute_type(self, x):
485493
if x.dtype in [torch.float32, torch.bfloat16]:
@@ -507,7 +515,10 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
507515
then fill state_dict with components of quant_state
508516
"""
509517
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
510-
518+
if getattr(self.weight.quant_state, "packing_format_for_cpu", False):
519+
self.weight.data, self.weight.quant_state = _convert_weight_packed_for_cpu_inverse(
520+
self.weight.data, self.weight.quant_state
521+
)
511522
if getattr(self.weight, "quant_state", None) is not None:
512523
for k, v in self.weight.quant_state.as_dict(packed=True).items():
513524
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
@@ -517,15 +528,13 @@ def forward(self, x: torch.Tensor):
517528
quant_state = self.weight.quant_state
518529

519530
if (
520-
not self.packing_format_for_cpu
531+
not getattr(quant_state, "packing_format_for_cpu", False)
521532
and x.device.type == "cpu"
522533
and has_avx512bf16()
523534
and not self.training
524535
and x.requires_grad == False
525536
):
526537
self.weight.data, quant_state = _convert_weight_packed_for_cpu(self.weight.data, quant_state)
527-
self.packing_format_for_cpu = True
528-
quant_state.packing_format_for_cpu = True
529538

530539
# weights are cast automatically as Int8Params, but the bias has to be cast manually
531540
if self.bias is not None and self.bias.dtype != x.dtype:

0 commit comments

Comments
 (0)