55import torch .nn .functional as F
66from lightllm .common .quantization .triton_quant .fp8 .fp8act_quant_kernel import per_token_group_quant_fp8
77from lightllm .common .quantization .triton_quant .fp8 .fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul
8+ from lightllm .utils .vllm_utils import HAS_VLLM , vllm_ops , cutlass_scaled_mm
89
9- try :
10- HAS_VLLM = True
11- from lightllm .common .vllm_kernel import _custom_ops as ops
12- except :
13- HAS_VLLM = False
1410
15-
16- class vLLMBaseQuantizationMethod (QuantizationMethod ):
11+ class BaseQuantizationMethod (QuantizationMethod ):
1712 def __init__ (self ):
1813 super ().__init__ ()
19- assert HAS_VLLM , "vllm is not installed, you can't use quant api of it "
14+ assert HAS_VLLM , "vllm are not installed, you can't use quant api of them. "
2015 from lightllm .common .basemodel .layer_infer .cache_tensor_manager import g_cache_manager
2116
2217 self .cache_manager = g_cache_manager
@@ -30,8 +25,8 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None):
3025 pass
3126
3227
33- @QUANTMETHODS .register (["vllm-w8a8" ])
34- class vLLMw8a8QuantizationMethod ( vLLMBaseQuantizationMethod ):
28+ @QUANTMETHODS .register (["vllm-w8a8" , "w8a8" ])
29+ class w8a8QuantizationMethod ( BaseQuantizationMethod ):
3530 def __init__ (self ):
3631 super ().__init__ ()
3732
@@ -53,7 +48,7 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
5348 else :
5449 raise ValueError ("vllm-quant Weights must be a tuple of length 2 or 3." )
5550
56- x_q , x_scale , x_zp = ops .scaled_int8_quant (input_tensor , scale = input_scale , azp = None , symmetric = True )
51+ x_q , x_scale , x_zp = vllm_ops .scaled_int8_quant (input_tensor , scale = input_scale , azp = None , symmetric = True )
5752 m = input_tensor .shape [0 ]
5853 n = qweight .shape [1 ]
5954 if out is None :
@@ -63,51 +58,31 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
6358 )
6459 else :
6560 out = torch .empty ((m , n ), dtype = input_tensor .dtype , device = input_tensor .device )
66- torch . ops . _C . cutlass_scaled_mm (out , x_q , qweight , x_scale , weight_scale , bias )
61+ cutlass_scaled_mm (out , x_q , qweight , x_scale , weight_scale , bias )
6762 return out
6863
6964
70- @QUANTMETHODS .register (["vllm-fp8w8a8" ])
71- class vLLMFP8w8a8QuantizationMethod ( vLLMBaseQuantizationMethod ):
65+ @QUANTMETHODS .register (["vllm-fp8w8a8" , "fp8w8a8" ])
66+ class FP8w8a8QuantizationMethod ( BaseQuantizationMethod ):
7267 def __init__ (self ):
7368 super ().__init__ ()
7469 self .is_moe = False
75- # PINGPONG_FP8_GEMM is per tensor quant way.
76- self .use_pingpong_fp8_gemm = os .getenv ("ENABLE_PINGPONG_FP8_GEMM" , "0" ).upper () in ["ON" , "TRUE" , "1" ]
77-
78- if self .use_pingpong_fp8_gemm :
79- self .quantize = self .quantize_pingpong_fp8
80- self .apply = self .apply_pingpong_fp8
81- else :
82- self .quantize = self .quantize_scaled_mm_fp8
83- self .apply = self .apply_scaled_mm_fp8
8470
8571 def quantize (self , weight : torch .Tensor ):
86- raise Exception ("This function needs to be bound." )
87-
88- def quantize_scaled_mm_fp8 (self , weight : torch .Tensor ):
8972 if self .is_moe :
9073 return self .quantize_moe (weight )
91- qweight , weight_scale = ops .scaled_fp8_quant (
74+ qweight , weight_scale = vllm_ops .scaled_fp8_quant (
9275 weight .contiguous ().cuda (self .device_id_ ), scale = None , use_per_token_if_dynamic = True
9376 )
9477 return qweight .transpose (0 , 1 ), weight_scale
9578
96- def quantize_pingpong_fp8 (self , weight : torch .Tensor ):
97- if self .is_moe :
98- return self .quantize_moe (weight )
99- qweight , weight_scale = ops .scaled_fp8_quant (
100- weight .contiguous ().cuda (), scale = None , use_per_token_if_dynamic = False
101- )
102- return qweight .transpose (0 , 1 ), weight_scale
103-
10479 def quantize_moe (self , weight ):
10580 num_experts = weight .shape [0 ]
10681 qweights = []
10782 weight_scales = []
10883 qweights = torch .empty_like (weight , dtype = torch .float8_e4m3fn ).cuda (self .device_id_ )
10984 for i in range (num_experts ):
110- qweight , weight_scale = ops .scaled_fp8_quant (
85+ qweight , weight_scale = vllm_ops .scaled_fp8_quant (
11186 weight [i ].contiguous ().cuda (self .device_id_ ), scale = None , use_per_token_if_dynamic = False
11287 )
11388 qweights [i ] = qweight
@@ -116,12 +91,7 @@ def quantize_moe(self, weight):
11691 return qweights , weight_scale
11792
11893 def apply (self , input_tensor , weights , bias = None , out = None , workspace = None , use_custom_tensor_mananger = True ):
119- raise Exception ("This function needs to be bound." )
120-
121- def apply_scaled_mm_fp8 (
122- self , input_tensor , weights , bias = None , out = None , workspace = None , use_custom_tensor_mananger = True
123- ):
124- x_q , x_scale = ops .scaled_fp8_quant (input_tensor , scale = None , scale_ub = None , use_per_token_if_dynamic = True )
94+ x_q , x_scale = vllm_ops .scaled_fp8_quant (input_tensor , scale = None , scale_ub = None , use_per_token_if_dynamic = True )
12595 m = input_tensor .shape [0 ]
12696 n = weights [0 ].shape [1 ]
12797 if out is None :
@@ -131,31 +101,12 @@ def apply_scaled_mm_fp8(
131101 )
132102 else :
133103 out = torch .empty ((m , n ), dtype = input_tensor .dtype , device = input_tensor .device )
134- torch . ops . _C . cutlass_scaled_mm (out , x_q , weights [0 ], x_scale , weights [1 ], bias )
104+ cutlass_scaled_mm (out , x_q , weights [0 ], x_scale , weights [1 ], bias )
135105 return out
136106
137- def apply_pingpong_fp8 (
138- self , input_tensor , weights , bias = None , out = None , workspace = None , use_custom_tensor_mananger = True
139- ):
140- x_q , x_scale = ops .scaled_fp8_quant (input_tensor , scale = None , scale_ub = None , use_per_token_if_dynamic = False )
141- assert bias is None
142- m = input_tensor .shape [0 ]
143- n = weights [0 ].shape [1 ]
144- if out is None :
145- if use_custom_tensor_mananger :
146- out = self .cache_manager .alloc_tensor (
147- (m , n ), input_tensor .dtype , device = input_tensor .device , is_graph_out = False
148- )
149- else :
150- out = torch .empty ((m , n ), dtype = input_tensor .dtype , device = input_tensor .device )
151-
152- from fp8_pingpong_gemm import cutlass_scaled_mm
153-
154- return cutlass_scaled_mm (x_q , weights [0 ], x_scale , weights [1 ], out )
155-
156107
157- @QUANTMETHODS .register (["vllm-fp8w8a8-b128" ])
158- class vLLMFP8w8a8B128QuantizationMethod ( vLLMBaseQuantizationMethod ):
108+ @QUANTMETHODS .register (["vllm-fp8w8a8-b128, fp8w8a8-b128 " ])
109+ class FP8w8a8B128QuantizationMethod ( BaseQuantizationMethod ):
159110 def __init__ (self ):
160111 super ().__init__ ()
161112 self .block_size = 128
@@ -197,5 +148,5 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
197148 )
198149 else :
199150 input_scale = input_scale .t ().contiguous ().t ()
200- torch . ops . _C . cutlass_scaled_mm (out , qinput_tensor , qweight , input_scale , weight_scale , bias )
151+ cutlass_scaled_mm (out , qinput_tensor , qweight , input_scale , weight_scale , bias )
201152 return out
0 commit comments