@@ -114,7 +114,9 @@ def quantize_moe(self, weight):
114114 def apply (self , input_tensor , weights , bias = None , out = None , workspace = None , use_custom_tensor_mananger = True ):
115115 raise Exception ("This function needs to be bound." )
116116
117- def apply_scaled_mm_fp8 (self , input_tensor , weights , bias = None , out = None , workspace = None , use_custom_tensor_mananger = True ):
117+ def apply_scaled_mm_fp8 (
118+ self , input_tensor , weights , bias = None , out = None , workspace = None , use_custom_tensor_mananger = True
119+ ):
118120 x_q , x_scale = ops .scaled_fp8_quant (input_tensor , scale = None , scale_ub = None , use_per_token_if_dynamic = True )
119121 m = input_tensor .shape [0 ]
120122 n = weights [0 ].shape [1 ]
@@ -128,7 +130,9 @@ def apply_scaled_mm_fp8(self, input_tensor, weights, bias=None, out=None, worksp
128130 torch .ops ._C .cutlass_scaled_mm (out , x_q , weights [0 ], x_scale , weights [1 ], bias )
129131 return out
130132
131- def apply_pingpong_fp8 (self , input_tensor , weights , bias = None , out = None , workspace = None , use_custom_tensor_mananger = True ):
133+ def apply_pingpong_fp8 (
134+ self , input_tensor , weights , bias = None , out = None , workspace = None , use_custom_tensor_mananger = True
135+ ):
132136 x_q , x_scale = ops .scaled_fp8_quant (input_tensor , scale = None , scale_ub = None , use_per_token_if_dynamic = False )
133137 assert bias is None
134138 m = input_tensor .shape [0 ]
@@ -140,7 +144,7 @@ def apply_pingpong_fp8(self, input_tensor, weights, bias=None, out=None, workspa
140144 )
141145 else :
142146 out = torch .empty ((m , n ), dtype = input_tensor .dtype , device = input_tensor .device )
143-
147+
144148 from fp8_pingpong_gemm import cutlass_scaled_mm
145149
146150 return cutlass_scaled_mm (x_q , weights [0 ], x_scale , weights [1 ], out )
0 commit comments