|
| 1 | +from typing import Any, Dict, List, Optional |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch.nn import Module |
| 5 | +from torch.nn.parameter import Parameter |
| 6 | + |
| 7 | +from vllm.logger import init_logger |
| 8 | +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, |
| 9 | + UnquantizedLinearMethod) |
| 10 | +from vllm.model_executor.layers.quantization.base_config import ( |
| 11 | + QuantizationConfig, QuantizeMethodBase) |
| 12 | +from vllm.model_executor.layers.quantization.fp8 import cutlass_fp8_supported |
| 13 | +from vllm.model_executor.layers.quantization.fbgemm_fp8 import ( |
| 14 | + FBGEMMFp8Config, FBGEMMFp8LinearMethod) |
| 15 | +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( |
| 16 | + apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) |
| 17 | +from vllm.model_executor.layers.quantization.utils.quant_utils import ( |
| 18 | + is_layer_skipped) |
| 19 | +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( |
| 20 | + apply_fp8_linear, normalize_e4m3fn_to_e4m3fnuz) |
| 21 | +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, |
| 22 | + ModelWeightParameter) |
| 23 | +from vllm.platforms import current_platform |
| 24 | + |
| 25 | +logger = init_logger(__name__) |
| 26 | + |
| 27 | + |
| 28 | +class PTPCFp8Config(FBGEMMFp8Config): |
| 29 | + """Config class for Per-Token-Per-Channel Fp8.""" |
| 30 | + |
| 31 | + def __init__(self, ignore_list: Optional[List[str]] = None): |
| 32 | + super().__init__(ignore_list, 1.0) # Dummy values |
| 33 | + |
| 34 | + @classmethod |
| 35 | + def get_name(cls) -> str: |
| 36 | + return "ptpc_fp8" |
| 37 | + |
| 38 | + @classmethod |
| 39 | + def from_config(cls, config: Dict[str, Any]) -> "PTPCFp8Config": |
| 40 | + ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"]) |
| 41 | + input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"]) |
| 42 | + return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub) |
| 43 | + |
| 44 | + def get_quant_method(self, layer: torch.nn.Module, |
| 45 | + prefix: str) -> Optional["QuantizeMethodBase"]: |
| 46 | + if isinstance(layer, LinearBase): |
| 47 | + if is_layer_skipped(prefix, self.ignore_list): |
| 48 | + return UnquantizedLinearMethod() |
| 49 | + return FBGEMMFp8LinearMethod(self) |
| 50 | + return None |
0 commit comments