|
1 | 1 | from typing import Any, Dict, List, Optional |
2 | 2 |
|
3 | 3 | import torch |
| 4 | +from torch.nn.parameter import Parameter |
4 | 5 |
|
5 | 6 | from vllm.logger import init_logger |
6 | 7 | from vllm.model_executor.layers.linear import (LinearBase, |
7 | 8 | UnquantizedLinearMethod) |
8 | 9 | from vllm.model_executor.layers.quantization.base_config import ( |
9 | 10 | QuantizeMethodBase) |
10 | | -from vllm.model_executor.layers.quantization.fbgemm_fp8 import ( |
11 | | - FBGEMMFp8Config, FBGEMMFp8LinearMethod) |
| 11 | +from vllm.model_executor.layers.quantization.fp8 import ( |
| 12 | + Fp8Config, Fp8LinearMethod, Fp8KVCacheMethod) |
12 | 13 | from vllm.model_executor.layers.quantization.utils.quant_utils import ( |
13 | 14 | is_layer_skipped) |
14 | 15 | from vllm.platforms import current_platform |
| 16 | +from vllm import _custom_ops as ops |
| 17 | +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( |
| 18 | + apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) |
| 19 | +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( |
| 20 | + apply_fp8_linear) |
| 21 | + |
| 22 | +ACTIVATION_SCHEMES = ["static", "dynamic"] |
15 | 23 |
|
16 | 24 | logger = init_logger(__name__) |
17 | 25 |
|
18 | 26 |
|
19 | | -class PTPCFp8Config(FBGEMMFp8Config): |
| 27 | +class PTPCFp8Config(Fp8Config): |
20 | 28 | """Config class for Per-Token-Per-Channel Fp8.""" |
21 | 29 |
|
22 | | - def __init__(self, ignore_list: Optional[List[str]] = None): |
| 30 | + def __init__( |
| 31 | + self, |
| 32 | + activation_scheme: str = "dynamic", |
| 33 | + ignored_layers: Optional[List[str]] = None, |
| 34 | + ) -> None: |
23 | 35 | if not current_platform.is_rocm(): |
24 | 36 | raise ValueError("ptpc_fpp8 quantization is supported only on ROCm") |
25 | | - super().__init__(ignore_list, 1.0) # Dummy values |
| 37 | + super().__init__( |
| 38 | + is_checkpoint_fp8_serialized=False, |
| 39 | + activation_scheme=activation_scheme, |
| 40 | + ignored_layers=ignored_layers) |
26 | 41 |
|
27 | 42 | @classmethod |
28 | 43 | def get_name(cls) -> str: |
29 | 44 | return "ptpc_fp8" |
30 | 45 |
|
31 | 46 | @classmethod |
32 | 47 | def from_config(cls, config: Dict[str, Any]) -> "PTPCFp8Config": |
33 | | - ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"]) |
34 | | - input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"]) |
35 | | - return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub) |
| 48 | + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) |
| 49 | + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) |
| 50 | + return cls(activation_scheme=activation_scheme, |
| 51 | + ignored_layers=ignored_layers) |
36 | 52 |
|
37 | 53 | def get_quant_method(self, layer: torch.nn.Module, |
38 | 54 | prefix: str) -> Optional["QuantizeMethodBase"]: |
| 55 | + from vllm.attention.layer import Attention # Avoid circular import |
| 56 | + |
39 | 57 | if isinstance(layer, LinearBase): |
40 | | - if is_layer_skipped(prefix, self.ignore_list): |
| 58 | + if is_layer_skipped(prefix, self.ignored_layers): |
41 | 59 | return UnquantizedLinearMethod() |
42 | | - return FBGEMMFp8LinearMethod(self) |
| 60 | + return PTPCFp8LinearMethod(self) |
| 61 | + elif isinstance(layer, Attention): |
| 62 | + return Fp8KVCacheMethod(self) |
43 | 63 | return None |
| 64 | + |
| 65 | + |
| 66 | +class PTPCFp8LinearMethod(Fp8LinearMethod): |
| 67 | + """Linear method for Per-Token and Per-Channel FP8 Quantization. |
| 68 | + Only supports loading quantized FP16/BF16 model checkpoints with dynamic |
| 69 | + activation scaling. The weight scaling factor will be initialized after |
| 70 | + the model weights are loaded. |
| 71 | +
|
| 72 | + Limitations: |
| 73 | + 1. Only support float8_e4m3fn data type due to the limitation of |
| 74 | + torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) |
| 75 | +
|
| 76 | + Args: |
| 77 | + quant_config: The quantization config. |
| 78 | + """ |
| 79 | + |
| 80 | + def __init__(self, quant_config: PTPCFp8Config): |
| 81 | + super().__init__(quant_config=quant_config) |
| 82 | + # Force weight quantization |
| 83 | + self.quant_config.is_checkpoint_fp8_serialized = False |
| 84 | + |
| 85 | + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: |
| 86 | + layer.weight = torch.nn.Parameter(layer.weight.data, |
| 87 | + requires_grad=False) |
| 88 | + |
| 89 | + # Quantize the weights. |
| 90 | + qweight, weight_scale = ops.scaled_fp8_quant( |
| 91 | + layer.weight, |
| 92 | + scale=None, |
| 93 | + use_per_token_if_dynamic=True) |
| 94 | + |
| 95 | + # Update the layer with the new values. |
| 96 | + layer.weight = Parameter(qweight.t(), requires_grad=False) |
| 97 | + layer.weight_scale = Parameter(weight_scale, requires_grad=False) |
| 98 | + layer.input_scale = None |
| 99 | + |
| 100 | + if self.use_marlin: |
| 101 | + prepare_fp8_layer_for_marlin(layer) |
| 102 | + # Activations not quantized for marlin. |
| 103 | + del layer.input_scale |
| 104 | + |
| 105 | + def apply(self, |
| 106 | + layer: torch.nn.Module, |
| 107 | + x: torch.Tensor, |
| 108 | + bias: Optional[torch.Tensor] = None) -> torch.Tensor: |
| 109 | + |
| 110 | + if self.use_marlin: |
| 111 | + return apply_fp8_marlin_linear( |
| 112 | + input=x, |
| 113 | + weight=layer.weight, |
| 114 | + weight_scale=layer.weight_scale, |
| 115 | + workspace=layer.workspace, |
| 116 | + size_n=layer.output_size_per_partition, |
| 117 | + size_k=layer.input_size_per_partition, |
| 118 | + bias=bias) |
| 119 | + |
| 120 | + return apply_fp8_linear( |
| 121 | + input=x, |
| 122 | + weight=layer.weight, |
| 123 | + weight_scale=layer.weight_scale, |
| 124 | + out_dtype=self.out_dtype, |
| 125 | + input_scale=None, |
| 126 | + input_scale_ub=None, |
| 127 | + bias=bias, |
| 128 | + cutlass_fp8_supported=None, |
| 129 | + use_per_token_if_dynamic=True) |
0 commit comments