@@ -284,28 +284,35 @@ def fp8_mlp_fwd(x, w1, w2):
284
284
if len (x_orig_shape ) > 2 :
285
285
o3 = o3 .reshape ([x_orig_shape [0 ], - 1 , o3 .shape [- 1 ]])
286
286
287
- return x_fp8 , x_scale , o3
287
+ return o3
288
288
289
289
290
- def fp8_mlp_bwd (do3 , x_fp8 , x_scale , w1 , w2 ):
290
+ def fp8_mlp_bwd (do3 , x , w1 , w2 ):
291
291
do3_orig_shape = do3 .shape
292
292
do3 = do3 .reshape ([- 1 , do3_orig_shape [- 1 ]])
293
293
294
- x_orig_shape = x_fp8 .shape
294
+ x_orig_shape = x .shape
295
+ x = x .reshape ([- 1 , x_orig_shape [- 1 ]])
296
+
297
+ if x .shape [0 ] % 128 == 0 :
298
+ x_fp8 , x_scale , x_t_fp8 , x_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
299
+ x , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
300
+ )
301
+ else :
302
+ x_fp8 , x_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
303
+ x , output_scale_transpose = True , quant_method = "1x128" , input_transpose = False
304
+ )
305
+ x = padding (x , 0 )
306
+ _ , _ , x_t_fp8 , x_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
307
+ x , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
308
+ )
295
309
296
310
_ , _ , w1_fp8 , w1_sacle = paddle .incubate .nn .functional .fp8_quant_blockwise (
297
311
w1 , output_scale_transpose = False , quant_method = "128x128" , input_transpose = True
298
312
)
299
313
o1 = paddle .empty ([x_fp8 .shape [0 ], w1_fp8 .shape [0 ]], dtype = do3 .dtype )
300
314
deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w1_fp8 , w1_sacle ), o1 , num_sms = 112 )
301
315
302
- x_dequant_fp16 = paddle .incubate .nn .functional .fused_act_dequant (x_fp8 , x_scale .T .contiguous ())
303
- x_dequant_fp16 = padding (x_dequant_fp16 , 0 )
304
-
305
- _ , _ , x_t_fp8 , x_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
306
- x_dequant_fp16 , output_scale_transpose = True , quant_method = "1x128" , input_transpose = True
307
- )
308
-
309
316
# ===== [recompute] o2 = swiglu(o1) =====
310
317
o2 = swiglu (o1 )
311
318
@@ -577,7 +584,10 @@ def split_group_gemm(x_fp8, x_scale, w_fp8, w_scale, tokens_per_expert, gemm_out
577
584
x_scale_tma_align = x_scale [start_idx :end_idx ].T .contiguous ().T
578
585
579
586
deep_gemm .gemm_fp8_fp8_bf16_nt (
580
- (x_fp8 [start_idx :end_idx ], x_scale_tma_align ), (w_fp8 [i ], w_scale [i ]), gemm_out [start_idx :end_idx ], num_sms = 112
587
+ (x_fp8 [start_idx :end_idx ], x_scale_tma_align ),
588
+ (w_fp8 [i ], w_scale [i ]),
589
+ gemm_out [start_idx :end_idx ],
590
+ num_sms = 112 ,
581
591
)
582
592
583
593
start_idx = end_idx
@@ -763,7 +773,7 @@ def bwd_dowm_input(self, expert_w2, unzipped_grad, o1, inplace_swiglu_prob=False
763
773
(bw_w2_quant , bw_w2_scale ),
764
774
do2_s ,
765
775
m_indices = self .m_indices ,
766
- num_sms = 112
776
+ num_sms = 112 ,
767
777
)
768
778
769
779
with paddle .amp .auto_cast (False ):
0 commit comments