Skip to content

Commit 9d0f459

Browse files
Fix nested quant
1 parent e9c79cf commit 9d0f459

File tree

4 files changed

+30
-6
lines changed

4 files changed

+30
-6
lines changed

bitsandbytes/_ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def _(
131131
)
132132

133133

134+
@register_fake("bitsandbytes::quantize_4bit")
134135
def _(
135136
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
136137
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -147,13 +148,15 @@ def _(
147148
)
148149

149150

151+
@register_fake("bitsandbytes::dequantize_blockwise")
150152
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
151153
return torch.empty_like(A, dtype=dtype)
152154

153155

154156
torch.library.define("bitsandbytes::quantize_blockwise", "(Tensor A, Tensor code, int blocksize) -> (Tensor, Tensor)")
155157

156158

159+
@register_fake("bitsandbytes::quantize_blockwise")
157160
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]:
158161
n = A.numel()
159162
blocks = -(n // -blocksize)

bitsandbytes/backends/cuda/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor
231231
else:
232232
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
233233

234-
return out, absmax
234+
return out, absmax
235235

236236

237237
@register_kernel("bitsandbytes::dequantize_blockwise", "cuda")

bitsandbytes/functional.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -977,7 +977,7 @@ def dequantize_blockwise(
977977

978978
return torch.ops.bitsandbytes.dequantize_blockwise(
979979
A,
980-
quant_state.absmax,
980+
absmax,
981981
quant_state.code.to(A.device),
982982
quant_state.blocksize,
983983
quant_state.dtype,
@@ -1142,8 +1142,9 @@ def quantize_4bit(
11421142

11431143
if compress_statistics:
11441144
offset = absmax.mean()
1145-
absmax -= offset
1146-
qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
1145+
# absmax -= offset
1146+
# qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
1147+
qabsmax, state2 = quantize_blockwise(absmax - offset, blocksize=256)
11471148
del absmax
11481149
state = QuantState(
11491150
absmax=qabsmax,
@@ -1249,7 +1250,7 @@ def dequantize_4bit(
12491250
out = torch.ops.bitsandbytes.dequantize_4bit(
12501251
A,
12511252
absmax,
1252-
blocksize,
1253+
quant_state.blocksize,
12531254
quant_state.quant_type,
12541255
quant_state.shape,
12551256
quant_state.dtype,

tests/test_ops.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import torch
33

4-
import bitsandbytes # noqa: F401
4+
import bitsandbytes
55

66

77
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@@ -63,3 +63,23 @@ def test_int8_mm_dequant(device):
6363
assert out.device == A.device
6464

6565
torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats))
66+
67+
68+
@pytest.mark.parametrize("device", ["cpu", "cuda"])
69+
def test_quantize_blockwise(device):
70+
# if device == "cpu":
71+
# pytest.skip("CPU implementation is not available")
72+
blocksize = 256
73+
74+
code = bitsandbytes.functional.create_dynamic_map().to(device)
75+
A = torch.randn(1024, 1024, dtype=torch.float16, device=device)
76+
out, absmax = torch.ops.bitsandbytes.quantize_blockwise(A, code, blocksize)
77+
78+
assert out.shape == A.shape
79+
assert out.dtype == torch.uint8
80+
assert out.device == A.device
81+
82+
assert absmax.device == A.device
83+
assert absmax.dtype == torch.float32
84+
85+
torch.library.opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize))

0 commit comments

Comments
 (0)