Skip to content

Commit 9610a20

Browse files
committed
per token per channel quantization initial setup
1 parent 466334a commit 9610a20

File tree

4 files changed

+55
-2
lines changed

4 files changed

+55
-2
lines changed

vllm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def _verify_quantization(self) -> None:
427427
optimized_quantization_methods = [
428428
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
429429
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
430-
"compressed-tensors", "experts_int8"
430+
"compressed-tensors", "experts_int8", "ptpc_fp8"
431431
]
432432
if self.quantization is not None:
433433
self.quantization = self.quantization.lower()

vllm/model_executor/layers/quantization/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"tpu_int8",
1111
"fp8",
1212
"fbgemm_fp8",
13+
"ptpc_fp8",
1314
"modelopt",
1415
# The order of gptq methods is important for config.py iteration over
1516
# override_quantization_method(..)
@@ -43,6 +44,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
4344
from .deepspeedfp import DeepSpeedFPConfig
4445
from .experts_int8 import ExpertsInt8Config
4546
from .fbgemm_fp8 import FBGEMMFp8Config
47+
from .ptpc_fp8 import PTPCFp8Config
4648
from .fp8 import Fp8Config
4749
from .gguf import GGUFConfig
4850
from .gptq import GPTQConfig
@@ -63,6 +65,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
6365
"tpu_int8": Int8TpuConfig,
6466
"fp8": Fp8Config,
6567
"fbgemm_fp8": FBGEMMFp8Config,
68+
"ptpc_fp8": PTPCFp8Config,
6669
"modelopt": ModelOptFp8Config,
6770
# The order of gptq methods is important for config.py iteration over
6871
# override_quantization_method(..)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import Any, Dict, List, Optional
2+
3+
import torch
4+
from torch.nn import Module
5+
from torch.nn.parameter import Parameter
6+
7+
from vllm.logger import init_logger
8+
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
9+
UnquantizedLinearMethod)
10+
from vllm.model_executor.layers.quantization.base_config import (
11+
QuantizationConfig, QuantizeMethodBase)
12+
from vllm.model_executor.layers.quantization.fp8 import cutlass_fp8_supported
13+
from vllm.model_executor.layers.quantization.fbgemm_fp8 import (
14+
FBGEMMFp8Config, FBGEMMFp8LinearMethod)
15+
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
16+
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
17+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
18+
is_layer_skipped)
19+
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
20+
apply_fp8_linear, normalize_e4m3fn_to_e4m3fnuz)
21+
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
22+
ModelWeightParameter)
23+
from vllm.platforms import current_platform
24+
25+
logger = init_logger(__name__)
26+
27+
28+
class PTPCFp8Config(FBGEMMFp8Config):
29+
"""Config class for Per-Token-Per-Channel Fp8."""
30+
31+
def __init__(self, ignore_list: Optional[List[str]] = None):
32+
super().__init__(ignore_list, 1.0) # Dummy values
33+
34+
@classmethod
35+
def get_name(cls) -> str:
36+
return "ptpc_fp8"
37+
38+
@classmethod
39+
def from_config(cls, config: Dict[str, Any]) -> "PTPCFp8Config":
40+
ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
41+
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
42+
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
43+
44+
def get_quant_method(self, layer: torch.nn.Module,
45+
prefix: str) -> Optional["QuantizeMethodBase"]:
46+
if isinstance(layer, LinearBase):
47+
if is_layer_skipped(prefix, self.ignore_list):
48+
return UnquantizedLinearMethod()
49+
return FBGEMMFp8LinearMethod(self)
50+
return None

vllm/platforms/rocm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class RocmPlatform(Platform):
7979
dispatch_key: str = "CUDA"
8080
supported_quantization: list[str] = [
8181
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
82-
"fbgemm_fp8", "gguf"
82+
"fbgemm_fp8", "gguf", "ptpc_fp8"
8383
]
8484

8585
@classmethod

0 commit comments

Comments
 (0)