Skip to content

Commit 3c07023

Browse files
committed
fix quant_storage bf16 and gemv cpu
Signed-off-by: jiqing-feng <[email protected]>
1 parent bc8723e commit 3c07023

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

bitsandbytes/backends/cpu/ops.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,9 @@ def _(
203203
# Enable non uint8 dtype
204204
device = A.device
205205
if A.dtype != torch.uint8:
206+
if A.dtype == torch.bfloat16:
207+
# Numpy does not support bfloat16
208+
A = A.view(torch.float16)
206209
bytes_value = A.cpu().numpy().tobytes()
207210
A = torch.frombuffer(bytes_value, dtype=torch.uint8).to(device)
208211

@@ -247,12 +250,8 @@ def _(
247250
blocksize: int,
248251
) -> torch.Tensor:
249252
# Applied from dequantize_4bit
250-
B = B.view(-1, 1)
251-
upper = (B >> 4).to(torch.int64)
252-
lower = (B & 0x0F).to(torch.int64)
253-
blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize)
254-
B_dq = code[blocks] * absmax[:, None]
255-
B_dq = B_dq.reshape(-1, *shapeB[1:]).to(A.dtype)
253+
quant_type = "nf4" if code[1] > 0 else "fp4"
254+
B_dq = torch.ops.bitsandbytes.dequantize_4bit.default(B, absmax, blocksize, quant_type, shapeB, A.dtype)
256255

257256
# User called gemv with B.t(), so we need to transpose it back.
258257
# if B.shape[0] == 1:

tests/test_functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
129129
assert abserr < 0.0036
130130
assert relerr < 0.015
131131
else:
132-
assert abserr < 0.00175
132+
assert abserr < 0.0023
133133
assert relerr < 0.012
134134
assert A2.dtype == dtype
135135
# print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))

0 commit comments

Comments
 (0)