Skip to content

Commit 7d7c74c

Browse files
authored
Merge branch 'main' into main
2 parents 3e53783 + 6a07ffe commit 7d7c74c

File tree

6 files changed

+638
-16
lines changed

6 files changed

+638
-16
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ jobs:
355355
pypi_index: "https://download.pytorch.org/whl/cu128"
356356
- cuda_version: "12.9.1"
357357
torch_version: "2.8.0"
358-
pypi_index: "https://download.pytorch.org/whl/test/cu129"
358+
pypi_index: "https://download.pytorch.org/whl/cu129"
359359

360360

361361
# Linux L40S runners

bitsandbytes/nn/parametrize.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
from functools import partial
2+
from typing import Any, Literal, Optional
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.utils.parametrize as P
7+
8+
from .. import functional as F
9+
10+
11+
class Bnb4bitParametrization(nn.Module):
12+
"""
13+
A parametrization module that handles dequantization of a 4-bit quantized parameter.
14+
15+
The parameter data is expected to be already quantized when this parametrization is applied.
16+
This module will dequantize the parameter data to its original floating-point representation
17+
when the forward method is called (i.e. when the parameter is accessed).
18+
19+
Args:
20+
quant_state (`F.QuantState`):
21+
The quantization state containing the necessary information for dequantization.
22+
"""
23+
24+
def __init__(self, quant_state: F.QuantState):
25+
super().__init__()
26+
self.quant_state = quant_state
27+
28+
@torch.no_grad()
29+
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
30+
"""
31+
Forward pass to dequantize the parameter.
32+
33+
Args:
34+
quantized_param (`torch.Tensor`): The quantized parameter tensor (from .original)
35+
36+
Returns:
37+
`torch.Tensor`: The dequantized parameter tensor in the original shape and dtype.
38+
"""
39+
return F.dequantize_4bit(quantized_param, self.quant_state)
40+
41+
42+
def replace_parameter_4bit_prequantized(
43+
module: nn.Module, param_name: str, qs_dict: dict[str, Any], device: torch.device
44+
):
45+
if not hasattr(module, param_name):
46+
raise AttributeError(f"Module does not have parameter '{param_name}'")
47+
48+
original_param = getattr(module, param_name)
49+
50+
if not isinstance(original_param, nn.Parameter):
51+
raise TypeError(f"Parameter '{param_name}' is not an instance of nn.Parameter")
52+
53+
quant_state = F.QuantState.from_dict(qs_dict, device=device)
54+
55+
# Apply a parametrization to the module to handle dequantization.
56+
P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state), unsafe=True)
57+
58+
# Next, register hooks.
59+
_register_parametrization_hooks(module, param_name)
60+
61+
62+
def replace_parameter_4bit(
63+
module: nn.Module,
64+
param_name: str,
65+
compress_statistics: bool = False,
66+
quant_type: Literal["nf4", "fp4"] = "nf4",
67+
blocksize: Optional[int] = None,
68+
):
69+
"""
70+
Replace a module parameter with a 4-bit quantized version using parametrization.
71+
72+
This function quantizes an existing parameter in a PyTorch module to 4-bit precision
73+
and sets up parametrization to handle automatic dequantization during forward passes.
74+
The original parameter is replaced with quantized data, and a parametrization layer
75+
is registered to manage the quantization state and dequantization process.
76+
77+
Additional, it registers a state dict post-hook to ensure that the quantization state
78+
is saved correctly when the model's state dict is saved.
79+
80+
It is useful for MoE models or other scenarios where you want to quantize parameters
81+
outside of nn.Linear layers without changing the model's architecture.
82+
83+
<Tip warning={true}>This feature is experimental and may change in future releases.</Tip>
84+
85+
Args:
86+
module (`nn.Module`):
87+
The PyTorch module containing the parameter to be quantized.
88+
param_name (`str`):
89+
The name of the parameter within the module to quantize.
90+
compress_statistics (`bool`, *optional*, defaults to `False`):
91+
Whether to compress quantization statistics to reduce memory usage.
92+
quant_type (`Literal["nf4", "fp4"]`, *optional*, defaults to `"nf4"`):
93+
The quantization format to use.
94+
blocksize (`int`, *optional*, defaults to `None`):
95+
The block size for quantization. If None, uses the default block size.
96+
97+
Raises:
98+
AttributeError: If the module does not have the specified parameter.
99+
TypeError: If the specified attribute is not an instance of nn.Parameter.
100+
"""
101+
102+
if not hasattr(module, param_name):
103+
raise AttributeError(f"Module does not have parameter '{param_name}'")
104+
105+
original_param = getattr(module, param_name)
106+
107+
if not isinstance(original_param, nn.Parameter):
108+
raise TypeError(f"Parameter '{param_name}' is not an instance of nn.Parameter")
109+
110+
# Quantize the original parameter.
111+
quantized_data, quant_state = F.quantize_4bit(
112+
original_param.data,
113+
blocksize=blocksize,
114+
compress_statistics=compress_statistics,
115+
quant_type=quant_type,
116+
)
117+
118+
# Replace the parameter with the quantized data.
119+
setattr(module, param_name, nn.Parameter(quantized_data, requires_grad=False))
120+
del original_param
121+
122+
# Apply a parametrization to the module to handle dequantization.
123+
P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state), unsafe=True)
124+
125+
# Next, register hooks.
126+
_register_parametrization_hooks(module, param_name)
127+
128+
129+
def _disable_parametrization_cache(module: nn.Module, inputs: tuple[Any, ...], output: Any):
130+
P._cache_enabled -= 1
131+
if not P._cache_enabled:
132+
P._cache = {}
133+
134+
135+
def _enable_parametrization_cache(module: nn.Module, inputs: tuple[Any, ...]):
136+
P._cache_enabled += 1
137+
138+
139+
def _register_parametrization_hooks(module: nn.Module, param_name: str):
140+
# Register a state dict hook for saving. Note that this requires torch >= 2.5.0.
141+
if torch.__version__ >= (2, 5):
142+
module.register_state_dict_post_hook(
143+
partial(
144+
_parametrized_state_dict_post_hook,
145+
param_name=param_name,
146+
)
147+
)
148+
149+
# Register hooks to enable caching for the dequantization parametrization.
150+
# This helps preserve time and memory when the same quantized parameter
151+
# is accessed multiple times in the forward computation.
152+
module.register_forward_pre_hook(_enable_parametrization_cache)
153+
module.register_forward_hook(_disable_parametrization_cache)
154+
155+
156+
def _parametrized_state_dict_post_hook(
157+
module: nn.Module,
158+
state_dict: dict[str, Any],
159+
prefix: str,
160+
local_metadata: Any,
161+
*,
162+
param_name: str = "weight",
163+
**kwargs: dict[str, Any],
164+
) -> None:
165+
"""
166+
Hook to modify the state dict to include the quantization state.
167+
"""
168+
169+
original_key = f"{prefix}parametrizations.{param_name}.original"
170+
171+
if original_key in state_dict:
172+
# Create a clean entry.
173+
# The `parametrizations.{param_name}.original` key will have the quantized data,
174+
# but we would like it to keep it in the state_dict as `{param_name}`.
175+
clean_key = f"{prefix}{param_name}"
176+
state_dict[clean_key] = state_dict.pop(original_key)
177+
178+
assert P.is_parametrized(module, param_name)
179+
180+
# Find the parametrization, which should have the quantization state.
181+
parametrization: Bnb4bitParametrization = next(
182+
filter(lambda x: isinstance(x, Bnb4bitParametrization), module.parametrizations[param_name]), None
183+
)
184+
185+
assert parametrization is not None, "Parametrization not found for the parameter."
186+
187+
quant_state = parametrization.quant_state
188+
189+
# Next, we need to store the quantization state.
190+
if quant_state is not None:
191+
for k, v in quant_state.as_dict(packed=True).items():
192+
state_dict[f"{prefix}{param_name}.{k}"] = v

tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def pytest_runtest_teardown(item, nextitem):
3434
gc.collect()
3535
if torch.cuda.is_available():
3636
torch.cuda.empty_cache()
37+
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
38+
torch.mps.empty_cache()
3739

3840

3941
@pytest.fixture(scope="session")

tests/test_functional.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import math
2+
import platform
23
import random
34
import time
45

56
import einops
6-
import numpy as np
7+
from packaging import version
78
import pytest
89
import torch
910

@@ -101,16 +102,16 @@ class Test8BitBlockwiseQuantizeFunctional:
101102
def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed):
102103
iters = 100
103104

104-
if device == "cpu":
105+
if device != "cuda":
105106
iters = 10
106107

107-
# This test is slow on CPU, so avoid atypical use cases.
108+
# This test is slow in our non-CUDA implementations, so avoid atypical use cases.
108109
if nested:
109110
pytest.skip("Not a typical use case.")
110111
if blocksize != 256:
111-
pytest.skip("Only blocksize 256 is used in CPU/XPU")
112+
pytest.skip("Only blocksize 256 is used in CPU/MPS/XPU")
112113
if dtype != torch.float32:
113-
pytest.skip("Only float32 is used in CPU/XPU")
114+
pytest.skip("Only float32 is used in CPU/MPS/XPU")
114115

115116
diffs = []
116117
reldiffs = []
@@ -239,7 +240,7 @@ def test_fp8_quant(self, device):
239240

240241
abserr = []
241242
relerr = []
242-
for i in range(100):
243+
for i in range(10):
243244
A1 = torch.randn(1024, 1024, device=device)
244245
C, SC = F.quantize_blockwise(A1, code=code)
245246
A2 = F.dequantize_blockwise(C, SC)
@@ -253,7 +254,7 @@ def test_fp8_quant(self, device):
253254

254255
abserr = []
255256
relerr = []
256-
for i in range(100):
257+
for i in range(10):
257258
A1 = torch.rand(1024, 1024, device=device)
258259
C, SC = F.quantize_blockwise(A1, code=code)
259260
A2 = F.dequantize_blockwise(C, SC)
@@ -267,7 +268,7 @@ def test_fp8_quant(self, device):
267268

268269
abserr = []
269270
relerr = []
270-
for i in range(100):
271+
for i in range(10):
271272
A1 = torch.randn(1024, 1024, device=device)
272273
C, SC = F.quantize_blockwise(A1)
273274
A2 = F.dequantize_blockwise(C, SC)
@@ -1169,8 +1170,12 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11691170
4096: 0.262457,
11701171
}
11711172

1172-
assert err < error_dict[quant_type]["err"][blocksize] + 1e-3
1173-
assert relerr < error_dict[quant_type]["rel_err"][blocksize] + 1e-3
1173+
# Allow higher tolerance for fp32 on CPU with larger block sizes
1174+
reltol = 2.8e-3 if dtype == torch.float32 and blocksize >= 128 and device == "cpu" else 1e-3
1175+
errtol = 1.2e-3 if dtype == torch.float32 and blocksize >= 1024 and device == "cpu" else 1e-3
1176+
1177+
assert err < error_dict[quant_type]["err"][blocksize] + errtol
1178+
assert relerr < error_dict[quant_type]["rel_err"][blocksize] + reltol
11741179

11751180
@pytest.mark.parametrize("device", get_available_devices())
11761181
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@@ -1403,28 +1408,29 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
14031408
@pytest.mark.parametrize("device", get_available_devices())
14041409
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
14051410
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
1406-
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
14071411
@pytest.mark.skipif(
14081412
HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a",
14091413
reason="this test is not supported on ROCm with gfx90a architecture yet",
14101414
)
1411-
def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant):
1415+
def test_gemv_eye_4bit(self, device, storage_type, dtype):
14121416
if device == "cpu" and dtype == torch.bfloat16 and torch.__version__ < (2, 3):
14131417
pytest.skip("eye doe not support bfloat16 on CPU in torch < 2.3")
14141418

14151419
if device == "hpu" and not is_supported_on_hpu(storage_type, dtype):
14161420
pytest.skip("This configuration is not supported on HPU.")
14171421

1418-
dims = 10
1419-
torch.random.manual_seed(np.random.randint(0, 412424242))
1422+
if device == "cpu" and platform.system() == "Windows" and version.parse(torch.__version__).release == (2, 8, 0):
1423+
pytest.skip("Regression: CPU crash on Windows with torch 2.8.0")
1424+
1425+
dims = 4
14201426
dims = get_test_dims(0, 8192, n=dims)
14211427
dims = [dim + (64 - (dim % 64)) for dim in dims]
14221428
# for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
14231429
for dim in dims:
14241430
A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device=device)
14251431
B = torch.eye(dim, dtype=dtype, device=device)
14261432

1427-
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
1433+
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=False)
14281434
C3 = torch.matmul(A, B.t())
14291435
C2 = bnb.matmul_4bit(A, qB.t(), state)
14301436
A.requires_grad = True

tests/test_optim.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ def rm_path(path):
172172
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True), ids=id_formatter("device"))
173173
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
174174
def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
175+
176+
if device not in ["cuda", "xpu"]:
177+
pytest.skip("Optimizers are only supported on CUDA and XPU")
178+
175179
if optim_name.startswith("paged_") and sys.platform == "win32":
176180
pytest.skip("Paged optimizers can have issues on Windows.")
177181

@@ -253,6 +257,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
253257
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
254258
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
255259
def test_global_config(dim1, dim2, gtype, device):
260+
if device not in ["cuda", "xpu"]:
261+
pytest.skip("Optimizers are only supported on CUDA and XPU")
262+
256263
if dim1 == 1 and dim2 == 1:
257264
return
258265
p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
@@ -310,6 +317,10 @@ def test_global_config(dim1, dim2, gtype, device):
310317
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
311318
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
312319
def test_optimizer8bit(dim1, dim2, gtype, optim_name, device):
320+
321+
if device not in ["cuda", "xpu"]:
322+
pytest.skip("8-bit optimizers are only supported on CUDA and XPU")
323+
313324
torch.set_printoptions(precision=6)
314325

315326
if dim1 == 1 and dim2 == 1:

0 commit comments

Comments
 (0)