|
| 1 | +from typing import TYPE_CHECKING, Any, Dict, List, Union |
| 2 | + |
| 3 | +from diffusers.utils.import_utils import is_nunchaku_version |
| 4 | + |
| 5 | +from ...utils import ( |
| 6 | + get_module_from_name, |
| 7 | + is_accelerate_available, |
| 8 | + is_nunchaku_available, |
| 9 | + is_torch_available, |
| 10 | + logging, |
| 11 | +) |
| 12 | +from ...utils.torch_utils import is_fp8_available |
| 13 | +from ..base import DiffusersQuantizer |
| 14 | + |
| 15 | + |
| 16 | +if TYPE_CHECKING: |
| 17 | + from ...models.modeling_utils import ModelMixin |
| 18 | + |
| 19 | + |
| 20 | +if is_torch_available(): |
| 21 | + import torch |
| 22 | + |
| 23 | +if is_accelerate_available(): |
| 24 | + pass |
| 25 | + |
| 26 | +if is_nunchaku_available(): |
| 27 | + from .utils import replace_with_nunchaku_linear |
| 28 | + |
| 29 | +logger = logging.get_logger(__name__) |
| 30 | + |
| 31 | + |
| 32 | +class QuantoQuantizer(DiffusersQuantizer): |
| 33 | + r""" |
| 34 | + Diffusers Quantizer for Optimum Quanto |
| 35 | + """ |
| 36 | + |
| 37 | + use_keep_in_fp32_modules = True |
| 38 | + requires_calibration = False |
| 39 | + required_packages = ["nunchaku", "accelerate"] |
| 40 | + |
| 41 | + dtype_map = {"int4": torch.int8} |
| 42 | + if is_fp8_available(): |
| 43 | + dtype_map = {"nvfp4": torch.float8_e4m3fn} |
| 44 | + |
| 45 | + def __init__(self, quantization_config, **kwargs): |
| 46 | + super().__init__(quantization_config, **kwargs) |
| 47 | + |
| 48 | + def validate_environment(self, *args, **kwargs): |
| 49 | + if not torch.cuda.is_available(): |
| 50 | + raise RuntimeError("No GPU found. A GPU is needed for nunchaku quantization.") |
| 51 | + |
| 52 | + if not is_nunchaku_available(): |
| 53 | + raise ImportError( |
| 54 | + "Loading an nunchaku quantized model requires nunchaku library (follow https://nunchaku.tech/docs/nunchaku/installation/installation.html)" |
| 55 | + ) |
| 56 | + if not is_nunchaku_version(">=", "0.3.1"): |
| 57 | + raise ImportError( |
| 58 | + "Loading an nunchaku quantized model requires `nunchaku>=1.0.0`. " |
| 59 | + "Please upgrade your installation by following https://nunchaku.tech/docs/nunchaku/installation/installation.html." |
| 60 | + ) |
| 61 | + |
| 62 | + if not is_accelerate_available(): |
| 63 | + raise ImportError( |
| 64 | + "Loading an nunchaku quantized model requires accelerate library (`pip install accelerate`)" |
| 65 | + ) |
| 66 | + |
| 67 | + # TODO: check |
| 68 | + # device_map = kwargs.get("device_map", None) |
| 69 | + # if isinstance(device_map, dict) and len(device_map.keys()) > 1: |
| 70 | + # raise ValueError( |
| 71 | + # "`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the Quanto backend" |
| 72 | + # ) |
| 73 | + |
| 74 | + def check_if_quantized_param( |
| 75 | + self, |
| 76 | + model: "ModelMixin", |
| 77 | + param_value: "torch.Tensor", |
| 78 | + param_name: str, |
| 79 | + state_dict: Dict[str, Any], |
| 80 | + **kwargs, |
| 81 | + ): |
| 82 | + # Quanto imports diffusers internally. This is here to prevent circular imports |
| 83 | + from nunchaku.models.linear import SVDQW4A4Linear |
| 84 | + |
| 85 | + module, tensor_name = get_module_from_name(model, param_name) |
| 86 | + if self.pre_quantized and isinstance(module, SVDQW4A4Linear): |
| 87 | + return True |
| 88 | + |
| 89 | + return False |
| 90 | + |
| 91 | + def create_quantized_param( |
| 92 | + self, |
| 93 | + model: "ModelMixin", |
| 94 | + param_value: "torch.Tensor", |
| 95 | + param_name: str, |
| 96 | + target_device: "torch.device", |
| 97 | + *args, |
| 98 | + **kwargs, |
| 99 | + ): |
| 100 | + """ |
| 101 | + Create a quantized parameter. |
| 102 | + """ |
| 103 | + from nunchaku.models.linear import SVDQW4A4Linear |
| 104 | + |
| 105 | + module, tensor_name = get_module_from_name(model, param_name) |
| 106 | + if tensor_name not in module._parameters and tensor_name not in module._buffers: |
| 107 | + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") |
| 108 | + |
| 109 | + if self.pre_quantized: |
| 110 | + if tensor_name in module._parameters: |
| 111 | + module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device)) |
| 112 | + if tensor_name in module._buffers: |
| 113 | + module._buffers[tensor_name] = torch.nn.Parameter(param_value.to(target_device)) |
| 114 | + |
| 115 | + elif isinstance(module, torch.nn.Linear): |
| 116 | + if tensor_name in module._parameters: |
| 117 | + module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) |
| 118 | + if tensor_name in module._buffers: |
| 119 | + module._buffers[tensor_name] = torch.nn.Parameter(param_value).to(target_device) |
| 120 | + |
| 121 | + new_module = SVDQW4A4Linear.from_linear(module) |
| 122 | + setattr(model, param_name, new_module) |
| 123 | + |
| 124 | + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: |
| 125 | + max_memory = {key: val * 0.90 for key, val in max_memory.items()} |
| 126 | + return max_memory |
| 127 | + |
| 128 | + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": |
| 129 | + precision = self.quantization_config.precision |
| 130 | + expected_target_dtypes = [torch.int8] |
| 131 | + if is_fp8_available(): |
| 132 | + expected_target_dtypes.append(torch.float8_e4m3fn) |
| 133 | + if target_dtype not in expected_target_dtypes: |
| 134 | + new_target_dtype = self.dtype_map[precision] |
| 135 | + |
| 136 | + logger.info(f"target_dtype {target_dtype} is replaced by {new_target_dtype} for `nunchaku` quantization") |
| 137 | + return new_target_dtype |
| 138 | + else: |
| 139 | + raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.") |
| 140 | + |
| 141 | + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": |
| 142 | + if torch_dtype is None: |
| 143 | + # We force the `dtype` to be bfloat16, this is a requirement from `bitsandbytes` |
| 144 | + logger.info( |
| 145 | + "Overriding torch_dtype=%s with `torch_dtype=torch.bfloat16` due to " |
| 146 | + "requirements of `nunchaku` to enable model loading in 4-bit. " |
| 147 | + "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass" |
| 148 | + " torch_dtype=torch.bfloat16 to remove this warning.", |
| 149 | + torch_dtype, |
| 150 | + ) |
| 151 | + torch_dtype = torch.bfloat16 |
| 152 | + return torch_dtype |
| 153 | + |
| 154 | + def _process_model_before_weight_loading( |
| 155 | + self, |
| 156 | + model: "ModelMixin", |
| 157 | + device_map, |
| 158 | + keep_in_fp32_modules: List[str] = [], |
| 159 | + **kwargs, |
| 160 | + ): |
| 161 | + # TODO: deal with `device_map` |
| 162 | + self.modules_to_not_convert = self.quantization_config.modules_to_not_convert |
| 163 | + |
| 164 | + if not isinstance(self.modules_to_not_convert, list): |
| 165 | + self.modules_to_not_convert = [self.modules_to_not_convert] |
| 166 | + |
| 167 | + self.modules_to_not_convert.extend(keep_in_fp32_modules) |
| 168 | + |
| 169 | + model = replace_with_nunchaku_linear( |
| 170 | + model, |
| 171 | + modules_to_not_convert=self.modules_to_not_convert, |
| 172 | + quantization_config=self.quantization_config, |
| 173 | + pre_quantized=self.pre_quantized, |
| 174 | + ) |
| 175 | + model.config.quantization_config = self.quantization_config |
| 176 | + |
| 177 | + def _process_model_after_weight_loading(self, model, **kwargs): |
| 178 | + return model |
| 179 | + |
| 180 | + # @property |
| 181 | + # def is_serializable(self): |
| 182 | + # return True |
0 commit comments