@@ -165,3 +165,40 @@ def apply_pingpong_fp8(
165165 from fp8_pingpong_gemm import cutlass_scaled_mm
166166
167167 return cutlass_scaled_mm (x_q , weights [0 ], x_scale , weights [1 ], out )
168+
169+
170+ @QUANTMETHODS .register (["vllm-fp8w8a8-b128" ])
171+ class vLLMFP8w8a8B128QuantizationMethod (vLLMBaseQuantizationMethod ):
172+ def __init__ (self ):
173+ super ().__init__ ()
174+ self .blcok_size = 128
175+
176+ def quantize (self , weight : torch .Tensor ):
177+ if self .is_moe :
178+ return self .quantize_moe (weight )
179+ qweight , weight_scale = ops .scaled_fp8_quant (
180+ weight .contiguous ().cuda (self .device_id_ ), scale = None , use_per_token_if_dynamic = True
181+ )
182+ return qweight .transpose (0 , 1 ), weight_scale
183+
184+ def apply (self , input_tensor , weights , bias = None , out = None , workspace = None , use_custom_tensor_mananger = True ):
185+ qweight , weight_scale , input_scale = weights
186+ m , k = input_tensor .shape
187+ n = weights [0 ].shape [1 ]
188+ if input_scale is None :
189+ input_scale = self .cache_manager .alloc_tensor (
190+ (m , k // self .block_size ), torch .float32 , device = input_tensor .device , is_graph_out = False
191+ )
192+ qinput_tensor = self .cache_manager .alloc_tensor (
193+ (m , k ), qweight .dtype , device = qweight .device , is_graph_out = False
194+ )
195+ per_token_group_quant_fp8 (input_tensor , self .block_size , qinput_tensor , input_scale )
196+ if out is None :
197+ if use_custom_tensor_mananger :
198+ out = self .cache_manager .alloc_tensor (
199+ (m , n ), input_tensor .dtype , device = input_tensor .device , is_graph_out = False
200+ )
201+ else :
202+ out = torch .empty ((m , n ), dtype = input_tensor .dtype , device = input_tensor .device )
203+ torch .ops ._C .cutlass_scaled_mm (out , qinput_tensor , qweight , input_scale , weight_scale , bias )
204+ return out
0 commit comments