Skip to content

Commit ac58fa4

Browse files
authored
support send recv overlap (#10853)
* support send recv overlap * update * update * update * update * fix bug
1 parent 0e4be4d commit ac58fa4

File tree

1 file changed

+126
-29
lines changed

1 file changed

+126
-29
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 126 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767

6868

6969
DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true"
70+
DSV3_USE_FP8_DISPATCH = os.getenv("DSV3_USE_FP8_DISPATCH", "False").lower() == "true"
7071

7172

7273
def parse_args(args):
@@ -156,6 +157,40 @@ def __init__(
156157
if self.using_post_norm_recompute:
157158
assert self.shared_experts is not None
158159
assert self.shared_experts.norm_weight is not None and self.shared_experts.norm_eps is not None
160+
def forward_without_residual(self, inputs):
161+
162+
if isinstance(inputs, list):
163+
inputs = tuple(inputs)
164+
165+
if self.send_mtp_embed:
166+
(inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states) = inputs
167+
else:
168+
(hidden_states, residual, l_aux, final_hidden_states) = inputs
169+
170+
with paddle.no_grad():
171+
if self.shared_experts is not None:
172+
if self.using_post_norm_recompute:
173+
shared_expert_output = fp8_mlp_fwd_norm_rc(
174+
hidden_states,
175+
self.shared_experts.norm_weight,
176+
self.shared_experts.norm_eps,
177+
self.shared_experts.w1,
178+
self.shared_experts.w2,
179+
)
180+
else:
181+
shared_expert_output = fp8_mlp_fwd(hidden_states, self.shared_experts.w1, self.shared_experts.w2)
182+
residual = residual + shared_expert_output
183+
184+
self.x = hidden_states
185+
self.l_aux = l_aux
186+
187+
hidden_states = residual
188+
hidden_states.stop_gradient = False
189+
190+
if self.send_mtp_embed:
191+
hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1)
192+
193+
return return_args(hidden_states)
159194

160195
def forward(self, inputs):
161196

@@ -431,9 +466,15 @@ def __init__(self, forward_nodes, backward_nodes, use_fuion=True):
431466
for f, b in zip(forward_nodes, backward_nodes):
432467
self.nodes.append(schedule_node_class(f, b, f"OverlapedNode_{len(self.nodes)}"))
433468

434-
def forward_backward(self, inputs, output_grad, event_to_wait=None):
435-
for n in self.nodes:
436-
inputs, output_grad, event_to_wait = n.forward_backward(inputs, output_grad, event_to_wait)
469+
def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None):
470+
#print(" fwd pp stream", pp_stream)
471+
event_to_wait = combine_bw_event_to_wait
472+
for i, n in enumerate(self.nodes):
473+
pp_stream_t = pp_stream
474+
if i + 1 != len(self.nodes):
475+
pp_stream_t = None
476+
477+
inputs, output_grad, event_to_wait = n.forward_backward(inputs, output_grad, combine_bw_event_to_wait=event_to_wait, pp_stream=pp_stream_t)
437478
return inputs, output_grad, None
438479

439480

@@ -586,7 +627,7 @@ def combine_forward(self, inputs, async_finish=False, previous_event=None, alloc
586627
ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret
587628
return ret
588629

589-
def post_process_forward(self, inputs):
630+
def post_process_forward(self, inputs, with_residual=True):
590631
if self.send_mtp_embed:
591632
(inputs_embeds_mtp, hidden_states, residual, l_aux, output_combine) = inputs
592633
else:
@@ -596,7 +637,10 @@ def post_process_forward(self, inputs):
596637
inputs = (hidden_states, residual, l_aux, final_hidden_states)
597638
inputs = (inputs_embeds_mtp, *inputs) if self.send_mtp_embed else inputs
598639

599-
inputs = self.post_process_node.forward(inputs)
640+
if with_residual:
641+
inputs = self.post_process_node.forward(inputs)
642+
else:
643+
inputs = self.post_process_node.forward_without_residual(inputs)
600644
return inputs
601645

602646
def post_process_backward(self, output_grad, event_to_wait=None):
@@ -615,7 +659,7 @@ def post_process_backward(self, output_grad, event_to_wait=None):
615659
ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret
616660
return ret
617661

618-
def combine_backward(self, output_grad, async_finish=False, allocate_on_comm_stream=False):
662+
def combine_backward(self, output_grad, previous_event=None, async_finish=False, allocate_on_comm_stream=False):
619663
if self.send_mtp_embed:
620664
(
621665
inputs_embeds_mtp_grad,
@@ -626,12 +670,22 @@ def combine_backward(self, output_grad, async_finish=False, allocate_on_comm_str
626670
quant_event,
627671
) = output_grad
628672
else:
629-
hidden_states_grad, residual_grad, l_aux_grad, output_combine_grad, quant_event = output_grad
630-
673+
(
674+
hidden_states_grad,
675+
residual_grad,
676+
l_aux_grad,
677+
output_combine_grad,
678+
quant_event,
679+
) = output_grad
680+
681+
if DSV3_USE_FP8_DISPATCH and quant_event is not None :
682+
combine_backward_wait_event = quant_event
683+
else:
684+
combine_backward_wait_event = previous_event
631685
hidden_states_out_grad = self.fp8_fusion_moe_node.combine_node.backward(
632686
output_combine_grad,
633687
async_finish=async_finish,
634-
previous_event=quant_event,
688+
previous_event=combine_backward_wait_event,
635689
allocate_on_comm_stream=allocate_on_comm_stream and quant_event is not None,
636690
)
637691

@@ -738,25 +792,32 @@ def __init__(self, forward_node, backward_node, name=""):
738792
self.backward_node = backward_node
739793
self.name = name
740794

741-
def forward_backward(self, inputs, output_grad, event_to_wait=None):
795+
def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None):
742796
paddle.base.core.nvprof_nvtx_push("forward_backward")
743797

798+
combine_bwd_event = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id)
799+
800+
paddle.base.core.nvprof_nvtx_push("attn_forward")
801+
inputs = self.forward_node.attn_forward(inputs)
802+
paddle.base.core.nvprof_nvtx_pop()
803+
attn_compute_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)
804+
744805
paddle.base.core.nvprof_nvtx_push("post_process_backward")
745-
output_grad = self.backward_node.post_process_backward(output_grad, event_to_wait)
806+
output_grad = self.backward_node.post_process_backward(output_grad, combine_bw_event_to_wait)
746807
paddle.base.core.nvprof_nvtx_pop()
747808

748809
paddle.base.core.nvprof_nvtx_push("combine_backward")
749-
output_grad = self.backward_node.combine_backward(output_grad, async_finish=True, allocate_on_comm_stream=True)
810+
if combine_bw_event_to_wait is not None:
811+
# print(" event", combine_bw_event_to_wait)
812+
output_grad = self.backward_node.combine_backward(output_grad, previous_event= combine_bw_event_to_wait, async_finish=True,
813+
allocate_on_comm_stream=True)
814+
else:
815+
output_grad = self.backward_node.combine_backward(output_grad, previous_event= combine_bwd_event, async_finish=True,
816+
allocate_on_comm_stream=True)
750817
# get combine event
751818
combine_backward_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id)
752819
paddle.base.core.nvprof_nvtx_pop()
753820

754-
paddle.base.core.nvprof_nvtx_push("attn_forward")
755-
inputs = self.forward_node.attn_forward(inputs)
756-
paddle.base.core.nvprof_nvtx_pop()
757-
758-
attn_compute_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)
759-
760821
combine_backward_event.calc_stream_wait(self.backward_node.moe_group.id)
761822
paddle.base.core.nvprof_nvtx_push("mlp_backward_dx")
762823
output_grad = self.backward_node.mlp_backward(output_grad)
@@ -787,26 +848,61 @@ def forward_backward(self, inputs, output_grad, event_to_wait=None):
787848
paddle.base.core.nvprof_nvtx_push("mlp_forward")
788849
inputs = self.forward_node.mlp_forward(inputs)
789850
paddle.base.core.nvprof_nvtx_pop()
851+
mlp_fwd_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)
790852

791-
inputs_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)
792853

854+
if pp_stream is not None:
855+
final_out = self.forward_node.post_process_node.forward_without_residual(inputs)
856+
857+
final_out_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)
858+
793859
paddle.base.core.nvprof_nvtx_push("combine_forward")
794-
inputs = self.forward_node.combine_forward(
795-
inputs, async_finish=True, previous_event=inputs_event, allocate_on_comm_stream=True
796-
)
860+
inputs = self.forward_node.combine_forward(inputs, previous_event= mlp_fwd_event, async_finish=True, allocate_on_comm_stream=True)
797861
paddle.base.core.nvprof_nvtx_pop()
798-
combine_forward_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id)
862+
863+
combine_forward_event = deep_ep.get_event_from_comm_stream( self.forward_node.moe_group.id)
864+
865+
combine_fwd_out = inputs[-1]
866+
867+
if pp_stream is not None:
868+
send_recv_stream = paddle.device.Stream(stream_base= pp_stream )
869+
870+
# combine_forward_event.custom_stream_wait( pp_stream)
871+
# final_out_event.custom_stream_wait(pp_stream)
872+
873+
paddle.base.core.nvprof_nvtx_push("pp stream add")
874+
875+
with paddle.device.stream_guard(send_recv_stream):
876+
combine_forward_event.current_stream_wait()
877+
final_out_event.current_stream_wait()
878+
879+
inputs = final_out + combine_fwd_out
880+
881+
final_out._record_stream()
882+
combine_fwd_out._record_stream()
883+
884+
paddle.base.core.nvprof_nvtx_pop()
799885

800886
dispatch_backward_event.calc_stream_wait(self.backward_node.moe_group.id)
887+
paddle.base.core.nvprof_nvtx_push("post_process_forward")
888+
889+
890+
paddle.base.core.nvprof_nvtx_pop()
801891
paddle.base.core.nvprof_nvtx_push("attn_backward")
802892
output_grad = self.backward_node.attn_backward(output_grad)
803893
event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id)
804-
paddle.base.core.nvprof_nvtx_pop()
805894

806-
combine_forward_event.calc_stream_wait(self.forward_node.moe_group.id)
807-
paddle.base.core.nvprof_nvtx_push("post_process_forward")
808-
inputs = self.forward_node.post_process_forward(inputs)
809895
paddle.base.core.nvprof_nvtx_pop()
896+
897+
# residual add
898+
if pp_stream is None:
899+
combine_forward_event.calc_stream_wait(self.forward_node.moe_group.id)
900+
901+
final_out = self.forward_node.post_process_node.forward_without_residual(inputs)
902+
inputs = final_out + combine_fwd_out
903+
904+
combine_fwd_out._record_stream()
905+
810906
paddle.base.core.nvprof_nvtx_pop()
811907
return inputs, output_grad, event_to_wait
812908

@@ -1580,8 +1676,8 @@ def overlapped_forward_backward(
15801676
forward_inputs = forward_pre_node.forward(forward_inputs)
15811677
backward_input_grads = backward_pre_node.backward(backward_input_grads)
15821678
forward_inputs, backward_input_grads, _ = overlap_node.forward_backward(
1583-
forward_inputs, backward_input_grads, combine_bw_event_to_wait
1584-
)
1679+
forward_inputs, backward_input_grads, combine_bw_event_to_wait = combine_bw_event_to_wait,
1680+
pp_stream = pp_stream)
15851681
forward_inputs = forward_post_node.forward(forward_inputs)
15861682
backward_input_grads = backward_post_node.backward(backward_input_grads)
15871683

@@ -1592,3 +1688,4 @@ def overlapped_forward_backward(
15921688

15931689
forward_inputs = [forward_inputs] if isinstance(forward_inputs, paddle.Tensor) else forward_inputs
15941690
return forward_inputs, forward_loss, backward_input_grads
1691+

0 commit comments

Comments
 (0)