|  | 
|  | 1 | +from typing import TYPE_CHECKING, Any, Dict, List, Optional | 
|  | 2 | + | 
|  | 3 | +from ...utils import is_accelerate_available, is_torch_available, logging | 
|  | 4 | +from ..base import DiffusersQuantizer | 
|  | 5 | +from ...utils import get_module_from_name | 
|  | 6 | + | 
|  | 7 | + | 
|  | 8 | +if is_torch_available(): | 
|  | 9 | +    import torch | 
|  | 10 | + | 
|  | 11 | +logger = logging.get_logger(__name__) | 
|  | 12 | + | 
|  | 13 | +if TYPE_CHECKING: | 
|  | 14 | +    from ...models.modeling_utils import ModelMixin | 
|  | 15 | + | 
|  | 16 | +class FinegrainedFP8Quantizer(DiffusersQuantizer): | 
|  | 17 | +    """ | 
|  | 18 | +    FP8 quantization implementation supporting both standard and MoE models. | 
|  | 19 | +    Supports both e4m3fn formats based on platform. | 
|  | 20 | +    """ | 
|  | 21 | + | 
|  | 22 | +    requires_parameters_quantization = True | 
|  | 23 | +    requires_calibration = False | 
|  | 24 | +    required_packages = ["accelerate"] | 
|  | 25 | + | 
|  | 26 | +    def __init__(self, quantization_config, **kwargs): | 
|  | 27 | +        super().__init__(quantization_config, **kwargs) | 
|  | 28 | +        self.quantization_config = quantization_config | 
|  | 29 | + | 
|  | 30 | +    def validate_environment(self, *args, **kwargs): | 
|  | 31 | +        if not is_torch_available(): | 
|  | 32 | +            raise ImportError( | 
|  | 33 | +                "Using fp8 quantization requires torch >= 2.1.0" | 
|  | 34 | +                "Please install the latest version of torch ( pip install --upgrade torch )" | 
|  | 35 | +            ) | 
|  | 36 | + | 
|  | 37 | +        if not is_accelerate_available(): | 
|  | 38 | +            raise ImportError("Loading an FP8 quantized model requires accelerate (`pip install accelerate`)") | 
|  | 39 | + | 
|  | 40 | +        if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): | 
|  | 41 | +            raise ValueError( | 
|  | 42 | +                "Converting into FP8 weights from tf/flax weights is currently not supported, " | 
|  | 43 | +                "please make sure the weights are in PyTorch format." | 
|  | 44 | +            ) | 
|  | 45 | + | 
|  | 46 | +        if torch.cuda.is_available(): | 
|  | 47 | +            compute_capability = torch.cuda.get_device_capability() | 
|  | 48 | +            major, minor = compute_capability | 
|  | 49 | +            if (major < 8) or (major == 8 and minor < 9): | 
|  | 50 | +                raise ValueError( | 
|  | 51 | +                    "FP8 quantized models is only supported on GPUs with compute capability >= 8.9 (e.g 4090/H100)" | 
|  | 52 | +                    f", actual = `{major}.{minor}`" | 
|  | 53 | +                ) | 
|  | 54 | + | 
|  | 55 | +        device_map = kwargs.get("device_map", None) | 
|  | 56 | +        if device_map is None: | 
|  | 57 | +            logger.warning_once( | 
|  | 58 | +                "You have loaded an FP8 model on CPU and have a CUDA device available, make sure to set " | 
|  | 59 | +                "your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. " | 
|  | 60 | +            ) | 
|  | 61 | +        elif device_map is not None: | 
|  | 62 | +            if ( | 
|  | 63 | +                not self.pre_quantized | 
|  | 64 | +                and isinstance(device_map, dict) | 
|  | 65 | +                and ("cpu" in device_map.values() or "disk" in device_map.values()) | 
|  | 66 | +            ): | 
|  | 67 | +                raise ValueError( | 
|  | 68 | +                    "You are attempting to load an FP8 model with a device_map that contains a cpu/disk device." | 
|  | 69 | +                    "This is not supported when the model is quantized on the fly. " | 
|  | 70 | +                    "Please use a quantized checkpoint or remove the cpu/disk device from the device_map." | 
|  | 71 | +                ) | 
|  | 72 | + | 
|  | 73 | +    def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": | 
|  | 74 | +        if torch_dtype is None: | 
|  | 75 | +            logger.info("Setting torch_dtype to torch.float32 as no torch_dtype was specified in from_pretrained") | 
|  | 76 | +            torch_dtype = torch.float32 | 
|  | 77 | +        return torch_dtype | 
|  | 78 | + | 
|  | 79 | +    def create_quantized_param( | 
|  | 80 | +        self, | 
|  | 81 | +        model: "ModelMixin", | 
|  | 82 | +        param_value: "torch.Tensor", | 
|  | 83 | +        param_name: str, | 
|  | 84 | +        target_device: "torch.device", | 
|  | 85 | +        state_dict: Dict[str, Any], | 
|  | 86 | +        unexpected_keys: Optional[List[str]] = None, | 
|  | 87 | +        **kwargs, | 
|  | 88 | +    ): | 
|  | 89 | +        """ | 
|  | 90 | +        Quantizes weights to FP8 format using Block-wise quantization | 
|  | 91 | +        """ | 
|  | 92 | +        # print("############ create quantized param ########") | 
|  | 93 | +        from accelerate.utils import set_module_tensor_to_device | 
|  | 94 | + | 
|  | 95 | +        set_module_tensor_to_device(model, param_name, target_device, param_value) | 
|  | 96 | + | 
|  | 97 | +        module, tensor_name = get_module_from_name(model, param_name) | 
|  | 98 | + | 
|  | 99 | +        # Get FP8 min/max values | 
|  | 100 | +        fp8_min = torch.finfo(torch.float8_e4m3fn).min | 
|  | 101 | +        fp8_max = torch.finfo(torch.float8_e4m3fn).max | 
|  | 102 | + | 
|  | 103 | +        block_size_m, block_size_n = self.quantization_config.weight_block_size | 
|  | 104 | + | 
|  | 105 | +        rows, cols = param_value.shape[-2:] | 
|  | 106 | + | 
|  | 107 | +        if rows % block_size_m != 0 or cols % block_size_n != 0: | 
|  | 108 | +            raise ValueError( | 
|  | 109 | +                f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})" | 
|  | 110 | +            ) | 
|  | 111 | +        param_value_orig_shape = param_value.shape | 
|  | 112 | + | 
|  | 113 | +        param_value = param_value.reshape( | 
|  | 114 | +            -1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n | 
|  | 115 | +        ).permute(0, 1, 3, 2, 4) | 
|  | 116 | + | 
|  | 117 | +        # Calculate scaling factor for each block | 
|  | 118 | +        max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2)) | 
|  | 119 | +        scale = fp8_max / max_abs | 
|  | 120 | +        scale_orig_shape = scale.shape | 
|  | 121 | +        scale = scale.unsqueeze(-1).unsqueeze(-1) | 
|  | 122 | + | 
|  | 123 | +        # Quantize the weights | 
|  | 124 | +        quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) | 
|  | 125 | + | 
|  | 126 | +        quantized_param = quantized_param.permute(0, 1, 3, 2, 4) | 
|  | 127 | +        # Reshape back to matrix shape | 
|  | 128 | +        quantized_param = quantized_param.reshape(param_value_orig_shape) | 
|  | 129 | + | 
|  | 130 | +        # Reshape scale to match the number of blocks | 
|  | 131 | +        scale = scale.reshape(scale_orig_shape).squeeze().reciprocal() | 
|  | 132 | + | 
|  | 133 | +        # Load into the model | 
|  | 134 | +        module._buffers[tensor_name] = quantized_param.to(target_device) | 
|  | 135 | +        module._buffers["weight_scale_inv"] = scale.to(target_device) | 
|  | 136 | +        # print("_buffers[0]", module._buffers["weight_scale_inv"]) | 
|  | 137 | + | 
|  | 138 | +    def check_if_quantized_param( | 
|  | 139 | +        self, | 
|  | 140 | +        model: "ModelMixin", | 
|  | 141 | +        param_value: "torch.Tensor", | 
|  | 142 | +        param_name: str, | 
|  | 143 | +        state_dict: Dict[str, Any], | 
|  | 144 | +        **kwargs, | 
|  | 145 | +    ): | 
|  | 146 | +        from .utils import FP8Linear | 
|  | 147 | + | 
|  | 148 | +        module, tensor_name = get_module_from_name(model, param_name) | 
|  | 149 | +        if isinstance(module, FP8Linear): | 
|  | 150 | +            if self.pre_quantized or tensor_name == "bias": | 
|  | 151 | +                if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn: | 
|  | 152 | +                    raise ValueError("Expect quantized weights but got an unquantized weight") | 
|  | 153 | +                return False | 
|  | 154 | +            else: | 
|  | 155 | +                if tensor_name == "weight_scale_inv": | 
|  | 156 | +                    raise ValueError("Expect unquantized weights but got a quantized weight_scale") | 
|  | 157 | +                return True | 
|  | 158 | +        return False | 
|  | 159 | + | 
|  | 160 | +    def _process_model_before_weight_loading( | 
|  | 161 | +        self, | 
|  | 162 | +        model: "ModelMixin", | 
|  | 163 | +        keep_in_fp32_modules: Optional[List[str]] = None, | 
|  | 164 | +        **kwargs, | 
|  | 165 | +    ): | 
|  | 166 | +        from .utils import replace_with_fp8_linear | 
|  | 167 | + | 
|  | 168 | +        if self.quantization_config.modules_to_not_convert is not None: | 
|  | 169 | +            self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert) | 
|  | 170 | + | 
|  | 171 | +        model = replace_with_fp8_linear( | 
|  | 172 | +            model, | 
|  | 173 | +            modules_to_not_convert=self.modules_to_not_convert, | 
|  | 174 | +            quantization_config=self.quantization_config, | 
|  | 175 | +        ) | 
|  | 176 | + | 
|  | 177 | +        model.config.quantization_config = self.quantization_config | 
|  | 178 | + | 
|  | 179 | +    def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): | 
|  | 180 | +        return model | 
|  | 181 | + | 
|  | 182 | +    def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: | 
|  | 183 | +        from .utils import FP8Linear | 
|  | 184 | + | 
|  | 185 | +        not_missing_keys = [] | 
|  | 186 | +        for name, module in model.named_modules(): | 
|  | 187 | +            if isinstance(module, FP8Linear): | 
|  | 188 | +                for missing in missing_keys: | 
|  | 189 | +                    if ( | 
|  | 190 | +                        (name in missing or name in f"{prefix}.{missing}") | 
|  | 191 | +                        and not missing.endswith(".weight") | 
|  | 192 | +                        and not missing.endswith(".bias") | 
|  | 193 | +                    ): | 
|  | 194 | +                        not_missing_keys.append(missing) | 
|  | 195 | +        return [k for k in missing_keys if k not in not_missing_keys] | 
|  | 196 | + | 
|  | 197 | +    def is_serializable(self, safe_serialization=None): | 
|  | 198 | +        return True | 
|  | 199 | + | 
|  | 200 | +    @property | 
|  | 201 | +    def is_trainable(self) -> bool: | 
|  | 202 | +        return False | 
|  | 203 | + | 
|  | 204 | +    def get_cuda_warm_up_factor(self): | 
|  | 205 | +        # Pre-processing is done cleanly, so we can allocate everything here | 
|  | 206 | +        return 2 | 
0 commit comments