Skip to content

Commit 3179b42

Browse files
committed
fix tests
Signed-off-by: jiqing-feng <[email protected]>
1 parent fbb911b commit 3179b42

File tree

3 files changed

+13
-3
lines changed

3 files changed

+13
-3
lines changed

bitsandbytes/backends/cpu/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def _(
229229
code: torch.Tensor,
230230
blocksize: int,
231231
) -> torch.Tensor:
232-
# Applied from dequantize_4bit
232+
assert B.dtype == torch.uint8, "Only support uint8 qweight"
233233
dtype = A.dtype
234234
quant_type = "fp4" if code[1] > 0 else "nf4"
235235
# cpu fused op only support bf16 for now.

bitsandbytes/functional.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2108,7 +2108,9 @@ def _convert_weight_packed_for_cpu(qweight: torch.Tensor, quant_state: QuantStat
21082108
qweight: (K * N / 2) uint8
21092109
return: packed_weight
21102110
"""
2111-
assert qweight.dtype == torch.uint8, "qweight must be uint8"
2111+
if qweight.dtype != torch.uint8:
2112+
quant_state.original_storage_type = qweight.dtype
2113+
qweight = qweight.view(torch.uint8)
21122114
quant_state.original_dtype = quant_state.dtype
21132115
quant_state.original_nested = quant_state.nested
21142116
quant_state.original_qshape = qweight.shape
@@ -2200,6 +2202,7 @@ def _convert_weight_packed_for_cpu_inverse(
22002202

22012203
# 2) Best-effort restore of quant_state fields (absmax / dtype / nested flags, etc.)
22022204
recovered_state = quant_state
2205+
qweight = qweight.to(torch.uint8).reshape(recovered_state.original_qshape)
22032206

22042207
# quantize absmax
22052208
if recovered_state.original_nested:
@@ -2213,7 +2216,10 @@ def _convert_weight_packed_for_cpu_inverse(
22132216
recovered_state.dtype = recovered_state.original_dtype
22142217
recovered_state.packing_format_for_cpu = False
22152218

2216-
return qweight.to(torch.uint8).reshape(recovered_state.original_qshape), recovered_state
2219+
if getattr(recovered_state, "original_storage_type", None):
2220+
qweight = qweight.view(recovered_state.original_storage_type)
2221+
2222+
return qweight, recovered_state
22172223

22182224

22192225
def has_avx512bf16():

tests/test_functional.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,6 +1318,10 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
13181318
quant_storage=quant_storage,
13191319
)
13201320
C3 = torch.matmul(A, B.t())
1321+
# CPU requires convert weight packed for gemv
1322+
if device == "cpu" and F.has_avx512bf16():
1323+
qB, state = F._convert_weight_packed_for_cpu(qB, state)
1324+
qB = qB.t()
13211325
C2 = F.gemv_4bit(A, qB.t(), state=state)
13221326
A.requires_grad = True
13231327
C1 = bnb.matmul_4bit(A, qB.t(), state)

0 commit comments

Comments
 (0)