diff --git a/paddlenlp/transformers/deepseek_v2/modeling_pp.py b/paddlenlp/transformers/deepseek_v2/modeling_pp.py index de9592375fee..bece3a144f16 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_pp.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_pp.py @@ -67,7 +67,6 @@ DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true" -DSV3_USE_FP8_DISPATCH = os.getenv("DSV3_USE_FP8_DISPATCH", "False").lower() == "true" def parse_args(args): @@ -158,43 +157,6 @@ def __init__( assert self.shared_experts is not None assert self.shared_experts.norm_weight is not None and self.shared_experts.norm_eps is not None - def forward_without_residual(self, inputs): - - if isinstance(inputs, list): - inputs = tuple(inputs) - - if self.send_mtp_embed: - (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states) = inputs - else: - (hidden_states, residual, l_aux, final_hidden_states) = inputs - - with paddle.no_grad(): - if self.shared_experts is not None: - if self.using_post_norm_recompute: - shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd_norm_rc( - hidden_states, - self.shared_experts.norm_weight, - self.shared_experts.norm_eps, - self.shared_experts.w1, - self.shared_experts.w2, - ) - else: - _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( - hidden_states, self.shared_experts.w1, self.shared_experts.w2 - ) - residual = residual + shared_expert_output - - self.x = hidden_states - self.l_aux = l_aux - - hidden_states = residual - hidden_states.stop_gradient = False - - if self.send_mtp_embed: - hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) - - return return_args(hidden_states) - def forward(self, inputs): if isinstance(inputs, list): @@ -469,17 +431,9 @@ def __init__(self, forward_nodes, backward_nodes, use_fuion=True): for f, b in zip(forward_nodes, backward_nodes): self.nodes.append(schedule_node_class(f, b, f"OverlapedNode_{len(self.nodes)}")) - def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None): - # print(" fwd pp stream", pp_stream) - event_to_wait = combine_bw_event_to_wait - for i, n in enumerate(self.nodes): - pp_stream_t = pp_stream - if i + 1 != len(self.nodes): - pp_stream_t = None - - inputs, output_grad, event_to_wait = n.forward_backward( - inputs, output_grad, combine_bw_event_to_wait=event_to_wait, pp_stream=pp_stream_t - ) + def forward_backward(self, inputs, output_grad, event_to_wait=None): + for n in self.nodes: + inputs, output_grad, event_to_wait = n.forward_backward(inputs, output_grad, event_to_wait) return inputs, output_grad, None @@ -632,7 +586,7 @@ def combine_forward(self, inputs, async_finish=False, previous_event=None, alloc ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret return ret - def post_process_forward(self, inputs, with_residual=True): + def post_process_forward(self, inputs): if self.send_mtp_embed: (inputs_embeds_mtp, hidden_states, residual, l_aux, output_combine) = inputs else: @@ -642,10 +596,7 @@ def post_process_forward(self, inputs, with_residual=True): inputs = (hidden_states, residual, l_aux, final_hidden_states) inputs = (inputs_embeds_mtp, *inputs) if self.send_mtp_embed else inputs - if with_residual: - inputs = self.post_process_node.forward(inputs) - else: - inputs = self.post_process_node.forward_without_residual(inputs) + inputs = self.post_process_node.forward(inputs) return inputs def post_process_backward(self, output_grad, event_to_wait=None): @@ -664,7 +615,7 @@ def post_process_backward(self, output_grad, event_to_wait=None): ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret return ret - def combine_backward(self, output_grad, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + def combine_backward(self, output_grad, async_finish=False, allocate_on_comm_stream=False): if self.send_mtp_embed: ( inputs_embeds_mtp_grad, @@ -675,22 +626,12 @@ def combine_backward(self, output_grad, previous_event=None, async_finish=False, quant_event, ) = output_grad else: - ( - hidden_states_grad, - residual_grad, - l_aux_grad, - output_combine_grad, - quant_event, - ) = output_grad + hidden_states_grad, residual_grad, l_aux_grad, output_combine_grad, quant_event = output_grad - if DSV3_USE_FP8_DISPATCH and quant_event is not None: - combine_backward_wait_event = quant_event - else: - combine_backward_wait_event = previous_event hidden_states_out_grad = self.fp8_fusion_moe_node.combine_node.backward( output_combine_grad, async_finish=async_finish, - previous_event=combine_backward_wait_event, + previous_event=quant_event, allocate_on_comm_stream=allocate_on_comm_stream and quant_event is not None, ) @@ -797,34 +738,25 @@ def __init__(self, forward_node, backward_node, name=""): self.backward_node = backward_node self.name = name - def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None): + def forward_backward(self, inputs, output_grad, event_to_wait=None): paddle.base.core.nvprof_nvtx_push("forward_backward") - combine_bwd_event = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) - - paddle.base.core.nvprof_nvtx_push("attn_forward") - inputs = self.forward_node.attn_forward(inputs) - paddle.base.core.nvprof_nvtx_pop() - attn_compute_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) - paddle.base.core.nvprof_nvtx_push("post_process_backward") - output_grad = self.backward_node.post_process_backward(output_grad, combine_bw_event_to_wait) + output_grad = self.backward_node.post_process_backward(output_grad, event_to_wait) paddle.base.core.nvprof_nvtx_pop() paddle.base.core.nvprof_nvtx_push("combine_backward") - if combine_bw_event_to_wait is not None: - # print(" event", combine_bw_event_to_wait) - output_grad = self.backward_node.combine_backward( - output_grad, previous_event=combine_bw_event_to_wait, async_finish=True, allocate_on_comm_stream=True - ) - else: - output_grad = self.backward_node.combine_backward( - output_grad, previous_event=combine_bwd_event, async_finish=True, allocate_on_comm_stream=True - ) + output_grad = self.backward_node.combine_backward(output_grad, async_finish=True, allocate_on_comm_stream=True) # get combine event combine_backward_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id) paddle.base.core.nvprof_nvtx_pop() + paddle.base.core.nvprof_nvtx_push("attn_forward") + inputs = self.forward_node.attn_forward(inputs) + paddle.base.core.nvprof_nvtx_pop() + + attn_compute_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + combine_backward_event.calc_stream_wait(self.backward_node.moe_group.id) paddle.base.core.nvprof_nvtx_push("mlp_backward_dx") output_grad = self.backward_node.mlp_backward(output_grad) @@ -855,61 +787,26 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p paddle.base.core.nvprof_nvtx_push("mlp_forward") inputs = self.forward_node.mlp_forward(inputs) paddle.base.core.nvprof_nvtx_pop() - mlp_fwd_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) - - if pp_stream is not None: - final_out = self.forward_node.post_process_node.forward_without_residual(inputs) - final_out_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + inputs_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) paddle.base.core.nvprof_nvtx_push("combine_forward") inputs = self.forward_node.combine_forward( - inputs, previous_event=mlp_fwd_event, async_finish=True, allocate_on_comm_stream=True + inputs, async_finish=True, previous_event=inputs_event, allocate_on_comm_stream=True ) paddle.base.core.nvprof_nvtx_pop() - combine_forward_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id) - combine_fwd_out = inputs[-1] - - if pp_stream is not None: - send_recv_stream = paddle.device.Stream(stream_base=pp_stream) - - # combine_forward_event.custom_stream_wait( pp_stream) - # final_out_event.custom_stream_wait(pp_stream) - - paddle.base.core.nvprof_nvtx_push("pp stream add") - - with paddle.device.stream_guard(send_recv_stream): - combine_forward_event.current_stream_wait() - final_out_event.current_stream_wait() - - inputs = final_out + combine_fwd_out - - final_out._record_stream() - combine_fwd_out._record_stream() - - paddle.base.core.nvprof_nvtx_pop() - dispatch_backward_event.calc_stream_wait(self.backward_node.moe_group.id) - paddle.base.core.nvprof_nvtx_push("post_process_forward") - - paddle.base.core.nvprof_nvtx_pop() paddle.base.core.nvprof_nvtx_push("attn_backward") output_grad = self.backward_node.attn_backward(output_grad) event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) - paddle.base.core.nvprof_nvtx_pop() - # residual add - if pp_stream is None: - combine_forward_event.calc_stream_wait(self.forward_node.moe_group.id) - - final_out = self.forward_node.post_process_node.forward_without_residual(inputs) - inputs = final_out + combine_fwd_out - - combine_fwd_out._record_stream() - + combine_forward_event.calc_stream_wait(self.forward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_push("post_process_forward") + inputs = self.forward_node.post_process_forward(inputs) + paddle.base.core.nvprof_nvtx_pop() paddle.base.core.nvprof_nvtx_pop() return inputs, output_grad, event_to_wait @@ -1683,10 +1580,7 @@ def overlapped_forward_backward( forward_inputs = forward_pre_node.forward(forward_inputs) backward_input_grads = backward_pre_node.backward(backward_input_grads) forward_inputs, backward_input_grads, _ = overlap_node.forward_backward( - forward_inputs, - backward_input_grads, - combine_bw_event_to_wait=combine_bw_event_to_wait, - pp_stream=pp_stream, + forward_inputs, backward_input_grads, combine_bw_event_to_wait ) forward_inputs = forward_post_node.forward(forward_inputs) backward_input_grads = backward_post_node.backward(backward_input_grads) diff --git a/paddlenlp/transformers/fp8_utils.py b/paddlenlp/transformers/fp8_utils.py index 1f59ca011627..72110eab3ba1 100644 --- a/paddlenlp/transformers/fp8_utils.py +++ b/paddlenlp/transformers/fp8_utils.py @@ -917,17 +917,18 @@ def bwd_gate_up_weight(self, do1, input_x, expert_w1, clear_input=False): @paddle.no_grad() def forward(self, hs_out, unzipped_probs, tokens_per_expert, origin_token_per_experts, output=None): self.origin_token_per_experts = origin_token_per_experts - # deal 0 size - dtype = paddle.bfloat16 if hs_out is None: assert self.input_fp8 is not None assert self.input_scale is not None shape = self.input_fp8.shape + dtype = paddle.bfloat16 else: if isinstance(hs_out, tuple): shape = hs_out[0].shape + dtype = hs_out[0].dtype else: shape = hs_out.shape + dtype = hs_out.dtype if shape[0] == 0: o3 = paddle.zeros(shape, dtype=dtype) @@ -957,12 +958,6 @@ def forward(self, hs_out, unzipped_probs, tokens_per_expert, origin_token_per_ex @paddle.no_grad() def backward(self, out_grad): - # deal 0 size - dtype = paddle.bfloat16 - shape = out_grad[0].shape if isinstance(out_grad, tuple) else out_grad.shape - if shape[0] == 0: - return paddle.zeros_like(out_grad, dtype=dtype), paddle.zeros_like(self.unzipped_probs, dtype=dtype) - # recompute expert_w2 and expert_w1 expert_w1 = [x.w1 for x in self.experts if x is not None] expert_w2 = [x.w2 for x in self.experts if x is not None] @@ -1000,12 +995,6 @@ def backward(self, out_grad): @paddle.no_grad() def backward_dx(self, out_grad): - # deal 0 size - dtype = paddle.bfloat16 - shape = out_grad[0].shape if isinstance(out_grad, tuple) else out_grad.shape - if shape[0] == 0: - return paddle.zeros_like(out_grad, dtype=dtype), paddle.zeros_like(self.unzipped_probs, dtype=dtype) - # recompute expert_w2 and expert_w1 expert_w1 = [x.w1 for x in self.experts if x is not None] expert_w2 = [x.w2 for x in self.experts if x is not None] @@ -1038,9 +1027,6 @@ def backward_dx(self, out_grad): @paddle.no_grad() def backward_dw(self): - # deal 0 size - if self.input_fp8 is None or self.input_fp8.shape[0] == 0: - return # recompute expert_w2 and expert_w1 expert_w1 = [x.w1 for x in self.experts if x is not None] expert_w2 = [x.w2 for x in self.experts if x is not None]