@@ -157,6 +157,7 @@ def __init__(
157
157
if self .using_post_norm_recompute :
158
158
assert self .shared_experts is not None
159
159
assert self .shared_experts .norm_weight is not None and self .shared_experts .norm_eps is not None
160
+
160
161
def forward_without_residual (self , inputs ):
161
162
162
163
if isinstance (inputs , list ):
@@ -178,13 +179,15 @@ def forward_without_residual(self, inputs):
178
179
self .shared_experts .w2 ,
179
180
)
180
181
else :
181
- shared_expert_output = FP8LinearFunctionBase .fp8_mlp_fwd (hidden_states , self .shared_experts .w1 , self .shared_experts .w2 )
182
+ _ , _ , shared_expert_output = FP8LinearFunctionBase .fp8_mlp_fwd (
183
+ hidden_states , self .shared_experts .w1 , self .shared_experts .w2
184
+ )
182
185
residual = residual + shared_expert_output
183
186
184
187
self .x = hidden_states
185
188
self .l_aux = l_aux
186
189
187
- hidden_states = residual
190
+ hidden_states = residual
188
191
hidden_states .stop_gradient = False
189
192
190
193
if self .send_mtp_embed :
@@ -467,14 +470,16 @@ def __init__(self, forward_nodes, backward_nodes, use_fuion=True):
467
470
self .nodes .append (schedule_node_class (f , b , f"OverlapedNode_{ len (self .nodes )} " ))
468
471
469
472
def forward_backward (self , inputs , output_grad , combine_bw_event_to_wait = None , pp_stream = None ):
470
- #print(" fwd pp stream", pp_stream)
473
+ # print(" fwd pp stream", pp_stream)
471
474
event_to_wait = combine_bw_event_to_wait
472
475
for i , n in enumerate (self .nodes ):
473
476
pp_stream_t = pp_stream
474
477
if i + 1 != len (self .nodes ):
475
478
pp_stream_t = None
476
-
477
- inputs , output_grad , event_to_wait = n .forward_backward (inputs , output_grad , combine_bw_event_to_wait = event_to_wait , pp_stream = pp_stream_t )
479
+
480
+ inputs , output_grad , event_to_wait = n .forward_backward (
481
+ inputs , output_grad , combine_bw_event_to_wait = event_to_wait , pp_stream = pp_stream_t
482
+ )
478
483
return inputs , output_grad , None
479
484
480
485
@@ -677,8 +682,8 @@ def combine_backward(self, output_grad, previous_event=None, async_finish=False,
677
682
output_combine_grad ,
678
683
quant_event ,
679
684
) = output_grad
680
-
681
- if DSV3_USE_FP8_DISPATCH and quant_event is not None :
685
+
686
+ if DSV3_USE_FP8_DISPATCH and quant_event is not None :
682
687
combine_backward_wait_event = quant_event
683
688
else :
684
689
combine_backward_wait_event = previous_event
@@ -809,11 +814,13 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
809
814
paddle .base .core .nvprof_nvtx_push ("combine_backward" )
810
815
if combine_bw_event_to_wait is not None :
811
816
# print(" event", combine_bw_event_to_wait)
812
- output_grad = self .backward_node .combine_backward (output_grad , previous_event = combine_bw_event_to_wait , async_finish = True ,
813
- allocate_on_comm_stream = True )
817
+ output_grad = self .backward_node .combine_backward (
818
+ output_grad , previous_event = combine_bw_event_to_wait , async_finish = True , allocate_on_comm_stream = True
819
+ )
814
820
else :
815
- output_grad = self .backward_node .combine_backward (output_grad , previous_event = combine_bwd_event , async_finish = True ,
816
- allocate_on_comm_stream = True )
821
+ output_grad = self .backward_node .combine_backward (
822
+ output_grad , previous_event = combine_bwd_event , async_finish = True , allocate_on_comm_stream = True
823
+ )
817
824
# get combine event
818
825
combine_backward_event = deep_ep .get_event_from_comm_stream (self .backward_node .moe_group .id )
819
826
paddle .base .core .nvprof_nvtx_pop ()
@@ -850,22 +857,23 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
850
857
paddle .base .core .nvprof_nvtx_pop ()
851
858
mlp_fwd_event = deep_ep .get_event_from_calc_stream (self .forward_node .moe_group .id )
852
859
853
-
854
860
if pp_stream is not None :
855
- final_out = self .forward_node .post_process_node .forward_without_residual (inputs )
856
-
861
+ final_out = self .forward_node .post_process_node .forward_without_residual (inputs )
862
+
857
863
final_out_event = deep_ep .get_event_from_calc_stream (self .forward_node .moe_group .id )
858
-
864
+
859
865
paddle .base .core .nvprof_nvtx_push ("combine_forward" )
860
- inputs = self .forward_node .combine_forward (inputs , previous_event = mlp_fwd_event , async_finish = True , allocate_on_comm_stream = True )
866
+ inputs = self .forward_node .combine_forward (
867
+ inputs , previous_event = mlp_fwd_event , async_finish = True , allocate_on_comm_stream = True
868
+ )
861
869
paddle .base .core .nvprof_nvtx_pop ()
862
870
863
- combine_forward_event = deep_ep .get_event_from_comm_stream ( self .forward_node .moe_group .id )
871
+ combine_forward_event = deep_ep .get_event_from_comm_stream (self .forward_node .moe_group .id )
864
872
865
873
combine_fwd_out = inputs [- 1 ]
866
874
867
875
if pp_stream is not None :
868
- send_recv_stream = paddle .device .Stream (stream_base = pp_stream )
876
+ send_recv_stream = paddle .device .Stream (stream_base = pp_stream )
869
877
870
878
# combine_forward_event.custom_stream_wait( pp_stream)
871
879
# final_out_event.custom_stream_wait(pp_stream)
@@ -876,16 +884,15 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
876
884
combine_forward_event .current_stream_wait ()
877
885
final_out_event .current_stream_wait ()
878
886
879
- inputs = final_out + combine_fwd_out
887
+ inputs = final_out + combine_fwd_out
880
888
881
889
final_out ._record_stream ()
882
890
combine_fwd_out ._record_stream ()
883
-
891
+
884
892
paddle .base .core .nvprof_nvtx_pop ()
885
893
886
894
dispatch_backward_event .calc_stream_wait (self .backward_node .moe_group .id )
887
895
paddle .base .core .nvprof_nvtx_push ("post_process_forward" )
888
-
889
896
890
897
paddle .base .core .nvprof_nvtx_pop ()
891
898
paddle .base .core .nvprof_nvtx_push ("attn_backward" )
@@ -899,10 +906,10 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
899
906
combine_forward_event .calc_stream_wait (self .forward_node .moe_group .id )
900
907
901
908
final_out = self .forward_node .post_process_node .forward_without_residual (inputs )
902
- inputs = final_out + combine_fwd_out
909
+ inputs = final_out + combine_fwd_out
903
910
904
911
combine_fwd_out ._record_stream ()
905
-
912
+
906
913
paddle .base .core .nvprof_nvtx_pop ()
907
914
return inputs , output_grad , event_to_wait
908
915
@@ -1676,8 +1683,11 @@ def overlapped_forward_backward(
1676
1683
forward_inputs = forward_pre_node .forward (forward_inputs )
1677
1684
backward_input_grads = backward_pre_node .backward (backward_input_grads )
1678
1685
forward_inputs , backward_input_grads , _ = overlap_node .forward_backward (
1679
- forward_inputs , backward_input_grads , combine_bw_event_to_wait = combine_bw_event_to_wait ,
1680
- pp_stream = pp_stream )
1686
+ forward_inputs ,
1687
+ backward_input_grads ,
1688
+ combine_bw_event_to_wait = combine_bw_event_to_wait ,
1689
+ pp_stream = pp_stream ,
1690
+ )
1681
1691
forward_inputs = forward_post_node .forward (forward_inputs )
1682
1692
backward_input_grads = backward_post_node .backward (backward_input_grads )
1683
1693
@@ -1688,4 +1698,3 @@ def overlapped_forward_backward(
1688
1698
1689
1699
forward_inputs = [forward_inputs ] if isinstance (forward_inputs , paddle .Tensor ) else forward_inputs
1690
1700
return forward_inputs , forward_loss , backward_input_grads
1691
-
0 commit comments