Skip to content

Commit 27549fb

Browse files
4bit quantization for arbitrary nn.Parameter (#1720)
* Add parametrize util for targeting parameters outside of nn.Linear modules * Parametrize 4bit: replace existing prequantized weight * cleanup * Add caching for parametrization * Add tests * Fix tests * Guard for torch < 2.5 * Guard for torch < 2.5 * Another test gaurd for torch >= 2.5
1 parent 39dd847 commit 27549fb

File tree

2 files changed

+603
-0
lines changed

2 files changed

+603
-0
lines changed

bitsandbytes/nn/parametrize.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
from functools import partial
2+
from typing import Any, Literal, Optional
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.utils.parametrize as P
7+
8+
from .. import functional as F
9+
10+
11+
class Bnb4bitParametrization(nn.Module):
12+
"""
13+
A parametrization module that handles dequantization of a 4-bit quantized parameter.
14+
15+
The parameter data is expected to be already quantized when this parametrization is applied.
16+
This module will dequantize the parameter data to its original floating-point representation
17+
when the forward method is called (i.e. when the parameter is accessed).
18+
19+
Args:
20+
quant_state (`F.QuantState`):
21+
The quantization state containing the necessary information for dequantization.
22+
"""
23+
24+
def __init__(self, quant_state: F.QuantState):
25+
super().__init__()
26+
self.quant_state = quant_state
27+
28+
@torch.no_grad()
29+
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
30+
"""
31+
Forward pass to dequantize the parameter.
32+
33+
Args:
34+
quantized_param (`torch.Tensor`): The quantized parameter tensor (from .original)
35+
36+
Returns:
37+
`torch.Tensor`: The dequantized parameter tensor in the original shape and dtype.
38+
"""
39+
return F.dequantize_4bit(quantized_param, self.quant_state)
40+
41+
42+
def replace_parameter_4bit_prequantized(
43+
module: nn.Module, param_name: str, qs_dict: dict[str, Any], device: torch.device
44+
):
45+
if not hasattr(module, param_name):
46+
raise AttributeError(f"Module does not have parameter '{param_name}'")
47+
48+
original_param = getattr(module, param_name)
49+
50+
if not isinstance(original_param, nn.Parameter):
51+
raise TypeError(f"Parameter '{param_name}' is not an instance of nn.Parameter")
52+
53+
quant_state = F.QuantState.from_dict(qs_dict, device=device)
54+
55+
# Apply a parametrization to the module to handle dequantization.
56+
P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state), unsafe=True)
57+
58+
# Next, register hooks.
59+
_register_parametrization_hooks(module, param_name)
60+
61+
62+
def replace_parameter_4bit(
63+
module: nn.Module,
64+
param_name: str,
65+
compress_statistics: bool = False,
66+
quant_type: Literal["nf4", "fp4"] = "nf4",
67+
blocksize: Optional[int] = None,
68+
):
69+
"""
70+
Replace a module parameter with a 4-bit quantized version using parametrization.
71+
72+
This function quantizes an existing parameter in a PyTorch module to 4-bit precision
73+
and sets up parametrization to handle automatic dequantization during forward passes.
74+
The original parameter is replaced with quantized data, and a parametrization layer
75+
is registered to manage the quantization state and dequantization process.
76+
77+
Additional, it registers a state dict post-hook to ensure that the quantization state
78+
is saved correctly when the model's state dict is saved.
79+
80+
It is useful for MoE models or other scenarios where you want to quantize parameters
81+
outside of nn.Linear layers without changing the model's architecture.
82+
83+
<Tip warning={true}>This feature is experimental and may change in future releases.</Tip>
84+
85+
Args:
86+
module (`nn.Module`):
87+
The PyTorch module containing the parameter to be quantized.
88+
param_name (`str`):
89+
The name of the parameter within the module to quantize.
90+
compress_statistics (`bool`, *optional*, defaults to `False`):
91+
Whether to compress quantization statistics to reduce memory usage.
92+
quant_type (`Literal["nf4", "fp4"]`, *optional*, defaults to `"nf4"`):
93+
The quantization format to use.
94+
blocksize (`int`, *optional*, defaults to `None`):
95+
The block size for quantization. If None, uses the default block size.
96+
97+
Raises:
98+
AttributeError: If the module does not have the specified parameter.
99+
TypeError: If the specified attribute is not an instance of nn.Parameter.
100+
"""
101+
102+
if not hasattr(module, param_name):
103+
raise AttributeError(f"Module does not have parameter '{param_name}'")
104+
105+
original_param = getattr(module, param_name)
106+
107+
if not isinstance(original_param, nn.Parameter):
108+
raise TypeError(f"Parameter '{param_name}' is not an instance of nn.Parameter")
109+
110+
# Quantize the original parameter.
111+
quantized_data, quant_state = F.quantize_4bit(
112+
original_param.data,
113+
blocksize=blocksize,
114+
compress_statistics=compress_statistics,
115+
quant_type=quant_type,
116+
)
117+
118+
# Replace the parameter with the quantized data.
119+
setattr(module, param_name, nn.Parameter(quantized_data, requires_grad=False))
120+
del original_param
121+
122+
# Apply a parametrization to the module to handle dequantization.
123+
P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state), unsafe=True)
124+
125+
# Next, register hooks.
126+
_register_parametrization_hooks(module, param_name)
127+
128+
129+
def _disable_parametrization_cache(module: nn.Module, inputs: tuple[Any, ...], output: Any):
130+
P._cache_enabled -= 1
131+
if not P._cache_enabled:
132+
P._cache = {}
133+
134+
135+
def _enable_parametrization_cache(module: nn.Module, inputs: tuple[Any, ...]):
136+
P._cache_enabled += 1
137+
138+
139+
def _register_parametrization_hooks(module: nn.Module, param_name: str):
140+
# Register a state dict hook for saving. Note that this requires torch >= 2.5.0.
141+
if torch.__version__ >= (2, 5):
142+
module.register_state_dict_post_hook(
143+
partial(
144+
_parametrized_state_dict_post_hook,
145+
param_name=param_name,
146+
)
147+
)
148+
149+
# Register hooks to enable caching for the dequantization parametrization.
150+
# This helps preserve time and memory when the same quantized parameter
151+
# is accessed multiple times in the forward computation.
152+
module.register_forward_pre_hook(_enable_parametrization_cache)
153+
module.register_forward_hook(_disable_parametrization_cache)
154+
155+
156+
def _parametrized_state_dict_post_hook(
157+
module: nn.Module,
158+
state_dict: dict[str, Any],
159+
prefix: str,
160+
local_metadata: Any,
161+
*,
162+
param_name: str = "weight",
163+
**kwargs: dict[str, Any],
164+
) -> None:
165+
"""
166+
Hook to modify the state dict to include the quantization state.
167+
"""
168+
169+
original_key = f"{prefix}parametrizations.{param_name}.original"
170+
171+
if original_key in state_dict:
172+
# Create a clean entry.
173+
# The `parametrizations.{param_name}.original` key will have the quantized data,
174+
# but we would like it to keep it in the state_dict as `{param_name}`.
175+
clean_key = f"{prefix}{param_name}"
176+
state_dict[clean_key] = state_dict.pop(original_key)
177+
178+
assert P.is_parametrized(module, param_name)
179+
180+
# Find the parametrization, which should have the quantization state.
181+
parametrization: Bnb4bitParametrization = next(
182+
filter(lambda x: isinstance(x, Bnb4bitParametrization), module.parametrizations[param_name]), None
183+
)
184+
185+
assert parametrization is not None, "Parametrization not found for the parameter."
186+
187+
quant_state = parametrization.quant_state
188+
189+
# Next, we need to store the quantization state.
190+
if quant_state is not None:
191+
for k, v in quant_state.as_dict(packed=True).items():
192+
state_dict[f"{prefix}{param_name}.{k}"] = v

0 commit comments

Comments
 (0)