|
| 1 | +from typing import TYPE_CHECKING, Any, Dict, List, Optional |
| 2 | + |
| 3 | +from ...utils import is_accelerate_available, is_torch_available, logging |
| 4 | +from ..base import DiffusersQuantizer |
| 5 | +from ...utils import get_module_from_name |
| 6 | + |
| 7 | + |
| 8 | +if is_torch_available(): |
| 9 | + import torch |
| 10 | + |
| 11 | +logger = logging.get_logger(__name__) |
| 12 | + |
| 13 | +if TYPE_CHECKING: |
| 14 | + from ...models.modeling_utils import ModelMixin |
| 15 | + |
| 16 | +class FinegrainedFP8Quantizer(DiffusersQuantizer): |
| 17 | + """ |
| 18 | + FP8 quantization implementation supporting both standard and MoE models. |
| 19 | + Supports both e4m3fn formats based on platform. |
| 20 | + """ |
| 21 | + |
| 22 | + requires_parameters_quantization = True |
| 23 | + requires_calibration = False |
| 24 | + required_packages = ["accelerate"] |
| 25 | + |
| 26 | + def __init__(self, quantization_config, **kwargs): |
| 27 | + super().__init__(quantization_config, **kwargs) |
| 28 | + self.quantization_config = quantization_config |
| 29 | + |
| 30 | + def validate_environment(self, *args, **kwargs): |
| 31 | + if not is_torch_available(): |
| 32 | + raise ImportError( |
| 33 | + "Using fp8 quantization requires torch >= 2.1.0" |
| 34 | + "Please install the latest version of torch ( pip install --upgrade torch )" |
| 35 | + ) |
| 36 | + |
| 37 | + if not is_accelerate_available(): |
| 38 | + raise ImportError("Loading an FP8 quantized model requires accelerate (`pip install accelerate`)") |
| 39 | + |
| 40 | + if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): |
| 41 | + raise ValueError( |
| 42 | + "Converting into FP8 weights from tf/flax weights is currently not supported, " |
| 43 | + "please make sure the weights are in PyTorch format." |
| 44 | + ) |
| 45 | + |
| 46 | + if torch.cuda.is_available(): |
| 47 | + compute_capability = torch.cuda.get_device_capability() |
| 48 | + major, minor = compute_capability |
| 49 | + if (major < 8) or (major == 8 and minor < 9): |
| 50 | + raise ValueError( |
| 51 | + "FP8 quantized models is only supported on GPUs with compute capability >= 8.9 (e.g 4090/H100)" |
| 52 | + f", actual = `{major}.{minor}`" |
| 53 | + ) |
| 54 | + |
| 55 | + device_map = kwargs.get("device_map", None) |
| 56 | + if device_map is None: |
| 57 | + logger.warning_once( |
| 58 | + "You have loaded an FP8 model on CPU and have a CUDA device available, make sure to set " |
| 59 | + "your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. " |
| 60 | + ) |
| 61 | + elif device_map is not None: |
| 62 | + if ( |
| 63 | + not self.pre_quantized |
| 64 | + and isinstance(device_map, dict) |
| 65 | + and ("cpu" in device_map.values() or "disk" in device_map.values()) |
| 66 | + ): |
| 67 | + raise ValueError( |
| 68 | + "You are attempting to load an FP8 model with a device_map that contains a cpu/disk device." |
| 69 | + "This is not supported when the model is quantized on the fly. " |
| 70 | + "Please use a quantized checkpoint or remove the cpu/disk device from the device_map." |
| 71 | + ) |
| 72 | + |
| 73 | + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": |
| 74 | + if torch_dtype is None: |
| 75 | + logger.info("Setting torch_dtype to torch.float32 as no torch_dtype was specified in from_pretrained") |
| 76 | + torch_dtype = torch.float32 |
| 77 | + return torch_dtype |
| 78 | + |
| 79 | + def create_quantized_param( |
| 80 | + self, |
| 81 | + model: "ModelMixin", |
| 82 | + param_value: "torch.Tensor", |
| 83 | + param_name: str, |
| 84 | + target_device: "torch.device", |
| 85 | + state_dict: Dict[str, Any], |
| 86 | + unexpected_keys: Optional[List[str]] = None, |
| 87 | + **kwargs, |
| 88 | + ): |
| 89 | + """ |
| 90 | + Quantizes weights to FP8 format using Block-wise quantization |
| 91 | + """ |
| 92 | + # print("############ create quantized param ########") |
| 93 | + from accelerate.utils import set_module_tensor_to_device |
| 94 | + |
| 95 | + set_module_tensor_to_device(model, param_name, target_device, param_value) |
| 96 | + |
| 97 | + module, tensor_name = get_module_from_name(model, param_name) |
| 98 | + |
| 99 | + # Get FP8 min/max values |
| 100 | + fp8_min = torch.finfo(torch.float8_e4m3fn).min |
| 101 | + fp8_max = torch.finfo(torch.float8_e4m3fn).max |
| 102 | + |
| 103 | + block_size_m, block_size_n = self.quantization_config.weight_block_size |
| 104 | + |
| 105 | + rows, cols = param_value.shape[-2:] |
| 106 | + |
| 107 | + if rows % block_size_m != 0 or cols % block_size_n != 0: |
| 108 | + raise ValueError( |
| 109 | + f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})" |
| 110 | + ) |
| 111 | + param_value_orig_shape = param_value.shape |
| 112 | + |
| 113 | + param_value = param_value.reshape( |
| 114 | + -1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n |
| 115 | + ).permute(0, 1, 3, 2, 4) |
| 116 | + |
| 117 | + # Calculate scaling factor for each block |
| 118 | + max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2)) |
| 119 | + scale = fp8_max / max_abs |
| 120 | + scale_orig_shape = scale.shape |
| 121 | + scale = scale.unsqueeze(-1).unsqueeze(-1) |
| 122 | + |
| 123 | + # Quantize the weights |
| 124 | + quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) |
| 125 | + |
| 126 | + quantized_param = quantized_param.permute(0, 1, 3, 2, 4) |
| 127 | + # Reshape back to matrix shape |
| 128 | + quantized_param = quantized_param.reshape(param_value_orig_shape) |
| 129 | + |
| 130 | + # Reshape scale to match the number of blocks |
| 131 | + scale = scale.reshape(scale_orig_shape).squeeze().reciprocal() |
| 132 | + |
| 133 | + # Load into the model |
| 134 | + module._buffers[tensor_name] = quantized_param.to(target_device) |
| 135 | + module._buffers["weight_scale_inv"] = scale.to(target_device) |
| 136 | + # print("_buffers[0]", module._buffers["weight_scale_inv"]) |
| 137 | + |
| 138 | + def check_if_quantized_param( |
| 139 | + self, |
| 140 | + model: "ModelMixin", |
| 141 | + param_value: "torch.Tensor", |
| 142 | + param_name: str, |
| 143 | + state_dict: Dict[str, Any], |
| 144 | + **kwargs, |
| 145 | + ): |
| 146 | + from .utils import FP8Linear |
| 147 | + |
| 148 | + module, tensor_name = get_module_from_name(model, param_name) |
| 149 | + if isinstance(module, FP8Linear): |
| 150 | + if self.pre_quantized or tensor_name == "bias": |
| 151 | + if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn: |
| 152 | + raise ValueError("Expect quantized weights but got an unquantized weight") |
| 153 | + return False |
| 154 | + else: |
| 155 | + if tensor_name == "weight_scale_inv": |
| 156 | + raise ValueError("Expect unquantized weights but got a quantized weight_scale") |
| 157 | + return True |
| 158 | + return False |
| 159 | + |
| 160 | + def _process_model_before_weight_loading( |
| 161 | + self, |
| 162 | + model: "ModelMixin", |
| 163 | + keep_in_fp32_modules: Optional[List[str]] = None, |
| 164 | + **kwargs, |
| 165 | + ): |
| 166 | + from .utils import replace_with_fp8_linear |
| 167 | + |
| 168 | + if self.quantization_config.modules_to_not_convert is not None: |
| 169 | + self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert) |
| 170 | + |
| 171 | + model = replace_with_fp8_linear( |
| 172 | + model, |
| 173 | + modules_to_not_convert=self.modules_to_not_convert, |
| 174 | + quantization_config=self.quantization_config, |
| 175 | + ) |
| 176 | + |
| 177 | + model.config.quantization_config = self.quantization_config |
| 178 | + |
| 179 | + def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): |
| 180 | + return model |
| 181 | + |
| 182 | + def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: |
| 183 | + from .utils import FP8Linear |
| 184 | + |
| 185 | + not_missing_keys = [] |
| 186 | + for name, module in model.named_modules(): |
| 187 | + if isinstance(module, FP8Linear): |
| 188 | + for missing in missing_keys: |
| 189 | + if ( |
| 190 | + (name in missing or name in f"{prefix}.{missing}") |
| 191 | + and not missing.endswith(".weight") |
| 192 | + and not missing.endswith(".bias") |
| 193 | + ): |
| 194 | + not_missing_keys.append(missing) |
| 195 | + return [k for k in missing_keys if k not in not_missing_keys] |
| 196 | + |
| 197 | + def is_serializable(self, safe_serialization=None): |
| 198 | + return True |
| 199 | + |
| 200 | + @property |
| 201 | + def is_trainable(self) -> bool: |
| 202 | + return False |
| 203 | + |
| 204 | + def get_cuda_warm_up_factor(self): |
| 205 | + # Pre-processing is done cleanly, so we can allocate everything here |
| 206 | + return 2 |
0 commit comments