67
67
68
68
69
69
DSV3_USE_FP8_GEMM = os .getenv ("DSV3_USE_FP8_GEMM" , "False" ).lower () == "true"
70
+ DSV3_USE_FP8_DISPATCH = os .getenv ("DSV3_USE_FP8_DISPATCH" , "False" ).lower () == "true"
70
71
71
72
72
73
def parse_args (args ):
@@ -156,6 +157,40 @@ def __init__(
156
157
if self .using_post_norm_recompute :
157
158
assert self .shared_experts is not None
158
159
assert self .shared_experts .norm_weight is not None and self .shared_experts .norm_eps is not None
160
+ def forward_without_residual (self , inputs ):
161
+
162
+ if isinstance (inputs , list ):
163
+ inputs = tuple (inputs )
164
+
165
+ if self .send_mtp_embed :
166
+ (inputs_embeds_mtp , hidden_states , residual , l_aux , final_hidden_states ) = inputs
167
+ else :
168
+ (hidden_states , residual , l_aux , final_hidden_states ) = inputs
169
+
170
+ with paddle .no_grad ():
171
+ if self .shared_experts is not None :
172
+ if self .using_post_norm_recompute :
173
+ shared_expert_output = fp8_mlp_fwd_norm_rc (
174
+ hidden_states ,
175
+ self .shared_experts .norm_weight ,
176
+ self .shared_experts .norm_eps ,
177
+ self .shared_experts .w1 ,
178
+ self .shared_experts .w2 ,
179
+ )
180
+ else :
181
+ shared_expert_output = fp8_mlp_fwd (hidden_states , self .shared_experts .w1 , self .shared_experts .w2 )
182
+ residual = residual + shared_expert_output
183
+
184
+ self .x = hidden_states
185
+ self .l_aux = l_aux
186
+
187
+ hidden_states = residual
188
+ hidden_states .stop_gradient = False
189
+
190
+ if self .send_mtp_embed :
191
+ hidden_states = paddle .concat ([hidden_states , inputs_embeds_mtp ], axis = - 1 )
192
+
193
+ return return_args (hidden_states )
159
194
160
195
def forward (self , inputs ):
161
196
@@ -431,9 +466,15 @@ def __init__(self, forward_nodes, backward_nodes, use_fuion=True):
431
466
for f , b in zip (forward_nodes , backward_nodes ):
432
467
self .nodes .append (schedule_node_class (f , b , f"OverlapedNode_{ len (self .nodes )} " ))
433
468
434
- def forward_backward (self , inputs , output_grad , event_to_wait = None ):
435
- for n in self .nodes :
436
- inputs , output_grad , event_to_wait = n .forward_backward (inputs , output_grad , event_to_wait )
469
+ def forward_backward (self , inputs , output_grad , combine_bw_event_to_wait = None , pp_stream = None ):
470
+ #print(" fwd pp stream", pp_stream)
471
+ event_to_wait = combine_bw_event_to_wait
472
+ for i , n in enumerate (self .nodes ):
473
+ pp_stream_t = pp_stream
474
+ if i + 1 != len (self .nodes ):
475
+ pp_stream_t = None
476
+
477
+ inputs , output_grad , event_to_wait = n .forward_backward (inputs , output_grad , combine_bw_event_to_wait = event_to_wait , pp_stream = pp_stream_t )
437
478
return inputs , output_grad , None
438
479
439
480
@@ -586,7 +627,7 @@ def combine_forward(self, inputs, async_finish=False, previous_event=None, alloc
586
627
ret = (inputs_embeds_mtp , * ret ) if self .send_mtp_embed else ret
587
628
return ret
588
629
589
- def post_process_forward (self , inputs ):
630
+ def post_process_forward (self , inputs , with_residual = True ):
590
631
if self .send_mtp_embed :
591
632
(inputs_embeds_mtp , hidden_states , residual , l_aux , output_combine ) = inputs
592
633
else :
@@ -596,7 +637,10 @@ def post_process_forward(self, inputs):
596
637
inputs = (hidden_states , residual , l_aux , final_hidden_states )
597
638
inputs = (inputs_embeds_mtp , * inputs ) if self .send_mtp_embed else inputs
598
639
599
- inputs = self .post_process_node .forward (inputs )
640
+ if with_residual :
641
+ inputs = self .post_process_node .forward (inputs )
642
+ else :
643
+ inputs = self .post_process_node .forward_without_residual (inputs )
600
644
return inputs
601
645
602
646
def post_process_backward (self , output_grad , event_to_wait = None ):
@@ -615,7 +659,7 @@ def post_process_backward(self, output_grad, event_to_wait=None):
615
659
ret = (inputs_embeds_mtp_grad , * ret ) if self .send_mtp_embed else ret
616
660
return ret
617
661
618
- def combine_backward (self , output_grad , async_finish = False , allocate_on_comm_stream = False ):
662
+ def combine_backward (self , output_grad , previous_event = None , async_finish = False , allocate_on_comm_stream = False ):
619
663
if self .send_mtp_embed :
620
664
(
621
665
inputs_embeds_mtp_grad ,
@@ -626,12 +670,22 @@ def combine_backward(self, output_grad, async_finish=False, allocate_on_comm_str
626
670
quant_event ,
627
671
) = output_grad
628
672
else :
629
- hidden_states_grad , residual_grad , l_aux_grad , output_combine_grad , quant_event = output_grad
630
-
673
+ (
674
+ hidden_states_grad ,
675
+ residual_grad ,
676
+ l_aux_grad ,
677
+ output_combine_grad ,
678
+ quant_event ,
679
+ ) = output_grad
680
+
681
+ if DSV3_USE_FP8_DISPATCH and quant_event is not None :
682
+ combine_backward_wait_event = quant_event
683
+ else :
684
+ combine_backward_wait_event = previous_event
631
685
hidden_states_out_grad = self .fp8_fusion_moe_node .combine_node .backward (
632
686
output_combine_grad ,
633
687
async_finish = async_finish ,
634
- previous_event = quant_event ,
688
+ previous_event = combine_backward_wait_event ,
635
689
allocate_on_comm_stream = allocate_on_comm_stream and quant_event is not None ,
636
690
)
637
691
@@ -738,25 +792,32 @@ def __init__(self, forward_node, backward_node, name=""):
738
792
self .backward_node = backward_node
739
793
self .name = name
740
794
741
- def forward_backward (self , inputs , output_grad , event_to_wait = None ):
795
+ def forward_backward (self , inputs , output_grad , combine_bw_event_to_wait = None , pp_stream = None ):
742
796
paddle .base .core .nvprof_nvtx_push ("forward_backward" )
743
797
798
+ combine_bwd_event = deep_ep .get_event_from_calc_stream (self .backward_node .moe_group .id )
799
+
800
+ paddle .base .core .nvprof_nvtx_push ("attn_forward" )
801
+ inputs = self .forward_node .attn_forward (inputs )
802
+ paddle .base .core .nvprof_nvtx_pop ()
803
+ attn_compute_event = deep_ep .get_event_from_calc_stream (self .forward_node .moe_group .id )
804
+
744
805
paddle .base .core .nvprof_nvtx_push ("post_process_backward" )
745
- output_grad = self .backward_node .post_process_backward (output_grad , event_to_wait )
806
+ output_grad = self .backward_node .post_process_backward (output_grad , combine_bw_event_to_wait )
746
807
paddle .base .core .nvprof_nvtx_pop ()
747
808
748
809
paddle .base .core .nvprof_nvtx_push ("combine_backward" )
749
- output_grad = self .backward_node .combine_backward (output_grad , async_finish = True , allocate_on_comm_stream = True )
810
+ if combine_bw_event_to_wait is not None :
811
+ # print(" event", combine_bw_event_to_wait)
812
+ output_grad = self .backward_node .combine_backward (output_grad , previous_event = combine_bw_event_to_wait , async_finish = True ,
813
+ allocate_on_comm_stream = True )
814
+ else :
815
+ output_grad = self .backward_node .combine_backward (output_grad , previous_event = combine_bwd_event , async_finish = True ,
816
+ allocate_on_comm_stream = True )
750
817
# get combine event
751
818
combine_backward_event = deep_ep .get_event_from_comm_stream (self .backward_node .moe_group .id )
752
819
paddle .base .core .nvprof_nvtx_pop ()
753
820
754
- paddle .base .core .nvprof_nvtx_push ("attn_forward" )
755
- inputs = self .forward_node .attn_forward (inputs )
756
- paddle .base .core .nvprof_nvtx_pop ()
757
-
758
- attn_compute_event = deep_ep .get_event_from_calc_stream (self .forward_node .moe_group .id )
759
-
760
821
combine_backward_event .calc_stream_wait (self .backward_node .moe_group .id )
761
822
paddle .base .core .nvprof_nvtx_push ("mlp_backward_dx" )
762
823
output_grad = self .backward_node .mlp_backward (output_grad )
@@ -787,26 +848,61 @@ def forward_backward(self, inputs, output_grad, event_to_wait=None):
787
848
paddle .base .core .nvprof_nvtx_push ("mlp_forward" )
788
849
inputs = self .forward_node .mlp_forward (inputs )
789
850
paddle .base .core .nvprof_nvtx_pop ()
851
+ mlp_fwd_event = deep_ep .get_event_from_calc_stream (self .forward_node .moe_group .id )
790
852
791
- inputs_event = deep_ep .get_event_from_calc_stream (self .forward_node .moe_group .id )
792
853
854
+ if pp_stream is not None :
855
+ final_out = self .forward_node .post_process_node .forward_without_residual (inputs )
856
+
857
+ final_out_event = deep_ep .get_event_from_calc_stream (self .forward_node .moe_group .id )
858
+
793
859
paddle .base .core .nvprof_nvtx_push ("combine_forward" )
794
- inputs = self .forward_node .combine_forward (
795
- inputs , async_finish = True , previous_event = inputs_event , allocate_on_comm_stream = True
796
- )
860
+ inputs = self .forward_node .combine_forward (inputs , previous_event = mlp_fwd_event , async_finish = True , allocate_on_comm_stream = True )
797
861
paddle .base .core .nvprof_nvtx_pop ()
798
- combine_forward_event = deep_ep .get_event_from_comm_stream (self .forward_node .moe_group .id )
862
+
863
+ combine_forward_event = deep_ep .get_event_from_comm_stream ( self .forward_node .moe_group .id )
864
+
865
+ combine_fwd_out = inputs [- 1 ]
866
+
867
+ if pp_stream is not None :
868
+ send_recv_stream = paddle .device .Stream (stream_base = pp_stream )
869
+
870
+ # combine_forward_event.custom_stream_wait( pp_stream)
871
+ # final_out_event.custom_stream_wait(pp_stream)
872
+
873
+ paddle .base .core .nvprof_nvtx_push ("pp stream add" )
874
+
875
+ with paddle .device .stream_guard (send_recv_stream ):
876
+ combine_forward_event .current_stream_wait ()
877
+ final_out_event .current_stream_wait ()
878
+
879
+ inputs = final_out + combine_fwd_out
880
+
881
+ final_out ._record_stream ()
882
+ combine_fwd_out ._record_stream ()
883
+
884
+ paddle .base .core .nvprof_nvtx_pop ()
799
885
800
886
dispatch_backward_event .calc_stream_wait (self .backward_node .moe_group .id )
887
+ paddle .base .core .nvprof_nvtx_push ("post_process_forward" )
888
+
889
+
890
+ paddle .base .core .nvprof_nvtx_pop ()
801
891
paddle .base .core .nvprof_nvtx_push ("attn_backward" )
802
892
output_grad = self .backward_node .attn_backward (output_grad )
803
893
event_to_wait = deep_ep .get_event_from_calc_stream (self .backward_node .moe_group .id )
804
- paddle .base .core .nvprof_nvtx_pop ()
805
894
806
- combine_forward_event .calc_stream_wait (self .forward_node .moe_group .id )
807
- paddle .base .core .nvprof_nvtx_push ("post_process_forward" )
808
- inputs = self .forward_node .post_process_forward (inputs )
809
895
paddle .base .core .nvprof_nvtx_pop ()
896
+
897
+ # residual add
898
+ if pp_stream is None :
899
+ combine_forward_event .calc_stream_wait (self .forward_node .moe_group .id )
900
+
901
+ final_out = self .forward_node .post_process_node .forward_without_residual (inputs )
902
+ inputs = final_out + combine_fwd_out
903
+
904
+ combine_fwd_out ._record_stream ()
905
+
810
906
paddle .base .core .nvprof_nvtx_pop ()
811
907
return inputs , output_grad , event_to_wait
812
908
@@ -1580,8 +1676,8 @@ def overlapped_forward_backward(
1580
1676
forward_inputs = forward_pre_node .forward (forward_inputs )
1581
1677
backward_input_grads = backward_pre_node .backward (backward_input_grads )
1582
1678
forward_inputs , backward_input_grads , _ = overlap_node .forward_backward (
1583
- forward_inputs , backward_input_grads , combine_bw_event_to_wait
1584
- )
1679
+ forward_inputs , backward_input_grads , combine_bw_event_to_wait = combine_bw_event_to_wait ,
1680
+ pp_stream = pp_stream )
1585
1681
forward_inputs = forward_post_node .forward (forward_inputs )
1586
1682
backward_input_grads = backward_post_node .backward (backward_input_grads )
1587
1683
@@ -1592,3 +1688,4 @@ def overlapped_forward_backward(
1592
1688
1593
1689
forward_inputs = [forward_inputs ] if isinstance (forward_inputs , paddle .Tensor ) else forward_inputs
1594
1690
return forward_inputs , forward_loss , backward_input_grads
1691
+
0 commit comments