@@ -223,21 +223,6 @@ def compute_expert_w_grad(
223223 weight ._apply_backward_hook ()
224224 return result
225225
226- @staticmethod
227- def common_fp8_mlp_fwd (x , w1 , w2 ):
228- # ===== o1 = deep_gemm(x_fp8, w1_t_fp8) =====
229- o1 , x_fp8 , x_scale = FP8LinearFunctionBase .compute_fp8_linear (
230- x , w1 , weight_transpose = True , return_transpose_only = True , return_mode = "with_input_quant"
231- )
232-
233- # ===== o2 = swiglu(o1) =====
234- o2 = swiglu (o1 )
235-
236- # ===== o3 = deep_gemm(o2_fp8, w2_t_fp8) =====
237- o3 = FP8LinearFunctionBase .compute_fp8_linear (o2 , w2 , weight_transpose = True , return_transpose_only = True )
238-
239- return x_fp8 , x_scale , o3
240-
241226 @staticmethod
242227 def common_fp8_mlp_bwd (do3 , x_fp8 , x_scale , x_t_fp8 , x_t_scale , w1 , w2 , apply_backward_hook = False ):
243228
@@ -303,12 +288,21 @@ def fp8_mlp_fwd(x, w1, w2):
303288 x_orig_shape = x .shape
304289 x = x .reshape ([- 1 , x_orig_shape [- 1 ]])
305290
306- _ , _ , o3 = FP8LinearFunctionBase .common_fp8_mlp_fwd (x , w1 , w2 )
291+ # ===== o1 = deep_gemm(x_fp8, w1_t_fp8) =====
292+ o1 , x_fp8 , x_scale = FP8LinearFunctionBase .compute_fp8_linear (
293+ x , w1 , weight_transpose = True , return_transpose_only = True , return_mode = "with_input_quant"
294+ )
295+
296+ # ===== o2 = swiglu(o1) =====
297+ o2 = swiglu (o1 )
298+
299+ # ===== o3 = deep_gemm(o2_fp8, w2_t_fp8) =====
300+ o3 = FP8LinearFunctionBase .compute_fp8_linear (o2 , w2 , weight_transpose = True , return_transpose_only = True )
307301
308302 if len (x_orig_shape ) > 2 :
309303 o3 = o3 .reshape ([x_orig_shape [0 ], - 1 , o3 .shape [- 1 ]])
310304
311- return o3
305+ return x_fp8 , x_scale , o3
312306
313307 @staticmethod
314308 def fp8_mlp_fwd_norm_rc (x , norm_w , norm_eps , w1 , w2 ):
@@ -462,7 +456,7 @@ def forward(self, x):
462456 return FP8LinearFunction .apply (x , self , keep_x = True )
463457
464458
465- class FP8NormMlpRecomputeFunction (paddle .autograd .PyLayer ):
459+ class FusedNormFP8MLPFunction (paddle .autograd .PyLayer ):
466460 @staticmethod
467461 def forward (ctx , x , norm_w , w1 , w2 , norm_eps ):
468462 # ===== compute norm_output =====
@@ -529,7 +523,7 @@ def forward(ctx, x, w1, w2):
529523 x = x .reshape ([- 1 , x_orig_shape [- 1 ]])
530524
531525 # ===== call func fp8_mlp_fwd =====
532- x_fp8 , x_scale , o3 = FP8LinearFunctionBase .common_fp8_mlp_fwd (x , w1 , w2 )
526+ x_fp8 , x_scale , o3 = FP8LinearFunctionBase .fp8_mlp_fwd (x , w1 , w2 )
533527 # ===== reshape to origin shape =====
534528 if len (x_orig_shape ) > 2 :
535529 o3 = o3 .reshape ([x_orig_shape [0 ], - 1 , o3 .shape [- 1 ]])
@@ -610,7 +604,7 @@ def __init__(
610604
611605 def forward (self , x ):
612606 if self .using_post_norm_recompute :
613- return FP8NormMlpRecomputeFunction .apply (x , self .norm_weight , self .w1 , self .w2 , self .norm_eps )
607+ return FusedNormFP8MLPFunction .apply (x , self .norm_weight , self .w1 , self .w2 , self .norm_eps )
614608 else :
615609 return FP8MlpFunction .apply (x , self .w1 , self .w2 )
616610
0 commit comments