66from lightllm .common .quantization .triton_quant .fp8 .fp8act_quant_kernel import per_token_group_quant_fp8
77from lightllm .common .quantization .triton_quant .fp8 .fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul
88from lightllm .utils .vllm_utils import HAS_VLLM , vllm_ops , cutlass_scaled_mm
9+ from lightllm .utils .light_utils import HAS_LIGHTLLM_KERNEL , light_ops
10+
11+ if not HAS_LIGHTLLM_KERNEL :
12+
13+ def scaled_fp8_quant (tensor , * args , ** kwargs ):
14+ return light_ops .per_token_quant_bf16_fp8 (tensor )
15+
16+ else :
17+ scaled_fp8_quant = vllm_ops .scaled_fp8_quant
918
1019
1120class BaseQuantizationMethod (QuantizationMethod ):
@@ -71,7 +80,7 @@ def __init__(self):
7180 def quantize (self , weight : torch .Tensor ):
7281 if self .is_moe :
7382 return self .quantize_moe (weight )
74- qweight , weight_scale = vllm_ops . scaled_fp8_quant (
83+ qweight , weight_scale = scaled_fp8_quant (
7584 weight .contiguous ().cuda (self .device_id_ ), scale = None , use_per_token_if_dynamic = True
7685 )
7786 return qweight .transpose (0 , 1 ), weight_scale
@@ -82,7 +91,7 @@ def quantize_moe(self, weight):
8291 weight_scales = []
8392 qweights = torch .empty_like (weight , dtype = torch .float8_e4m3fn ).cuda (self .device_id_ )
8493 for i in range (num_experts ):
85- qweight , weight_scale = vllm_ops . scaled_fp8_quant (
94+ qweight , weight_scale = scaled_fp8_quant (
8695 weight [i ].contiguous ().cuda (self .device_id_ ), scale = None , use_per_token_if_dynamic = False
8796 )
8897 qweights [i ] = qweight
@@ -91,7 +100,7 @@ def quantize_moe(self, weight):
91100 return qweights , weight_scale
92101
93102 def apply (self , input_tensor , weights , bias = None , out = None , workspace = None , use_custom_tensor_mananger = True ):
94- x_q , x_scale = vllm_ops . scaled_fp8_quant (input_tensor , scale = None , scale_ub = None , use_per_token_if_dynamic = True )
103+ x_q , x_scale = scaled_fp8_quant (input_tensor , scale = None , scale_ub = None , use_per_token_if_dynamic = True )
95104 m = input_tensor .shape [0 ]
96105 n = weights [0 ].shape [1 ]
97106 if out is None :
0 commit comments