|
| 1 | +from functools import partial |
| 2 | +from typing import Any, Literal, Optional |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn as nn |
| 6 | +import torch.nn.utils.parametrize as P |
| 7 | + |
| 8 | +from .. import functional as F |
| 9 | + |
| 10 | + |
| 11 | +class Bnb4bitParametrization(nn.Module): |
| 12 | + """ |
| 13 | + A parametrization module that handles dequantization of a 4-bit quantized parameter. |
| 14 | +
|
| 15 | + The parameter data is expected to be already quantized when this parametrization is applied. |
| 16 | + This module will dequantize the parameter data to its original floating-point representation |
| 17 | + when the forward method is called (i.e. when the parameter is accessed). |
| 18 | +
|
| 19 | + Args: |
| 20 | + quant_state (`F.QuantState`): |
| 21 | + The quantization state containing the necessary information for dequantization. |
| 22 | + """ |
| 23 | + |
| 24 | + def __init__(self, quant_state: F.QuantState): |
| 25 | + super().__init__() |
| 26 | + self.quant_state = quant_state |
| 27 | + |
| 28 | + @torch.no_grad() |
| 29 | + def forward(self, quantized_param: torch.Tensor) -> torch.Tensor: |
| 30 | + """ |
| 31 | + Forward pass to dequantize the parameter. |
| 32 | +
|
| 33 | + Args: |
| 34 | + quantized_param (`torch.Tensor`): The quantized parameter tensor (from .original) |
| 35 | +
|
| 36 | + Returns: |
| 37 | + `torch.Tensor`: The dequantized parameter tensor in the original shape and dtype. |
| 38 | + """ |
| 39 | + return F.dequantize_4bit(quantized_param, self.quant_state) |
| 40 | + |
| 41 | + |
| 42 | +def replace_parameter_4bit_prequantized( |
| 43 | + module: nn.Module, param_name: str, qs_dict: dict[str, Any], device: torch.device |
| 44 | +): |
| 45 | + if not hasattr(module, param_name): |
| 46 | + raise AttributeError(f"Module does not have parameter '{param_name}'") |
| 47 | + |
| 48 | + original_param = getattr(module, param_name) |
| 49 | + |
| 50 | + if not isinstance(original_param, nn.Parameter): |
| 51 | + raise TypeError(f"Parameter '{param_name}' is not an instance of nn.Parameter") |
| 52 | + |
| 53 | + quant_state = F.QuantState.from_dict(qs_dict, device=device) |
| 54 | + |
| 55 | + # Apply a parametrization to the module to handle dequantization. |
| 56 | + P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state), unsafe=True) |
| 57 | + |
| 58 | + # Next, register hooks. |
| 59 | + _register_parametrization_hooks(module, param_name) |
| 60 | + |
| 61 | + |
| 62 | +def replace_parameter_4bit( |
| 63 | + module: nn.Module, |
| 64 | + param_name: str, |
| 65 | + compress_statistics: bool = False, |
| 66 | + quant_type: Literal["nf4", "fp4"] = "nf4", |
| 67 | + blocksize: Optional[int] = None, |
| 68 | +): |
| 69 | + """ |
| 70 | + Replace a module parameter with a 4-bit quantized version using parametrization. |
| 71 | +
|
| 72 | + This function quantizes an existing parameter in a PyTorch module to 4-bit precision |
| 73 | + and sets up parametrization to handle automatic dequantization during forward passes. |
| 74 | + The original parameter is replaced with quantized data, and a parametrization layer |
| 75 | + is registered to manage the quantization state and dequantization process. |
| 76 | +
|
| 77 | + Additional, it registers a state dict post-hook to ensure that the quantization state |
| 78 | + is saved correctly when the model's state dict is saved. |
| 79 | +
|
| 80 | + It is useful for MoE models or other scenarios where you want to quantize parameters |
| 81 | + outside of nn.Linear layers without changing the model's architecture. |
| 82 | +
|
| 83 | + <Tip warning={true}>This feature is experimental and may change in future releases.</Tip> |
| 84 | +
|
| 85 | + Args: |
| 86 | + module (`nn.Module`): |
| 87 | + The PyTorch module containing the parameter to be quantized. |
| 88 | + param_name (`str`): |
| 89 | + The name of the parameter within the module to quantize. |
| 90 | + compress_statistics (`bool`, *optional*, defaults to `False`): |
| 91 | + Whether to compress quantization statistics to reduce memory usage. |
| 92 | + quant_type (`Literal["nf4", "fp4"]`, *optional*, defaults to `"nf4"`): |
| 93 | + The quantization format to use. |
| 94 | + blocksize (`int`, *optional*, defaults to `None`): |
| 95 | + The block size for quantization. If None, uses the default block size. |
| 96 | +
|
| 97 | + Raises: |
| 98 | + AttributeError: If the module does not have the specified parameter. |
| 99 | + TypeError: If the specified attribute is not an instance of nn.Parameter. |
| 100 | + """ |
| 101 | + |
| 102 | + if not hasattr(module, param_name): |
| 103 | + raise AttributeError(f"Module does not have parameter '{param_name}'") |
| 104 | + |
| 105 | + original_param = getattr(module, param_name) |
| 106 | + |
| 107 | + if not isinstance(original_param, nn.Parameter): |
| 108 | + raise TypeError(f"Parameter '{param_name}' is not an instance of nn.Parameter") |
| 109 | + |
| 110 | + # Quantize the original parameter. |
| 111 | + quantized_data, quant_state = F.quantize_4bit( |
| 112 | + original_param.data, |
| 113 | + blocksize=blocksize, |
| 114 | + compress_statistics=compress_statistics, |
| 115 | + quant_type=quant_type, |
| 116 | + ) |
| 117 | + |
| 118 | + # Replace the parameter with the quantized data. |
| 119 | + setattr(module, param_name, nn.Parameter(quantized_data, requires_grad=False)) |
| 120 | + del original_param |
| 121 | + |
| 122 | + # Apply a parametrization to the module to handle dequantization. |
| 123 | + P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state), unsafe=True) |
| 124 | + |
| 125 | + # Next, register hooks. |
| 126 | + _register_parametrization_hooks(module, param_name) |
| 127 | + |
| 128 | + |
| 129 | +def _disable_parametrization_cache(module: nn.Module, inputs: tuple[Any, ...], output: Any): |
| 130 | + P._cache_enabled -= 1 |
| 131 | + if not P._cache_enabled: |
| 132 | + P._cache = {} |
| 133 | + |
| 134 | + |
| 135 | +def _enable_parametrization_cache(module: nn.Module, inputs: tuple[Any, ...]): |
| 136 | + P._cache_enabled += 1 |
| 137 | + |
| 138 | + |
| 139 | +def _register_parametrization_hooks(module: nn.Module, param_name: str): |
| 140 | + # Register a state dict hook for saving. Note that this requires torch >= 2.5.0. |
| 141 | + if torch.__version__ >= (2, 5): |
| 142 | + module.register_state_dict_post_hook( |
| 143 | + partial( |
| 144 | + _parametrized_state_dict_post_hook, |
| 145 | + param_name=param_name, |
| 146 | + ) |
| 147 | + ) |
| 148 | + |
| 149 | + # Register hooks to enable caching for the dequantization parametrization. |
| 150 | + # This helps preserve time and memory when the same quantized parameter |
| 151 | + # is accessed multiple times in the forward computation. |
| 152 | + module.register_forward_pre_hook(_enable_parametrization_cache) |
| 153 | + module.register_forward_hook(_disable_parametrization_cache) |
| 154 | + |
| 155 | + |
| 156 | +def _parametrized_state_dict_post_hook( |
| 157 | + module: nn.Module, |
| 158 | + state_dict: dict[str, Any], |
| 159 | + prefix: str, |
| 160 | + local_metadata: Any, |
| 161 | + *, |
| 162 | + param_name: str = "weight", |
| 163 | + **kwargs: dict[str, Any], |
| 164 | +) -> None: |
| 165 | + """ |
| 166 | + Hook to modify the state dict to include the quantization state. |
| 167 | + """ |
| 168 | + |
| 169 | + original_key = f"{prefix}parametrizations.{param_name}.original" |
| 170 | + |
| 171 | + if original_key in state_dict: |
| 172 | + # Create a clean entry. |
| 173 | + # The `parametrizations.{param_name}.original` key will have the quantized data, |
| 174 | + # but we would like it to keep it in the state_dict as `{param_name}`. |
| 175 | + clean_key = f"{prefix}{param_name}" |
| 176 | + state_dict[clean_key] = state_dict.pop(original_key) |
| 177 | + |
| 178 | + assert P.is_parametrized(module, param_name) |
| 179 | + |
| 180 | + # Find the parametrization, which should have the quantization state. |
| 181 | + parametrization: Bnb4bitParametrization = next( |
| 182 | + filter(lambda x: isinstance(x, Bnb4bitParametrization), module.parametrizations[param_name]), None |
| 183 | + ) |
| 184 | + |
| 185 | + assert parametrization is not None, "Parametrization not found for the parameter." |
| 186 | + |
| 187 | + quant_state = parametrization.quant_state |
| 188 | + |
| 189 | + # Next, we need to store the quantization state. |
| 190 | + if quant_state is not None: |
| 191 | + for k, v in quant_state.as_dict(packed=True).items(): |
| 192 | + state_dict[f"{prefix}{param_name}.{k}"] = v |
0 commit comments