@@ -1531,15 +1531,15 @@ def __init__(
15311531 if ops is not None :
15321532 self .act_quant_func = self .act_quant_fp8_perchannel_sym_vllm
15331533 else :
1534- self .act_quant_func = self . fp8_quantize_triton
1534+ self .act_quant_func = fp8_quantize_triton
15351535
15361536 def apply (self , input_tensor ):
15371537 input_tensor_quant , input_tensor_scale = self .act_quant_func (input_tensor )
15381538 output_tensor = fp8_linear (
15391539 input_tensor_quant ,
15401540 self .weight ,
15411541 self .bias .float () if self .bias is not None else None ,
1542- input_tensor_scale ,
1542+ input_tensor_scale . float () ,
15431543 self .weight_scale ,
15441544 out_dtype = self .infer_dtype ,
15451545 )
@@ -1582,15 +1582,15 @@ def __init__(
15821582 if ops is not None :
15831583 self .act_quant_func = self .act_quant_int8_perchannel_sym_vllm
15841584 else :
1585- self .act_quant_func = self . int8_quantize_triton
1585+ self .act_quant_func = int8_quantize_triton
15861586
15871587 def apply (self , input_tensor ):
15881588 input_tensor_quant , input_tensor_scale = self .act_quant_func (input_tensor )
15891589 output_tensor = q8_linear (
15901590 input_tensor_quant ,
15911591 self .weight ,
15921592 self .bias .float () if self .bias is not None else None ,
1593- input_tensor_scale ,
1593+ input_tensor_scale . float () ,
15941594 self .weight_scale ,
15951595 fuse_gelu = False ,
15961596 out_dtype = self .infer_dtype ,
0 commit comments