Skip to content

Commit 6cd171e

Browse files
authored
allco on comm (#10859)
1 parent 3fee150 commit 6cd171e

File tree

3 files changed

+45
-11
lines changed

3 files changed

+45
-11
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -569,13 +569,18 @@ def mlp_forward(self, inputs):
569569
ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret
570570
return ret
571571

572-
def combine_forward(self, inputs, async_finish=False):
572+
def combine_forward(self, inputs, async_finish=False, previous_event=None, allocate_on_comm_stream=False):
573573
if self.send_mtp_embed:
574574
(inputs_embeds_mtp, hidden_states, residual, l_aux, hidden_states_out) = inputs
575575
else:
576576
(hidden_states, residual, l_aux, hidden_states_out) = inputs
577577

578-
output_combine = self.fp8_fusion_moe_node.combine_node.forward(hidden_states_out, async_finish=async_finish)
578+
output_combine = self.fp8_fusion_moe_node.combine_node.forward(
579+
hidden_states_out,
580+
async_finish=async_finish,
581+
previous_event=previous_event,
582+
allocate_on_comm_stream=allocate_on_comm_stream and previous_event is not None,
583+
)
579584

580585
ret = (hidden_states, residual, l_aux, output_combine)
581586

@@ -652,7 +657,7 @@ def mlp_backward(self, output_grad):
652657
ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret
653658
return ret
654659

655-
def dispatch_backward(self, output_grad, async_finish=False):
660+
def dispatch_backward(self, output_grad, async_finish=False, previous_event=None, allocate_on_comm_stream=False):
656661
if self.send_mtp_embed:
657662
(
658663
inputs_embeds_mtp_grad,
@@ -666,7 +671,11 @@ def dispatch_backward(self, output_grad, async_finish=False):
666671
hidden_states_grad, residual_grad, l_aux_grad, hs_dispatched_grad, dispatched_probs_grad = output_grad
667672

668673
hs_grad, token_probs_grad = self.fp8_fusion_moe_node.dispatch_node.backward(
669-
hs_dispatched_grad, dispatched_probs_grad, async_finish=async_finish
674+
hs_dispatched_grad,
675+
dispatched_probs_grad,
676+
async_finish=async_finish,
677+
previous_event=previous_event,
678+
allocate_on_comm_stream=allocate_on_comm_stream and previous_event is not None,
670679
)
671680

672681
ret = (hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad)
@@ -755,6 +764,8 @@ def forward_backward(self, inputs, output_grad, event_to_wait=None):
755764
output_grad = self.backward_node.mlp_backward(output_grad)
756765
paddle.base.core.nvprof_nvtx_pop()
757766

767+
output_grad_event = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id)
768+
758769
paddle.base.core.nvprof_nvtx_push("dispatch_forward")
759770
inputs = self.forward_node.dispatch_forward(
760771
inputs, previous_event=attn_compute_event, async_finish=True, allocate_on_comm_stream=True
@@ -763,7 +774,9 @@ def forward_backward(self, inputs, output_grad, event_to_wait=None):
763774
dispatch_forward_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id)
764775

765776
paddle.base.core.nvprof_nvtx_push("dispatch_backward")
766-
output_grad = self.backward_node.dispatch_backward(output_grad, async_finish=True)
777+
output_grad = self.backward_node.dispatch_backward(
778+
output_grad, async_finish=True, previous_event=output_grad_event, allocate_on_comm_stream=True
779+
)
767780
paddle.base.core.nvprof_nvtx_pop()
768781
# get dispatch backward event
769782
dispatch_backward_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id)
@@ -777,8 +790,12 @@ def forward_backward(self, inputs, output_grad, event_to_wait=None):
777790
inputs = self.forward_node.mlp_forward(inputs)
778791
paddle.base.core.nvprof_nvtx_pop()
779792

793+
inputs_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)
794+
780795
paddle.base.core.nvprof_nvtx_push("combine_forward")
781-
inputs = self.forward_node.combine_forward(inputs, async_finish=True)
796+
inputs = self.forward_node.combine_forward(
797+
inputs, async_finish=True, previous_event=inputs_event, allocate_on_comm_stream=True
798+
)
782799
paddle.base.core.nvprof_nvtx_pop()
783800
combine_forward_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id)
784801

paddlenlp/transformers/fused_a2a.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,9 @@ def forward(
308308

309309
return recv_x, recv_token_probs, states
310310

311-
def backward(self, grad_output, grad_token_probs, previous_event=None, async_finish=False):
311+
def backward(
312+
self, grad_output, grad_token_probs, previous_event=None, async_finish=False, allocate_on_comm_stream=False
313+
):
312314
"""Backward pass of fused dispatch."""
313315
out = fused_dispatch_backward_func(
314316
grad_output,
@@ -317,6 +319,7 @@ def backward(self, grad_output, grad_token_probs, previous_event=None, async_fin
317319
self.handle,
318320
previous_event=previous_event,
319321
async_finish=async_finish,
322+
allocate_on_comm_stream=allocate_on_comm_stream,
320323
)
321324
self.reset_statue()
322325
return out
@@ -329,12 +332,17 @@ def __init__(self, name="combine"):
329332
def reset_statue(self):
330333
self.handle = None
331334

332-
def forward(self, x, group, handle, previous_event=None, async_finish=False):
335+
def forward(self, x, group, handle, previous_event=None, async_finish=False, allocate_on_comm_stream=False):
333336
"""Forward pass of fused combine."""
334337
states = dict()
335338
states["handle"] = handle
336339
combined_x = fused_combine_forward_func(
337-
x, group, states, previous_event=previous_event, async_finish=async_finish
340+
x,
341+
group,
342+
states,
343+
previous_event=previous_event,
344+
async_finish=async_finish,
345+
allocate_on_comm_stream=allocate_on_comm_stream,
338346
)
339347

340348
self.handle = handle

paddlenlp/transformers/moe_layer.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -596,13 +596,21 @@ def forward(
596596
return hs_2d_dispatched, dispatched_indices, dispatched_probs
597597

598598
@paddle.no_grad()
599-
def backward(self, hs_dispatched_grad, dispatched_probs_grad, previous_event=None, async_finish=False):
599+
def backward(
600+
self,
601+
hs_dispatched_grad,
602+
dispatched_probs_grad,
603+
previous_event=None,
604+
async_finish=False,
605+
allocate_on_comm_stream=False,
606+
):
600607
# dispatch grad
601608
hs_grad, _, token_probs_grad = self.dispatch_act_node.backward(
602609
hs_dispatched_grad,
603610
dispatched_probs_grad,
604611
previous_event=previous_event,
605612
async_finish=async_finish,
613+
allocate_on_comm_stream=allocate_on_comm_stream,
606614
)
607615
return hs_grad, token_probs_grad
608616

@@ -614,14 +622,15 @@ def __init__(self, token_dispatcher, name="fp8_combine_node"):
614622
self.name = name
615623

616624
@paddle.no_grad()
617-
def forward(self, hidden_states_out, previous_event=None, async_finish=False):
625+
def forward(self, hidden_states_out, previous_event=None, async_finish=False, allocate_on_comm_stream=False):
618626
# combine
619627
output_combine = self.combine_node.forward(
620628
hidden_states_out,
621629
self.token_dispatcher._comm_manager.group,
622630
self.token_dispatcher._comm_manager.handle,
623631
previous_event=previous_event,
624632
async_finish=async_finish,
633+
allocate_on_comm_stream=allocate_on_comm_stream,
625634
)
626635
output_combine.stop_gradient = False
627636
return output_combine

0 commit comments

Comments
 (0)