Skip to content

Commit 7953b38

Browse files
committed
quantize model weights to fp8
1 parent 2f056b9 commit 7953b38

File tree

1 file changed

+96
-10
lines changed

1 file changed

+96
-10
lines changed
Lines changed: 96 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,129 @@
11
from typing import Any, Dict, List, Optional
22

33
import torch
4+
from torch.nn.parameter import Parameter
45

56
from vllm.logger import init_logger
67
from vllm.model_executor.layers.linear import (LinearBase,
78
UnquantizedLinearMethod)
89
from vllm.model_executor.layers.quantization.base_config import (
910
QuantizeMethodBase)
10-
from vllm.model_executor.layers.quantization.fbgemm_fp8 import (
11-
FBGEMMFp8Config, FBGEMMFp8LinearMethod)
11+
from vllm.model_executor.layers.quantization.fp8 import (
12+
Fp8Config, Fp8LinearMethod, Fp8KVCacheMethod)
1213
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1314
is_layer_skipped)
1415
from vllm.platforms import current_platform
16+
from vllm import _custom_ops as ops
17+
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
18+
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
19+
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
20+
apply_fp8_linear)
21+
22+
ACTIVATION_SCHEMES = ["static", "dynamic"]
1523

1624
logger = init_logger(__name__)
1725

1826

19-
class PTPCFp8Config(FBGEMMFp8Config):
27+
class PTPCFp8Config(Fp8Config):
2028
"""Config class for Per-Token-Per-Channel Fp8."""
2129

22-
def __init__(self, ignore_list: Optional[List[str]] = None):
30+
def __init__(
31+
self,
32+
activation_scheme: str = "dynamic",
33+
ignored_layers: Optional[List[str]] = None,
34+
) -> None:
2335
if not current_platform.is_rocm():
2436
raise ValueError("ptpc_fpp8 quantization is supported only on ROCm")
25-
super().__init__(ignore_list, 1.0) # Dummy values
37+
super().__init__(
38+
is_checkpoint_fp8_serialized=False,
39+
activation_scheme=activation_scheme,
40+
ignored_layers=ignored_layers)
2641

2742
@classmethod
2843
def get_name(cls) -> str:
2944
return "ptpc_fp8"
3045

3146
@classmethod
3247
def from_config(cls, config: Dict[str, Any]) -> "PTPCFp8Config":
33-
ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
34-
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
35-
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
48+
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
49+
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
50+
return cls(activation_scheme=activation_scheme,
51+
ignored_layers=ignored_layers)
3652

3753
def get_quant_method(self, layer: torch.nn.Module,
3854
prefix: str) -> Optional["QuantizeMethodBase"]:
55+
from vllm.attention.layer import Attention # Avoid circular import
56+
3957
if isinstance(layer, LinearBase):
40-
if is_layer_skipped(prefix, self.ignore_list):
58+
if is_layer_skipped(prefix, self.ignored_layers):
4159
return UnquantizedLinearMethod()
42-
return FBGEMMFp8LinearMethod(self)
60+
return PTPCFp8LinearMethod(self)
61+
elif isinstance(layer, Attention):
62+
return Fp8KVCacheMethod(self)
4363
return None
64+
65+
66+
class PTPCFp8LinearMethod(Fp8LinearMethod):
67+
"""Linear method for Per-Token and Per-Channel FP8 Quantization.
68+
Only supports loading quantized FP16/BF16 model checkpoints with dynamic
69+
activation scaling. The weight scaling factor will be initialized after
70+
the model weights are loaded.
71+
72+
Limitations:
73+
1. Only support float8_e4m3fn data type due to the limitation of
74+
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
75+
76+
Args:
77+
quant_config: The quantization config.
78+
"""
79+
80+
def __init__(self, quant_config: PTPCFp8Config):
81+
super().__init__(quant_config=quant_config)
82+
# Force weight quantization
83+
self.quant_config.is_checkpoint_fp8_serialized = False
84+
85+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
86+
layer.weight = torch.nn.Parameter(layer.weight.data,
87+
requires_grad=False)
88+
89+
# Quantize the weights.
90+
qweight, weight_scale = ops.scaled_fp8_quant(
91+
layer.weight,
92+
scale=None,
93+
use_per_token_if_dynamic=True)
94+
95+
# Update the layer with the new values.
96+
layer.weight = Parameter(qweight.t(), requires_grad=False)
97+
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
98+
layer.input_scale = None
99+
100+
if self.use_marlin:
101+
prepare_fp8_layer_for_marlin(layer)
102+
# Activations not quantized for marlin.
103+
del layer.input_scale
104+
105+
def apply(self,
106+
layer: torch.nn.Module,
107+
x: torch.Tensor,
108+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
109+
110+
if self.use_marlin:
111+
return apply_fp8_marlin_linear(
112+
input=x,
113+
weight=layer.weight,
114+
weight_scale=layer.weight_scale,
115+
workspace=layer.workspace,
116+
size_n=layer.output_size_per_partition,
117+
size_k=layer.input_size_per_partition,
118+
bias=bias)
119+
120+
return apply_fp8_linear(
121+
input=x,
122+
weight=layer.weight,
123+
weight_scale=layer.weight_scale,
124+
out_dtype=self.out_dtype,
125+
input_scale=None,
126+
input_scale_ub=None,
127+
bias=bias,
128+
cutlass_fp8_supported=None,
129+
use_per_token_if_dynamic=True)

0 commit comments

Comments
 (0)