Skip to content

Commit bae4ff0

Browse files
TimDettmersclaude
andcommitted
Implement k-bit quantization infrastructure with placeholder kernels
Add quantize_blockwise_kbit and dequantize_blockwise_kbit functions that support k-bit quantization where k is a template parameter (2-8 bits). The implementation follows the existing bitsandbytes architecture: - Python API with k parameter in functional.py - PyTorch operation registration in _ops.py - C interface with template demangling in pythonInterface.cpp - C++ template dispatch in ops.cu - CUDA placeholder kernels that return 1.0 for all elements - Full CUDA backend support, CPU throws NotImplementedError - Comprehensive test suite in test_kbit_quant.py The infrastructure is complete and tested. Placeholder kernels can be replaced with actual k-bit quantization logic. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent e54dc12 commit bae4ff0

File tree

12 files changed

+1377
-0
lines changed

12 files changed

+1377
-0
lines changed

bitsandbytes/_ops.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,52 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
273273
return out, absmax
274274

275275

276+
torch.library.define("bitsandbytes::quantize_blockwise_kbit", "(Tensor A, int k, Tensor code, int blocksize) -> (Tensor, Tensor)")
277+
278+
279+
@register_fake("bitsandbytes::quantize_blockwise_kbit")
280+
def _(A: torch.Tensor, k: int, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
281+
torch._check_is_size(blocksize)
282+
torch._check(k >= 2 and k <= 8, lambda: f"k must be between 2 and 8, got {k}")
283+
n = A.numel()
284+
blocks = -(n // -blocksize)
285+
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
286+
out = torch.empty_like(A, dtype=torch.uint8)
287+
return out, absmax
288+
289+
290+
torch.library.define(
291+
"bitsandbytes::dequantize_blockwise_kbit",
292+
"(Tensor A, int k, Tensor absmax, Tensor code, int blocksize, ScalarType dtype) -> Tensor",
293+
)
294+
295+
296+
@register_fake("bitsandbytes::dequantize_blockwise_kbit")
297+
def _(A: torch.Tensor, k: int, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
298+
torch._check_is_size(blocksize)
299+
torch._check(k >= 2 and k <= 8, lambda: f"k must be between 2 and 8, got {k}")
300+
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
301+
return torch.empty_like(A, dtype=dtype)
302+
303+
304+
torch.library.define(
305+
"bitsandbytes::dequantize_blockwise_kbit.out",
306+
"(Tensor A, int k, Tensor absmax, Tensor code, int blocksize, ScalarType dtype, Tensor! out) -> ()",
307+
)
308+
309+
310+
@register_fake("bitsandbytes::dequantize_blockwise_kbit.out")
311+
def _(
312+
A: torch.Tensor, k: int, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
313+
):
314+
torch._check_is_size(blocksize)
315+
torch._check(k >= 2 and k <= 8, lambda: f"k must be between 2 and 8, got {k}")
316+
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
317+
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
318+
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
319+
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
320+
321+
276322
torch.library.define(
277323
"bitsandbytes::gemv_4bit",
278324
"(Tensor A, Tensor B, int[] shapeB, Tensor absmax, Tensor code, int blocksize) -> Tensor",

bitsandbytes/backends/cpu/ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,29 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
6767
return out, absmax
6868

6969

70+
@register_kernel("bitsandbytes::quantize_blockwise_kbit", "cpu")
71+
def _(A: torch.Tensor, k: int, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
72+
raise NotImplementedError("K-bit quantization is not implemented for CPU backend")
73+
74+
75+
@register_kernel("bitsandbytes::dequantize_blockwise_kbit", "cpu")
76+
def _(A: torch.Tensor, k: int, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
77+
raise NotImplementedError("K-bit dequantization is not implemented for CPU backend")
78+
79+
80+
@register_kernel("bitsandbytes::dequantize_blockwise_kbit.out", "cpu")
81+
def _(
82+
A: torch.Tensor,
83+
k: int,
84+
absmax: torch.Tensor,
85+
code: torch.Tensor,
86+
blocksize: int,
87+
dtype: torch.dtype,
88+
out: torch.Tensor,
89+
) -> None:
90+
raise NotImplementedError("K-bit dequantization is not implemented for CPU backend")
91+
92+
7093
@register_kernel("bitsandbytes::dequantize_blockwise", "cpu")
7194
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
7295
torch._check_is_size(blocksize)

bitsandbytes/backends/cuda/ops.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,102 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
245245
return out, absmax
246246

247247

248+
@register_kernel("bitsandbytes::quantize_blockwise_kbit", "cuda")
249+
def _(A: torch.Tensor, k: int, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
250+
torch._check(k >= 2 and k <= 8, lambda: f"k must be between 2 and 8, got {k}")
251+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
252+
torch._check(A.device.type == "cuda", lambda: "Input tensor must be on CUDA device")
253+
torch._check(code.device.type == "cuda", lambda: "Code tensor must be on CUDA device")
254+
torch._check(code.dtype == torch.float32, lambda: "Code must be float32")
255+
torch._check(A.is_contiguous(), lambda: "A must be contiguous")
256+
torch._check(code.is_contiguous(), lambda: "Code must be contiguous")
257+
258+
n = A.numel()
259+
blocks = -(n // -blocksize)
260+
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
261+
out = torch.zeros_like(A, dtype=torch.uint8)
262+
263+
with torch.cuda.device_of(A):
264+
args = (
265+
get_ptr(code),
266+
get_ptr(A),
267+
get_ptr(absmax),
268+
get_ptr(out),
269+
ct.c_int32(blocksize),
270+
ct.c_int(A.numel()),
271+
)
272+
273+
# Call the appropriate k-bit function based on dtype and k value
274+
if A.dtype == torch.float16:
275+
getattr(lib, f"cquantize_blockwise_fp16_k{k}")(*args)
276+
elif A.dtype == torch.bfloat16:
277+
getattr(lib, f"cquantize_blockwise_bf16_k{k}")(*args)
278+
elif A.dtype == torch.float32:
279+
getattr(lib, f"cquantize_blockwise_fp32_k{k}")(*args)
280+
else:
281+
raise ValueError(f"K-bit quantization only supports 16/32-bit floats, but got {A.dtype}")
282+
283+
return out, absmax
284+
285+
286+
@register_kernel("bitsandbytes::dequantize_blockwise_kbit", "cuda")
287+
def _(A: torch.Tensor, k: int, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
288+
torch._check(k >= 2 and k <= 8, lambda: f"k must be between 2 and 8, got {k}")
289+
out = torch.empty_like(A, dtype=dtype)
290+
_dequantize_blockwise_kbit_impl(A, k, absmax, code, blocksize, dtype, out=out)
291+
return out
292+
293+
294+
@register_kernel("bitsandbytes::dequantize_blockwise_kbit.out", "cuda")
295+
def _(
296+
A: torch.Tensor,
297+
k: int,
298+
absmax: torch.Tensor,
299+
code: torch.Tensor,
300+
blocksize: int,
301+
dtype: torch.dtype,
302+
out: torch.Tensor,
303+
) -> None:
304+
torch._check(k >= 2 and k <= 8, lambda: f"k must be between 2 and 8, got {k}")
305+
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
306+
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
307+
_dequantize_blockwise_kbit_impl(A, k, absmax, code, blocksize, dtype, out=out)
308+
309+
310+
def _dequantize_blockwise_kbit_impl(
311+
A: torch.Tensor, k: int, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
312+
) -> None:
313+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
314+
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
315+
torch._check(
316+
dtype in [torch.float16, torch.bfloat16, torch.float32],
317+
lambda: f"K-bit dequantization only supports 16/32-bit floats, but got {dtype}",
318+
)
319+
torch._check(absmax.is_contiguous(), lambda: "Absmax must be contiguous")
320+
torch._check(code.is_contiguous(), lambda: "Code must be contiguous")
321+
322+
with torch.cuda.device_of(A):
323+
args = (
324+
get_ptr(code),
325+
get_ptr(A),
326+
get_ptr(absmax),
327+
get_ptr(out),
328+
ct.c_int32(blocksize),
329+
ct.c_int(A.numel()),
330+
_get_tensor_stream(A),
331+
)
332+
333+
# Call the appropriate k-bit function based on dtype and k value
334+
if dtype == torch.float16:
335+
getattr(lib, f"cdequantize_blockwise_fp16_k{k}")(*args)
336+
elif dtype == torch.bfloat16:
337+
getattr(lib, f"cdequantize_blockwise_bf16_k{k}")(*args)
338+
elif dtype == torch.float32:
339+
getattr(lib, f"cdequantize_blockwise_fp32_k{k}")(*args)
340+
else:
341+
raise ValueError(f"K-bit dequantization only supports 16/32-bit floats, but got {dtype}")
342+
343+
248344
@register_kernel("bitsandbytes::dequantize_blockwise", "cuda")
249345
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
250346
out = torch.empty_like(A, dtype=dtype)

bitsandbytes/functional.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ class QuantState:
407407
"nested_blocksize",
408408
"nested_dtype",
409409
"nested_offset",
410+
"k",
410411
]
411412

412413
def __init__(
@@ -419,6 +420,7 @@ def __init__(
419420
dtype=None,
420421
offset=None,
421422
state2=None,
423+
k=None,
422424
):
423425
self.absmax = absmax
424426
self.shape = shape
@@ -428,6 +430,7 @@ def __init__(
428430
self.quant_type = quant_type
429431
self.offset = offset
430432
self.state2 = state2
433+
self.k = k
431434
self.nested = state2 is not None
432435

433436
def __getitem__(self, idx):
@@ -637,6 +640,81 @@ def quantize_blockwise(
637640
return out, quant_state
638641

639642

643+
def quantize_blockwise_kbit(
644+
A: torch.Tensor,
645+
k: int,
646+
code: Optional[torch.Tensor] = None,
647+
absmax: Optional[torch.Tensor] = None,
648+
out: Optional[torch.Tensor] = None,
649+
blocksize=4096,
650+
nested=False,
651+
) -> tuple[torch.Tensor, QuantState]:
652+
"""Quantize a tensor in blocks using k-bit quantization.
653+
654+
The input tensor is quantized by dividing it into blocks of `blocksize` values.
655+
The the absolute maximum value within these blocks is calculated for scaling
656+
the k-bit quantization.
657+
658+
Args:
659+
A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes.
660+
k (`int`): The number of bits for quantization (2-8).
661+
code (`torch.Tensor`, *optional*):
662+
A mapping describing the k-bit data type. If not provided, a linear map is created.
663+
absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values.
664+
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
665+
blocksize (`int`, *optional*):
666+
The size of the blocks. Defaults to 4096.
667+
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
668+
nested (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False.
669+
670+
Raises:
671+
ValueError: Raised when the input data type or k value is not supported.
672+
673+
Returns:
674+
`Tuple[torch.Tensor, QuantState]`: A tuple containing the quantization results.
675+
- `torch.Tensor`: The quantized tensor.
676+
- [`QuantState`]: The state object used to undo the quantization.
677+
"""
678+
if k < 2 or k > 8:
679+
raise ValueError(f"k must be between 2 and 8, got {k}")
680+
681+
if code is None:
682+
# Create a linear k-bit quantization map
683+
code = create_linear_map(signed=True, total_bits=k).to(A.device)
684+
685+
_out, _absmax = torch.ops.bitsandbytes.quantize_blockwise_kbit.default(
686+
A,
687+
k,
688+
code.to(A.device),
689+
blocksize,
690+
)
691+
692+
if nested:
693+
offset = _absmax.mean()
694+
_absmax -= offset
695+
qabsmax, state2 = quantize_blockwise(_absmax, blocksize=blocksize, nested=False)
696+
quant_state = QuantState(
697+
absmax=qabsmax,
698+
code=code.to(A.device, copy=True),
699+
blocksize=blocksize,
700+
dtype=A.dtype,
701+
offset=offset,
702+
state2=state2,
703+
k=k,
704+
)
705+
else:
706+
quant_state = QuantState(absmax=_absmax, code=code.to(A.device, copy=True), blocksize=blocksize, dtype=A.dtype, k=k)
707+
708+
# TODO(matthewdouglas): Deprecate out kwarg
709+
out = out.copy_(_out) if out is not None else _out
710+
711+
# TODO(matthewdouglas): Deprecate absmax kwarg
712+
if absmax is not None:
713+
quant_state.absmax = absmax.copy_(quant_state.absmax)
714+
715+
return out, quant_state
716+
717+
640718
def dequantize_blockwise(
641719
A: torch.Tensor,
642720
quant_state: Optional[QuantState] = None,
@@ -714,6 +792,91 @@ def dequantize_blockwise(
714792
)
715793

716794

795+
def dequantize_blockwise_kbit(
796+
A: torch.Tensor,
797+
k: int,
798+
quant_state: Optional[QuantState] = None,
799+
absmax: Optional[torch.Tensor] = None,
800+
code: Optional[torch.Tensor] = None,
801+
out: Optional[torch.Tensor] = None,
802+
blocksize: int = 4096,
803+
nested=False,
804+
) -> torch.Tensor:
805+
"""Dequantize a tensor in blocks using k-bit dequantization.
806+
807+
The input tensor is dequantized by dividing it into blocks of `blocksize` values.
808+
The the absolute maximum value within these blocks is used for scaling
809+
the k-bit dequantization.
810+
811+
Args:
812+
A (`torch.Tensor`): The quantized input tensor.
813+
k (`int`): The number of bits used for quantization (2-8).
814+
quant_state ([`QuantState`], *optional*):
815+
The quantization state as returned by [`quantize_blockwise_kbit`].
816+
Required if `absmax` is not provided.
817+
absmax (`torch.Tensor`, *optional*):
818+
A tensor containing the scaling values.
819+
Required if `quant_state` is not provided and ignored otherwise.
820+
code (`torch.Tensor`, *optional*):
821+
A mapping describing the k-bit data type. If not provided, a linear map is created.
822+
Ignored when `quant_state` is provided.
823+
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
824+
blocksize (`int`, *optional*):
825+
The size of the blocks. Defaults to 4096.
826+
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
827+
Ignored when `quant_state` is provided.
828+
829+
Raises:
830+
ValueError: Raised when the input data type or k value is not supported.
831+
832+
Returns:
833+
`torch.Tensor`:
834+
The dequantized tensor. The datatype is indicated by `quant_state.dtype` and defaults to `torch.float32`.
835+
"""
836+
if k < 2 or k > 8:
837+
raise ValueError(f"k must be between 2 and 8, got {k}")
838+
839+
assert quant_state is not None or absmax is not None
840+
if code is None and quant_state is None:
841+
# Create a linear k-bit quantization map
842+
code = create_linear_map(signed=True, total_bits=k).to(A.device)
843+
844+
if quant_state is None:
845+
quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32, k=k)
846+
847+
absmax = quant_state.absmax
848+
if quant_state.nested:
849+
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
850+
absmax += quant_state.offset
851+
if absmax.dtype != torch.float32:
852+
absmax = absmax.float()
853+
854+
# Get k from quant_state if available
855+
if hasattr(quant_state, 'k'):
856+
k = quant_state.k
857+
858+
if out is not None:
859+
torch.ops.bitsandbytes.dequantize_blockwise_kbit.out(
860+
A,
861+
k,
862+
absmax,
863+
quant_state.code.to(A.device),
864+
quant_state.blocksize,
865+
quant_state.dtype,
866+
out=out,
867+
)
868+
return out
869+
870+
return torch.ops.bitsandbytes.dequantize_blockwise_kbit.default(
871+
A,
872+
k,
873+
absmax,
874+
quant_state.code.to(A.device),
875+
quant_state.blocksize,
876+
quant_state.dtype,
877+
)
878+
879+
717880
def get_4bit_type(typename, device=None, blocksize=64):
718881
if device is None:
719882
device = "cuda"

0 commit comments

Comments
 (0)