Skip to content

Commit 48e3e32

Browse files
authored
Fix conflict (#10902)
* refine fp8_utils * refine fp8_utils * refine fp8_utils * fix * fix after review * fix * fix * fix * fix * fix
1 parent a3b6c2c commit 48e3e32

File tree

1 file changed

+35
-26
lines changed

1 file changed

+35
-26
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def __init__(
157157
if self.using_post_norm_recompute:
158158
assert self.shared_experts is not None
159159
assert self.shared_experts.norm_weight is not None and self.shared_experts.norm_eps is not None
160+
160161
def forward_without_residual(self, inputs):
161162

162163
if isinstance(inputs, list):
@@ -178,13 +179,15 @@ def forward_without_residual(self, inputs):
178179
self.shared_experts.w2,
179180
)
180181
else:
181-
shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd(hidden_states, self.shared_experts.w1, self.shared_experts.w2)
182+
_, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd(
183+
hidden_states, self.shared_experts.w1, self.shared_experts.w2
184+
)
182185
residual = residual + shared_expert_output
183186

184187
self.x = hidden_states
185188
self.l_aux = l_aux
186189

187-
hidden_states = residual
190+
hidden_states = residual
188191
hidden_states.stop_gradient = False
189192

190193
if self.send_mtp_embed:
@@ -467,14 +470,16 @@ def __init__(self, forward_nodes, backward_nodes, use_fuion=True):
467470
self.nodes.append(schedule_node_class(f, b, f"OverlapedNode_{len(self.nodes)}"))
468471

469472
def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None):
470-
#print(" fwd pp stream", pp_stream)
473+
# print(" fwd pp stream", pp_stream)
471474
event_to_wait = combine_bw_event_to_wait
472475
for i, n in enumerate(self.nodes):
473476
pp_stream_t = pp_stream
474477
if i + 1 != len(self.nodes):
475478
pp_stream_t = None
476-
477-
inputs, output_grad, event_to_wait = n.forward_backward(inputs, output_grad, combine_bw_event_to_wait=event_to_wait, pp_stream=pp_stream_t)
479+
480+
inputs, output_grad, event_to_wait = n.forward_backward(
481+
inputs, output_grad, combine_bw_event_to_wait=event_to_wait, pp_stream=pp_stream_t
482+
)
478483
return inputs, output_grad, None
479484

480485

@@ -677,8 +682,8 @@ def combine_backward(self, output_grad, previous_event=None, async_finish=False,
677682
output_combine_grad,
678683
quant_event,
679684
) = output_grad
680-
681-
if DSV3_USE_FP8_DISPATCH and quant_event is not None :
685+
686+
if DSV3_USE_FP8_DISPATCH and quant_event is not None:
682687
combine_backward_wait_event = quant_event
683688
else:
684689
combine_backward_wait_event = previous_event
@@ -809,11 +814,13 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
809814
paddle.base.core.nvprof_nvtx_push("combine_backward")
810815
if combine_bw_event_to_wait is not None:
811816
# print(" event", combine_bw_event_to_wait)
812-
output_grad = self.backward_node.combine_backward(output_grad, previous_event= combine_bw_event_to_wait, async_finish=True,
813-
allocate_on_comm_stream=True)
817+
output_grad = self.backward_node.combine_backward(
818+
output_grad, previous_event=combine_bw_event_to_wait, async_finish=True, allocate_on_comm_stream=True
819+
)
814820
else:
815-
output_grad = self.backward_node.combine_backward(output_grad, previous_event= combine_bwd_event, async_finish=True,
816-
allocate_on_comm_stream=True)
821+
output_grad = self.backward_node.combine_backward(
822+
output_grad, previous_event=combine_bwd_event, async_finish=True, allocate_on_comm_stream=True
823+
)
817824
# get combine event
818825
combine_backward_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id)
819826
paddle.base.core.nvprof_nvtx_pop()
@@ -850,22 +857,23 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
850857
paddle.base.core.nvprof_nvtx_pop()
851858
mlp_fwd_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)
852859

853-
854860
if pp_stream is not None:
855-
final_out = self.forward_node.post_process_node.forward_without_residual(inputs)
856-
861+
final_out = self.forward_node.post_process_node.forward_without_residual(inputs)
862+
857863
final_out_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)
858-
864+
859865
paddle.base.core.nvprof_nvtx_push("combine_forward")
860-
inputs = self.forward_node.combine_forward(inputs, previous_event= mlp_fwd_event, async_finish=True, allocate_on_comm_stream=True)
866+
inputs = self.forward_node.combine_forward(
867+
inputs, previous_event=mlp_fwd_event, async_finish=True, allocate_on_comm_stream=True
868+
)
861869
paddle.base.core.nvprof_nvtx_pop()
862870

863-
combine_forward_event = deep_ep.get_event_from_comm_stream( self.forward_node.moe_group.id)
871+
combine_forward_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id)
864872

865873
combine_fwd_out = inputs[-1]
866874

867875
if pp_stream is not None:
868-
send_recv_stream = paddle.device.Stream(stream_base= pp_stream )
876+
send_recv_stream = paddle.device.Stream(stream_base=pp_stream)
869877

870878
# combine_forward_event.custom_stream_wait( pp_stream)
871879
# final_out_event.custom_stream_wait(pp_stream)
@@ -876,16 +884,15 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
876884
combine_forward_event.current_stream_wait()
877885
final_out_event.current_stream_wait()
878886

879-
inputs = final_out + combine_fwd_out
887+
inputs = final_out + combine_fwd_out
880888

881889
final_out._record_stream()
882890
combine_fwd_out._record_stream()
883-
891+
884892
paddle.base.core.nvprof_nvtx_pop()
885893

886894
dispatch_backward_event.calc_stream_wait(self.backward_node.moe_group.id)
887895
paddle.base.core.nvprof_nvtx_push("post_process_forward")
888-
889896

890897
paddle.base.core.nvprof_nvtx_pop()
891898
paddle.base.core.nvprof_nvtx_push("attn_backward")
@@ -899,10 +906,10 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
899906
combine_forward_event.calc_stream_wait(self.forward_node.moe_group.id)
900907

901908
final_out = self.forward_node.post_process_node.forward_without_residual(inputs)
902-
inputs = final_out + combine_fwd_out
909+
inputs = final_out + combine_fwd_out
903910

904911
combine_fwd_out._record_stream()
905-
912+
906913
paddle.base.core.nvprof_nvtx_pop()
907914
return inputs, output_grad, event_to_wait
908915

@@ -1676,8 +1683,11 @@ def overlapped_forward_backward(
16761683
forward_inputs = forward_pre_node.forward(forward_inputs)
16771684
backward_input_grads = backward_pre_node.backward(backward_input_grads)
16781685
forward_inputs, backward_input_grads, _ = overlap_node.forward_backward(
1679-
forward_inputs, backward_input_grads, combine_bw_event_to_wait = combine_bw_event_to_wait,
1680-
pp_stream = pp_stream)
1686+
forward_inputs,
1687+
backward_input_grads,
1688+
combine_bw_event_to_wait=combine_bw_event_to_wait,
1689+
pp_stream=pp_stream,
1690+
)
16811691
forward_inputs = forward_post_node.forward(forward_inputs)
16821692
backward_input_grads = backward_post_node.backward(backward_input_grads)
16831693

@@ -1688,4 +1698,3 @@ def overlapped_forward_backward(
16881698

16891699
forward_inputs = [forward_inputs] if isinstance(forward_inputs, paddle.Tensor) else forward_inputs
16901700
return forward_inputs, forward_loss, backward_input_grads
1691-

0 commit comments

Comments
 (0)