Skip to content

Commit 82ee356

Browse files
Parametrize 4bit: replace existing prequantized weight
1 parent 4579891 commit 82ee356

File tree

1 file changed

+29
-2
lines changed

1 file changed

+29
-2
lines changed

bitsandbytes/nn/parametrize.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ class Bnb4bitParametrization(nn.Module):
2121
The quantization state containing the necessary information for dequantization.
2222
"""
2323

24-
def __init__(self, quant_state: F.QuantState):
24+
def __init__(self, quant_state: F.QuantState, p_name="unknown"):
2525
super().__init__()
2626
self.quant_state = quant_state
27+
self.p_name = p_name
2728

2829
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
2930
"""
@@ -35,9 +36,35 @@ def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
3536
Returns:
3637
`torch.Tensor`: The dequantized parameter tensor in the original shape and dtype.
3738
"""
39+
# print(f"Dequantizing parameter '{self.p_name}'")
3840
return F.dequantize_4bit(quantized_param, self.quant_state)
3941

4042

43+
def replace_parameter_4bit_prequantized(
44+
module: nn.Module, param_name: str, qs_dict: dict[str, Any], device: torch.device
45+
):
46+
if not hasattr(module, param_name):
47+
raise AttributeError(f"Module does not have parameter '{param_name}'")
48+
49+
original_param = getattr(module, param_name)
50+
51+
if not isinstance(original_param, nn.Parameter):
52+
raise TypeError(f"Parameter '{param_name}' is not an instance of nn.Parameter")
53+
54+
quant_state = F.QuantState.from_dict(qs_dict, device=device)
55+
56+
# Apply a parametrization to the module to handle dequantization.
57+
P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state, p_name=param_name), unsafe=True)
58+
59+
# Next, register state dict hook for saving.
60+
module.register_state_dict_post_hook(
61+
partial(
62+
_parametrized_state_dict_post_hook,
63+
param_name=param_name,
64+
)
65+
)
66+
67+
4168
def replace_parameter_4bit(
4269
module: nn.Module,
4370
param_name: str,
@@ -99,7 +126,7 @@ def replace_parameter_4bit(
99126
del original_param
100127

101128
# Apply a parametrization to the module to handle dequantization.
102-
P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state), unsafe=True)
129+
P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state, p_name=param_name), unsafe=True)
103130

104131
# Next, register state dict hook for saving.
105132
module.register_state_dict_post_hook(

0 commit comments

Comments
 (0)