|
| 1 | +from typing import Any, Dict, List, Optional |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch.nn import Module |
| 5 | +from torch.nn.parameter import Parameter |
| 6 | + |
| 7 | +from vllm.logger import init_logger |
| 8 | +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase |
| 9 | +from vllm.model_executor.layers.quantization.base_config import ( |
| 10 | + QuantizationConfig, QuantizeMethodBase) |
| 11 | +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod |
| 12 | +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( |
| 13 | + apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale) |
| 14 | +from vllm.model_executor.parameter import (ModelWeightParameter, |
| 15 | + PerTensorScaleParameter) |
| 16 | + |
| 17 | +logger = init_logger(__name__) |
| 18 | + |
| 19 | +ACTIVATION_SCHEMES = ["static"] |
| 20 | + |
| 21 | + |
| 22 | +class ModelOptFp8Config(QuantizationConfig): |
| 23 | + """Config class for ModelOpt FP8.""" |
| 24 | + |
| 25 | + def __init__( |
| 26 | + self, |
| 27 | + is_checkpoint_fp8_serialized: bool = False, |
| 28 | + ) -> None: |
| 29 | + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized |
| 30 | + if is_checkpoint_fp8_serialized: |
| 31 | + logger.warning("Detected ModelOpt fp8 checkpoint. Please note that" |
| 32 | + " the format is experimental and could change.") |
| 33 | + |
| 34 | + @classmethod |
| 35 | + def get_name(cls) -> str: |
| 36 | + return "modelopt" |
| 37 | + |
| 38 | + @classmethod |
| 39 | + def get_supported_act_dtypes(cls) -> List[torch.dtype]: |
| 40 | + return [torch.bfloat16, torch.half] |
| 41 | + |
| 42 | + @classmethod |
| 43 | + def get_min_capability(cls) -> int: |
| 44 | + return 89 |
| 45 | + |
| 46 | + @classmethod |
| 47 | + def get_config_filenames(cls) -> List[str]: |
| 48 | + return ["hf_quant_config.json"] |
| 49 | + |
| 50 | + @classmethod |
| 51 | + def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config": |
| 52 | + quant_config = cls.get_from_keys(config, ["quantization"]) |
| 53 | + quant_method = quant_config["quant_algo"] |
| 54 | + is_checkpoint_fp8_serialized = ("FP8" in quant_method) |
| 55 | + if not is_checkpoint_fp8_serialized: |
| 56 | + raise ValueError("ModelOpt currently only supports static FP8" |
| 57 | + "quantization in vLLM. Please check the " |
| 58 | + "`hf_quant_config.json` file for your model's " |
| 59 | + "quant configuration.") |
| 60 | + return cls(is_checkpoint_fp8_serialized) |
| 61 | + |
| 62 | + def get_quant_method(self, layer: torch.nn.Module, |
| 63 | + prefix: str) -> Optional["QuantizeMethodBase"]: |
| 64 | + from vllm.attention.layer import Attention # Avoid circular import |
| 65 | + if isinstance(layer, LinearBase): |
| 66 | + return ModelOptFp8LinearMethod(self) |
| 67 | + elif isinstance(layer, Attention): |
| 68 | + return ModelOptFp8KVCacheMethod(self) |
| 69 | + return None |
| 70 | + |
| 71 | + def get_scaled_act_names(self) -> List[str]: |
| 72 | + return [] |
| 73 | + |
| 74 | + |
| 75 | +class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): |
| 76 | + """ |
| 77 | + Supports loading kv-cache scaling factors from FP8 checkpoints. |
| 78 | + """ |
| 79 | + |
| 80 | + def __init__(self, quant_config: ModelOptFp8Config): |
| 81 | + super().__init__(quant_config) |
| 82 | + |
| 83 | + |
| 84 | +class ModelOptFp8LinearMethod(LinearMethodBase): |
| 85 | + """Linear method for Model Optimizer static quantization. |
| 86 | + Supports loading FP8 checkpoints with static weight scale and |
| 87 | + activation scale. Future support might be added for dynamic |
| 88 | + scales. |
| 89 | +
|
| 90 | + Limitations: |
| 91 | + 1. Only support per-tensor quantization due to torch._scaled_mm support. |
| 92 | + 2. Only support float8_e4m3fn datatype |
| 93 | + Args: quant_config: The ModelOpt quantization config. |
| 94 | + """ |
| 95 | + |
| 96 | + def __init__(self, quant_config: ModelOptFp8Config): |
| 97 | + self.quant_config = quant_config |
| 98 | + self.cutlass_fp8_supported = cutlass_fp8_supported() |
| 99 | + |
| 100 | + def create_weights( |
| 101 | + self, |
| 102 | + layer: torch.nn.Module, |
| 103 | + input_size_per_partition: int, |
| 104 | + output_partition_sizes: List[int], |
| 105 | + input_size: int, |
| 106 | + output_size: int, |
| 107 | + params_dtype: torch.dtype, |
| 108 | + **extra_weight_attrs, |
| 109 | + ): |
| 110 | + del input_size, output_size |
| 111 | + output_size_per_partition = sum(output_partition_sizes) |
| 112 | + weight_loader = extra_weight_attrs.get("weight_loader") |
| 113 | + layer.logical_widths = output_partition_sizes |
| 114 | + layer.input_size_per_partition = input_size_per_partition |
| 115 | + layer.output_size_per_partition = output_size_per_partition |
| 116 | + weight_dtype = (torch.float8_e4m3fn |
| 117 | + if self.quant_config.is_checkpoint_fp8_serialized else |
| 118 | + params_dtype) |
| 119 | + weight = ModelWeightParameter(data=torch.empty( |
| 120 | + output_size_per_partition, |
| 121 | + input_size_per_partition, |
| 122 | + dtype=weight_dtype), |
| 123 | + input_dim=1, |
| 124 | + output_dim=0, |
| 125 | + weight_loader=weight_loader) |
| 126 | + layer.register_parameter("weight", weight) |
| 127 | + |
| 128 | + if self.quant_config.is_checkpoint_fp8_serialized: |
| 129 | + # WEIGHT SCALE |
| 130 | + weight_scale = PerTensorScaleParameter(data=torch.empty( |
| 131 | + len(output_partition_sizes), dtype=torch.float32), |
| 132 | + weight_loader=weight_loader) |
| 133 | + weight_scale[:] = torch.finfo(torch.float32).min |
| 134 | + layer.register_parameter("weight_scale", weight_scale) |
| 135 | + # INPUT SCALE |
| 136 | + scale = PerTensorScaleParameter(data=torch.empty( |
| 137 | + len(output_partition_sizes), dtype=torch.float32), |
| 138 | + weight_loader=weight_loader) |
| 139 | + |
| 140 | + scale[:] = torch.finfo(torch.float32).min |
| 141 | + layer.register_parameter("input_scale", scale) |
| 142 | + |
| 143 | + def process_weights_after_loading(self, layer: Module) -> None: |
| 144 | + max_w_scale, weight = requantize_with_max_scale( |
| 145 | + layer.weight, layer.weight_scale, layer.logical_widths) |
| 146 | + layer.weight = Parameter(weight.t(), requires_grad=False) |
| 147 | + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) |
| 148 | + layer.input_scale = Parameter(layer.input_scale.max(), |
| 149 | + requires_grad=False) |
| 150 | + |
| 151 | + def apply( |
| 152 | + self, |
| 153 | + layer: torch.nn.Module, |
| 154 | + x: torch.Tensor, |
| 155 | + bias: Optional[torch.Tensor] = None, |
| 156 | + ) -> torch.Tensor: |
| 157 | + return apply_fp8_linear( |
| 158 | + input=x, |
| 159 | + weight=layer.weight, |
| 160 | + weight_scale=layer.weight_scale, |
| 161 | + input_scale=layer.input_scale, |
| 162 | + bias=bias, |
| 163 | + cutlass_fp8_supported=self.cutlass_fp8_supported) |
0 commit comments