55import torch .nn .functional as F
66from lightllm .common .quantization .triton_quant .fp8 .fp8act_quant_kernel import per_token_group_quant_fp8
77from lightllm .common .quantization .triton_quant .fp8 .fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul
8+ from lightllm .utils .vllm_utils import HAS_VLLM , vllm_ops
9+ from lightllm .utils .sgl_utils import HAS_SGL_KERNEL , sgl_ops
810
9- try :
10- HAS_VLLM = True
11- from lightllm .common .vllm_kernel import _custom_ops as ops
12- except :
13- HAS_VLLM = False
1411
15-
16- class vLLMBaseQuantizationMethod (QuantizationMethod ):
12+ class BaseQuantizationMethod (QuantizationMethod ):
1713 def __init__ (self ):
1814 super ().__init__ ()
19- assert HAS_VLLM , "vllm is not installed, you can't use quant api of it "
15+ assert HAS_VLLM and HAS_SGL_KERNEL , "vllm and sgl_kernel are not installed, you can't use quant api of them. "
2016 from lightllm .common .basemodel .layer_infer .cache_tensor_manager import g_cache_manager
2117
2218 self .cache_manager = g_cache_manager
@@ -30,8 +26,8 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None):
3026 pass
3127
3228
33- @QUANTMETHODS .register (["vllm-w8a8" ])
34- class vLLMw8a8QuantizationMethod ( vLLMBaseQuantizationMethod ):
29+ @QUANTMETHODS .register (["vllm-w8a8" , "w8a8" ])
30+ class w8a8QuantizationMethod ( BaseQuantizationMethod ):
3531 def __init__ (self ):
3632 super ().__init__ ()
3733
@@ -53,7 +49,7 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
5349 else :
5450 raise ValueError ("vllm-quant Weights must be a tuple of length 2 or 3." )
5551
56- x_q , x_scale , x_zp = ops .scaled_int8_quant (input_tensor , scale = input_scale , azp = None , symmetric = True )
52+ x_q , x_scale , x_zp = vllm_ops .scaled_int8_quant (input_tensor , scale = input_scale , azp = None , symmetric = True )
5753 m = input_tensor .shape [0 ]
5854 n = qweight .shape [1 ]
5955 if out is None :
@@ -67,8 +63,8 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
6763 return out
6864
6965
70- @QUANTMETHODS .register (["vllm-fp8w8a8" ])
71- class vLLMFP8w8a8QuantizationMethod ( vLLMBaseQuantizationMethod ):
66+ @QUANTMETHODS .register (["vllm-fp8w8a8" , "fp8w8a8" ])
67+ class FP8w8a8QuantizationMethod ( BaseQuantizationMethod ):
7268 def __init__ (self ):
7369 super ().__init__ ()
7470 self .is_moe = False
@@ -88,15 +84,15 @@ def quantize(self, weight: torch.Tensor):
8884 def quantize_scaled_mm_fp8 (self , weight : torch .Tensor ):
8985 if self .is_moe :
9086 return self .quantize_moe (weight )
91- qweight , weight_scale = ops .scaled_fp8_quant (
87+ qweight , weight_scale = vllm_ops .scaled_fp8_quant (
9288 weight .contiguous ().cuda (self .device_id_ ), scale = None , use_per_token_if_dynamic = True
9389 )
9490 return qweight .transpose (0 , 1 ), weight_scale
9591
9692 def quantize_pingpong_fp8 (self , weight : torch .Tensor ):
9793 if self .is_moe :
9894 return self .quantize_moe (weight )
99- qweight , weight_scale = ops .scaled_fp8_quant (
95+ qweight , weight_scale = vllm_ops .scaled_fp8_quant (
10096 weight .contiguous ().cuda (), scale = None , use_per_token_if_dynamic = False
10197 )
10298 return qweight .transpose (0 , 1 ), weight_scale
@@ -107,7 +103,7 @@ def quantize_moe(self, weight):
107103 weight_scales = []
108104 qweights = torch .empty_like (weight , dtype = torch .float8_e4m3fn ).cuda (self .device_id_ )
109105 for i in range (num_experts ):
110- qweight , weight_scale = ops .scaled_fp8_quant (
106+ qweight , weight_scale = vllm_ops .scaled_fp8_quant (
111107 weight [i ].contiguous ().cuda (self .device_id_ ), scale = None , use_per_token_if_dynamic = False
112108 )
113109 qweights [i ] = qweight
@@ -121,7 +117,7 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
121117 def apply_scaled_mm_fp8 (
122118 self , input_tensor , weights , bias = None , out = None , workspace = None , use_custom_tensor_mananger = True
123119 ):
124- x_q , x_scale = ops .scaled_fp8_quant (input_tensor , scale = None , scale_ub = None , use_per_token_if_dynamic = True )
120+ x_q , x_scale = vllm_ops .scaled_fp8_quant (input_tensor , scale = None , scale_ub = None , use_per_token_if_dynamic = True )
125121 m = input_tensor .shape [0 ]
126122 n = weights [0 ].shape [1 ]
127123 if out is None :
@@ -137,7 +133,9 @@ def apply_scaled_mm_fp8(
137133 def apply_pingpong_fp8 (
138134 self , input_tensor , weights , bias = None , out = None , workspace = None , use_custom_tensor_mananger = True
139135 ):
140- x_q , x_scale = ops .scaled_fp8_quant (input_tensor , scale = None , scale_ub = None , use_per_token_if_dynamic = False )
136+ x_q , x_scale = vllm_ops .scaled_fp8_quant (
137+ input_tensor , scale = None , scale_ub = None , use_per_token_if_dynamic = False
138+ )
141139 assert bias is None
142140 m = input_tensor .shape [0 ]
143141 n = weights [0 ].shape [1 ]
@@ -154,8 +152,8 @@ def apply_pingpong_fp8(
154152 return cutlass_scaled_mm (x_q , weights [0 ], x_scale , weights [1 ], out )
155153
156154
157- @QUANTMETHODS .register (["vllm-fp8w8a8-b128" ])
158- class vLLMFP8w8a8B128QuantizationMethod ( vLLMBaseQuantizationMethod ):
155+ @QUANTMETHODS .register (["vllm-fp8w8a8-b128, fp8w8a8-b128 " ])
156+ class FP8w8a8B128QuantizationMethod ( BaseQuantizationMethod ):
159157 def __init__ (self ):
160158 super ().__init__ ()
161159 self .block_size = 128
0 commit comments