33from .quantize_method import QuantizationMethod
44from .registry import QUANTMETHODS
55import torch .nn .functional as F
6+ from lightllm .common .quantization .triton_quant .fp8 .fp8act_quant_kernel import per_token_group_quant_fp8
67
78try :
89 HAS_VLLM = True
@@ -119,7 +120,21 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
119120 def apply_scaled_mm_fp8 (
120121 self , input_tensor , weights , bias = None , out = None , workspace = None , use_custom_tensor_mananger = True
121122 ):
122- x_q , x_scale = ops .scaled_fp8_quant (input_tensor , scale = None , scale_ub = None , use_per_token_if_dynamic = True )
123+ qweight , weight_scale , input_scale = weights
124+ if weight_scale .shape == 2 :
125+ # block-wise
126+ m , k = input_tensor .shape
127+ if input_scale is None :
128+ input_scale = self .cache_manager .alloc_tensor (
129+ (m , k // self .block_size ), torch .float32 , device = input_tensor .device , is_graph_out = False
130+ )
131+ qinput_tensor = self .cache_manager .alloc_tensor (
132+ (m , k ), qweight .dtype , device = qweight .device , is_graph_out = False
133+ )
134+ per_token_group_quant_fp8 (input_tensor , self .block_size , qinput_tensor , input_scale )
135+ qinput_tensor , input_scale = ops .scaled_fp8_quant (
136+ input_tensor , scale = None , scale_ub = None , use_per_token_if_dynamic = True
137+ )
123138 m = input_tensor .shape [0 ]
124139 n = weights [0 ].shape [1 ]
125140 if out is None :
@@ -129,7 +144,7 @@ def apply_scaled_mm_fp8(
129144 )
130145 else :
131146 out = torch .empty ((m , n ), dtype = input_tensor .dtype , device = input_tensor .device )
132- torch .ops ._C .cutlass_scaled_mm (out , x_q , weights [ 0 ], x_scale , weights [ 1 ] , bias )
147+ torch .ops ._C .cutlass_scaled_mm (out , qinput_tensor , qweight , input_scale , weight_scale , bias )
133148 return out
134149
135150 def apply_pingpong_fp8 (
0 commit comments