@@ -648,9 +648,10 @@ def backward(self, output_combine_grad, previous_event=None, async_finish=False,
648
648
649
649
650
650
class Fp8CombineQuantNode :
651
- def __init__ (self , token_dispatcher , name = "fp8_combine_quant_node" ):
651
+ def __init__ (self , token_dispatcher , moe_group = None , name = "fp8_combine_quant_node" ):
652
652
self .token_dispatcher = token_dispatcher
653
653
self .name = name
654
+ self .moe_group = moe_group
654
655
655
656
@paddle .no_grad ()
656
657
def forward (self , output_combine ):
@@ -666,13 +667,14 @@ def backward(self, output_grad, event_to_wait=None):
666
667
# post combine grad
667
668
if DSV3_USE_FP8_DISPATCH :
668
669
if event_to_wait is not None :
670
+ assert self .moe_group is not None
671
+ event_to_wait .comm_stream_wait ( self .moe_group .id )
669
672
buffer = get_buffer (self .token_dispatcher ._comm_manager .group , get_hidden_bytes (output_grad ))
670
673
custom_stream = paddle .device .Stream (stream_base = buffer .runtime .get_comm_stream ())
671
- custom_stream .wait_event (event_to_wait )
672
674
else :
673
675
custom_stream = paddle .device .current_stream ()
674
676
with paddle .device .stream_guard (custom_stream ):
675
- output_combine_grad = paddle .reshape (output_grad , self . output_combine_shape )
677
+ output_combine_grad = paddle .reshape (output_grad , [ - 1 , output_grad . shape [ - 1 ]] )
676
678
# output_combine_grad quant to fp8
677
679
output_combine_grad_fp8 , output_combine_grad_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
678
680
output_combine_grad , output_scale_transpose = False , quant_method = "1x128" , input_transpose = False
@@ -683,7 +685,7 @@ def backward(self, output_grad, event_to_wait=None):
683
685
quant_event = deep_ep .get_event_from_custom_stream (custom_stream .stream_base )
684
686
return (output_combine_grad_fp8 , output_combine_grad_scale ), quant_event
685
687
else :
686
- output_combine_grad = paddle .reshape (output_grad , self . output_combine_shape )
688
+ output_combine_grad = paddle .reshape (output_grad , [ - 1 , output_grad . shape [ - 1 ]] )
687
689
return output_combine_grad , None
688
690
689
691
@@ -873,7 +875,7 @@ def __init__(
873
875
is_split_group_gemm = is_split_group_gemm ,
874
876
)
875
877
self .combine_node = Fp8CombineNode (self .token_dispatcher )
876
- self .combine_quant_node = Fp8CombineQuantNode (self .token_dispatcher )
878
+ self .combine_quant_node = Fp8CombineQuantNode (self .token_dispatcher , custom_map . moe_group )
877
879
self .name = name
878
880
879
881
@paddle .no_grad ()
0 commit comments