Skip to content

Commit bc8723e

Browse files
committed
fix quantize blockwise output shape
Signed-off-by: jiqing-feng <[email protected]>
1 parent 83cea6b commit bc8723e

File tree

2 files changed

+56
-26
lines changed

2 files changed

+56
-26
lines changed

bitsandbytes/backends/cpu/ops.py

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,42 @@ def _(A: torch.Tensor, B: torch.Tensor):
2626
@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
2727
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
2828
torch._check_is_size(blocksize)
29-
torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on cpu, got {A.dtype}")
3029

3130
n = A.numel()
32-
blocks = -(n // -blocksize)
33-
34-
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
35-
out = torch.empty_like(A, dtype=torch.uint8)
36-
37-
lib.cquantize_blockwise_cpu_fp32(
38-
get_ptr(code),
39-
get_ptr(A),
40-
get_ptr(absmax),
41-
get_ptr(out),
42-
ct.c_longlong(blocksize),
43-
ct.c_longlong(n),
44-
)
31+
32+
# Only FP32 has c++ kernrl
33+
if A.dtype == torch.float32:
34+
blocks = -(n // -blocksize)
35+
36+
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
37+
out = torch.empty_like(A, dtype=torch.uint8)
38+
39+
lib.cquantize_blockwise_cpu_fp32(
40+
get_ptr(code),
41+
get_ptr(A),
42+
get_ptr(absmax),
43+
get_ptr(out),
44+
ct.c_longlong(blocksize),
45+
ct.c_longlong(n),
46+
)
47+
else:
48+
rem = n % blocksize
49+
has_rem = rem > 0
50+
blocks = n // blocksize + has_rem
51+
absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype)
52+
A_reshaped = A.reshape(n)
53+
A_com = A_reshaped[: n - rem]
54+
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
55+
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
56+
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
57+
scaled_A = scaled_A.reshape(-1)
58+
if has_rem:
59+
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
60+
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
61+
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
62+
63+
diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
64+
out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)
4565

4666
return out, absmax
4767

@@ -50,18 +70,28 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
5070
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
5171
torch._check_is_size(blocksize)
5272
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
53-
torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on cpu, got {dtype}")
5473

55-
out = torch.empty_like(A, dtype=dtype)
56-
57-
lib.cdequantize_blockwise_cpu_fp32(
58-
get_ptr(code),
59-
get_ptr(A),
60-
get_ptr(absmax),
61-
get_ptr(out),
62-
ct.c_longlong(blocksize),
63-
ct.c_longlong(A.numel()),
64-
)
74+
# Only FP32 has c++ kernrl
75+
if dtype == torch.float32:
76+
out = torch.empty_like(A, dtype=dtype)
77+
78+
lib.cdequantize_blockwise_cpu_fp32(
79+
get_ptr(code),
80+
get_ptr(A),
81+
get_ptr(absmax),
82+
get_ptr(out),
83+
ct.c_longlong(blocksize),
84+
ct.c_longlong(A.numel()),
85+
)
86+
else:
87+
out = code[A.reshape(-1).int()]
88+
blocks = out.shape[-1] // blocksize
89+
res = out.shape[-1] % blocksize
90+
if res != 0:
91+
out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0)
92+
out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1)
93+
out = out[: blocks * blocksize + res]
94+
out = out.reshape(A.shape)
6595

6696
return out
6797

tests/test_functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
126126
abserr = sum(diffs) / len(diffs)
127127
relerr = sum(reldiffs) / len(reldiffs)
128128
if signed:
129-
assert abserr < 0.0035
129+
assert abserr < 0.0036
130130
assert relerr < 0.015
131131
else:
132132
assert abserr < 0.00175

0 commit comments

Comments
 (0)