11from typing import Any , Dict , List , Optional
22
33import torch
4- from torch .nn import Module
5- from torch .nn .parameter import Parameter
64
75from vllm .logger import init_logger
8- from vllm .model_executor .layers .linear import (LinearBase , LinearMethodBase ,
6+ from vllm .model_executor .layers .linear import (LinearBase ,
97 UnquantizedLinearMethod )
108from vllm .model_executor .layers .quantization .base_config import (
11- QuantizationConfig , QuantizeMethodBase )
12- from vllm .model_executor .layers .quantization .fp8 import cutlass_fp8_supported
9+ QuantizeMethodBase )
1310from vllm .model_executor .layers .quantization .fbgemm_fp8 import (
1411 FBGEMMFp8Config , FBGEMMFp8LinearMethod )
15- from vllm .model_executor .layers .quantization .utils .marlin_utils_fp8 import (
16- apply_fp8_marlin_linear , prepare_fp8_layer_for_marlin )
1712from vllm .model_executor .layers .quantization .utils .quant_utils import (
1813 is_layer_skipped )
19- from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
20- apply_fp8_linear , normalize_e4m3fn_to_e4m3fnuz )
21- from vllm .model_executor .parameter import (ChannelQuantScaleParameter ,
22- ModelWeightParameter )
2314from vllm .platforms import current_platform
2415
2516logger = init_logger (__name__ )
@@ -29,6 +20,8 @@ class PTPCFp8Config(FBGEMMFp8Config):
2920 """Config class for Per-Token-Per-Channel Fp8."""
3021
3122 def __init__ (self , ignore_list : Optional [List [str ]] = None ):
23+ if not current_platform .is_rocm ():
24+ raise ValueError ("ptpc_fpp8 quantization is supported only on ROCm" )
3225 super ().__init__ (ignore_list , 1.0 ) # Dummy values
3326
3427 @classmethod
0 commit comments