diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index 4950152ad732..3811f95b772d 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -747,6 +747,18 @@ def forward(self, x): class FusedNormGateFunc(paddle.autograd.PyLayer): """recompute of postnorm and gate""" + _current_norm_output = None + _current_invar = None + + @classmethod + def set_temporary_vars(cls, norm_output, invar): + FusedNormGateFunc._current_norm_output = norm_output + FusedNormGateFunc._current_invar = invar + + @classmethod + def clear_temporary_vars(cls): + FusedNormGateFunc._current_norm_output = None + FusedNormGateFunc._current_invar = None @staticmethod def forward(ctx, x, rms_norm_weight, moe_gate_weight, eps): @@ -762,7 +774,10 @@ def forward(ctx, x, rms_norm_weight, moe_gate_weight, eps): def backward(ctx, d_gate_logits, d_norm_output): x, rms_norm_weight, moe_gate_weight, eps = ctx.saved_tensor() # recompute rmsnorm - norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + norm_output = FusedNormGateFunc._current_norm_output + invar = FusedNormGateFunc._current_invar + if norm_output is None or invar is None: + norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) d_norm_output_linear, d_moe_gate_weight = paddle._C_ops.matmul_grad( cast_if_needed(norm_output, ctx.dtype), cast_if_needed(moe_gate_weight, ctx.dtype), @@ -779,6 +794,16 @@ def backward(ctx, d_gate_logits, d_norm_output): return dx, d_rms_norm_weight, d_moe_gate_weight +class TemporaryVarContext: + def __init__(self, norm_output, invar): + self.norm_output = norm_output + self.invar = invar + + def __enter__(self): + FusedNormGateFunc.set_temporary_vars(self.norm_output, self.invar) + + def __exit__(self, exc_type, exc_val, exc_tb): + FusedNormGateFunc.clear_temporary_vars() def balance_expert_assignment(n, m, k): assert k * n % m == 0 diff --git a/paddlenlp/transformers/deepseek_v2/modeling_pp.py b/paddlenlp/transformers/deepseek_v2/modeling_pp.py index f39283b46678..40f022d09a9b 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_pp.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_pp.py @@ -50,6 +50,7 @@ DeepseekV2PretrainingCriterion, DeepseekV2RMSNorm, set_global_step, + TemporaryVarContext, ) try: @@ -172,22 +173,25 @@ def forward_without_residual(self, inputs): if isinstance(inputs, list): inputs = tuple(inputs) - - if self.send_mtp_embed: - (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states) = inputs + if self.using_post_norm_recompute: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states, norm_out) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states, norm_out) = inputs else: - (hidden_states, residual, l_aux, final_hidden_states) = inputs + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states) = inputs with paddle.no_grad(): if self.shared_experts is not None: if self.using_post_norm_recompute: - shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd_norm_rc( - hidden_states, - self.shared_experts.norm_weight, - self.shared_experts.norm_eps, - self.shared_experts.w1, - self.shared_experts.w2, + _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( + norm_out, self.shared_experts.w1, self.shared_experts.w2 ) + norm_out = None + del norm_out else: _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( hidden_states, self.shared_experts.w1, self.shared_experts.w2 @@ -211,21 +215,25 @@ def forward(self, inputs): if isinstance(inputs, list): inputs = tuple(inputs) - if self.send_mtp_embed: - (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states) = inputs + if self.using_post_norm_recompute: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states, norm_out) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states, norm_out) = inputs else: - (hidden_states, residual, l_aux, final_hidden_states) = inputs + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states) = inputs with paddle.no_grad(): if self.shared_experts is not None: if self.using_post_norm_recompute: - shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd_norm_rc( - hidden_states, - self.shared_experts.norm_weight, - self.shared_experts.norm_eps, - self.shared_experts.w1, - self.shared_experts.w2, + _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( + norm_out, self.shared_experts.w1, self.shared_experts.w2 ) + norm_out = None + del norm_out else: _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( hidden_states, self.shared_experts.w1, self.shared_experts.w2 @@ -256,7 +264,7 @@ def backward(self, output_grad): inputs_embeds_mtp_grad = None if self.using_post_norm_recompute: - dx = FP8LinearFunctionBase.fp8_mlp_bwd_norm_rc( + dx, norm_out, invar = FP8LinearFunctionBase.fp8_mlp_bwd_norm_rc( hidden_states_grad, self.x, self.shared_experts.norm_weight, @@ -274,11 +282,17 @@ def backward(self, output_grad): residual_grad = hidden_states_grad l_aux_grad = paddle.ones(1, dtype=self.l_aux.dtype) * self.alpha final_hidden_states_grad = hidden_states_grad - - if self.send_mtp_embed: - return (inputs_embeds_mtp_grad, dx, residual_grad, l_aux_grad, final_hidden_states_grad) + + if self.using_post_norm_recompute: + if self.send_mtp_embed: + return (inputs_embeds_mtp_grad, dx, residual_grad, l_aux_grad, final_hidden_states_grad, norm_out, invar) + else: + return (dx, residual_grad, l_aux_grad, final_hidden_states_grad, norm_out, invar) else: - return (dx, residual_grad, l_aux_grad, final_hidden_states_grad) + if self.send_mtp_embed: + return (inputs_embeds_mtp_grad, dx, residual_grad, l_aux_grad, final_hidden_states_grad) + else: + return (dx, residual_grad, l_aux_grad, final_hidden_states_grad) class DecoderLayerNode(ScheduleNode): @@ -578,23 +592,35 @@ def attn_forward(self, inputs): hs_2d, token_indices, token_probs = self.fp8_fusion_moe_node.dispatch_quant_node.forward( norm_out, probs, routing_map ) + # common return values + ret = (hidden_states, residual, l_aux, hs_2d, token_indices, token_probs, norm_out) + + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + return ret else: hs_2d, token_indices, token_probs = self.fp8_fusion_moe_node.dispatch_quant_node.forward( hidden_states, probs, routing_map ) - # common return values - ret = (hidden_states, residual, l_aux, hs_2d, token_indices, token_probs) + # common return values + ret = (hidden_states, residual, l_aux, hs_2d, token_indices, token_probs) - # append mtp embed if needed - ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret - return ret + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + return ret def dispatch_forward(self, inputs, previous_event=None, async_finish=False, allocate_on_comm_stream=False): - if self.send_mtp_embed: - inputs_embeds_mtp, hidden_states, residual, l_aux, hs_2d, token_indices, token_probs = inputs + if self.using_post_norm_recompute: + if self.send_mtp_embed: + inputs_embeds_mtp, hidden_states, residual, l_aux, hs_2d, token_indices, token_probs, norm_out = inputs + else: + hidden_states, residual, l_aux, hs_2d, token_indices, token_probs, norm_out = inputs else: - hidden_states, residual, l_aux, hs_2d, token_indices, token_probs = inputs + if self.send_mtp_embed: + inputs_embeds_mtp, hidden_states, residual, l_aux, hs_2d, token_indices, token_probs = inputs + else: + hidden_states, residual, l_aux, hs_2d, token_indices, token_probs = inputs (hs_dispatched, dispatched_indices, dispatched_probs,) = self.fp8_fusion_moe_node.dispatch_node.forward( hs_2d, @@ -609,21 +635,37 @@ def dispatch_forward(self, inputs, previous_event=None, async_finish=False, allo # append mtp embed if needed ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret return ret def mlp_forward(self, inputs): - if self.send_mtp_embed: - ( - inputs_embeds_mtp, - hidden_states, - residual, - l_aux, - hs_dispatched, - dispatched_indices, - dispatched_probs, - ) = inputs + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + hs_dispatched, + dispatched_indices, + dispatched_probs, + norm_out, + ) = inputs + else: + hidden_states, residual, l_aux, hs_dispatched, dispatched_indices, dispatched_probs, norm_out = inputs else: - hidden_states, residual, l_aux, hs_dispatched, dispatched_indices, dispatched_probs = inputs + if self.send_mtp_embed: + ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + hs_dispatched, + dispatched_indices, + dispatched_probs, + ) = inputs + else: + hidden_states, residual, l_aux, hs_dispatched, dispatched_indices, dispatched_probs = inputs hidden_states_out = self.fp8_fusion_moe_node.mlp_node.forward( hs_dispatched, dispatched_indices, dispatched_probs @@ -632,13 +674,20 @@ def mlp_forward(self, inputs): # append mtp embed if needed ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret return ret def combine_forward(self, inputs, async_finish=False, previous_event=None, allocate_on_comm_stream=False): - if self.send_mtp_embed: - (inputs_embeds_mtp, hidden_states, residual, l_aux, hidden_states_out) = inputs + if self.using_post_norm_recompute: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, hidden_states_out, norm_out) = inputs + else: + (hidden_states, residual, l_aux, hidden_states_out, norm_out) = inputs else: - (hidden_states, residual, l_aux, hidden_states_out) = inputs + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, hidden_states_out) = inputs + else: + (hidden_states, residual, l_aux, hidden_states_out) = inputs output_combine = self.fp8_fusion_moe_node.combine_node.forward( hidden_states_out, @@ -651,17 +700,26 @@ def combine_forward(self, inputs, async_finish=False, previous_event=None, alloc # append mtp embed if needed ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret return ret def post_process_forward(self, inputs, with_residual=True): - if self.send_mtp_embed: - (inputs_embeds_mtp, hidden_states, residual, l_aux, output_combine) = inputs + if self.using_post_norm_recompute: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, output_combine, norm_out) = inputs + else: + (hidden_states, residual, l_aux, output_combine, norm_out) = inputs else: - (hidden_states, residual, l_aux, output_combine) = inputs + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, output_combine) = inputs + else: + (hidden_states, residual, l_aux, output_combine) = inputs final_hidden_states = self.fp8_fusion_moe_node.combine_quant_node.forward(output_combine) inputs = (hidden_states, residual, l_aux, final_hidden_states) inputs = (inputs_embeds_mtp, *inputs) if self.send_mtp_embed else inputs + inputs = (*inputs, norm_out) if self.using_post_norm_recompute else inputs + if with_residual: inputs = self.post_process_node.forward(inputs) @@ -672,10 +730,16 @@ def post_process_forward(self, inputs, with_residual=True): def post_process_backward(self, output_grad, event_to_wait=None): grad = self.post_process_node.backward(output_grad) - if self.send_mtp_embed: - inputs_embeds_mtp_grad, hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad = grad + if self.using_post_norm_recompute: + if self.send_mtp_embed: + inputs_embeds_mtp_grad, hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad, norm_out, invar = grad + else: + hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad, norm_out, invar = grad else: - hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad = grad + if self.send_mtp_embed: + inputs_embeds_mtp_grad, hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad = grad + else: + hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad = grad output_combine_grad, quant_event = self.fp8_fusion_moe_node.combine_quant_node.backward( final_hidden_states_grad, event_to_wait @@ -683,26 +747,50 @@ def post_process_backward(self, output_grad, event_to_wait=None): ret = (hidden_states_grad, residual_grad, l_aux_grad, output_combine_grad, quant_event) ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret return ret def combine_backward(self, output_grad, previous_event=None, async_finish=False, allocate_on_comm_stream=False): - if self.send_mtp_embed: - ( - inputs_embeds_mtp_grad, - hidden_states_grad, - residual_grad, - l_aux_grad, - output_combine_grad, - quant_event, - ) = output_grad + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + output_combine_grad, + quant_event, + norm_out, + invar, + ) = output_grad + else: + ( + hidden_states_grad, + residual_grad, + l_aux_grad, + output_combine_grad, + quant_event, + norm_out, + invar, + ) = output_grad else: - ( - hidden_states_grad, - residual_grad, - l_aux_grad, - output_combine_grad, - quant_event, - ) = output_grad + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + output_combine_grad, + quant_event, + ) = output_grad + else: + ( + hidden_states_grad, + residual_grad, + l_aux_grad, + output_combine_grad, + quant_event, + ) = output_grad if DSV3_USE_FP8_DISPATCH and quant_event is not None: combine_backward_wait_event = quant_event @@ -717,31 +805,55 @@ def combine_backward(self, output_grad, previous_event=None, async_finish=False, ret = (hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad) ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret return ret def mlp_backward(self, output_grad): - if self.send_mtp_embed: - inputs_embeds_mtp_grad, hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad = output_grad + if self.using_post_norm_recompute: + if self.send_mtp_embed: + inputs_embeds_mtp_grad, hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad, norm_out, invar = output_grad + else: + hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad, norm_out, invar = output_grad else: - hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad = output_grad + if self.send_mtp_embed: + inputs_embeds_mtp_grad, hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad = output_grad + else: + hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad = output_grad hs_dispatched_grad, dispatched_probs_grad = self.fp8_fusion_moe_node.mlp_node.backward(hidden_states_out_grad) + ret = (hidden_states_grad, residual_grad, l_aux_grad, hs_dispatched_grad, dispatched_probs_grad) ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret return ret def dispatch_backward(self, output_grad, async_finish=False, previous_event=None, allocate_on_comm_stream=False): - if self.send_mtp_embed: - ( - inputs_embeds_mtp_grad, - hidden_states_grad, - residual_grad, - l_aux_grad, - hs_dispatched_grad, - dispatched_probs_grad, - ) = output_grad + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_dispatched_grad, + dispatched_probs_grad, + norm_out, + invar, + ) = output_grad + else: + hidden_states_grad, residual_grad, l_aux_grad, hs_dispatched_grad, dispatched_probs_grad, norm_out, invar = output_grad else: - hidden_states_grad, residual_grad, l_aux_grad, hs_dispatched_grad, dispatched_probs_grad = output_grad + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_dispatched_grad, + dispatched_probs_grad, + ) = output_grad + else: + hidden_states_grad, residual_grad, l_aux_grad, hs_dispatched_grad, dispatched_probs_grad = output_grad hs_grad, token_probs_grad = self.fp8_fusion_moe_node.dispatch_node.backward( hs_dispatched_grad, @@ -753,23 +865,42 @@ def dispatch_backward(self, output_grad, async_finish=False, previous_event=None ret = (hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad) ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret return ret def attn_backward(self, output_grad): - if self.send_mtp_embed: - ( - inputs_embeds_mtp_grad, - hidden_states_grad, - residual_grad, - l_aux_grad, - hs_grad, - token_probs_grad, - ) = output_grad - inputs_embeds_mtp_grad_shape = hidden_states_grad.shape - inputs_embeds_mtp_grad_shape[-1] = -1 - inputs_embeds_mtp_grad = inputs_embeds_mtp_grad.view(inputs_embeds_mtp_grad_shape) + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_grad, + token_probs_grad, + norm_out, + invar, + ) = output_grad + inputs_embeds_mtp_grad_shape = hidden_states_grad.shape + inputs_embeds_mtp_grad_shape[-1] = -1 + inputs_embeds_mtp_grad = inputs_embeds_mtp_grad.view(inputs_embeds_mtp_grad_shape) + else: + hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad, norm_out, invar = output_grad else: - hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad = output_grad + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_grad, + token_probs_grad, + ) = output_grad + inputs_embeds_mtp_grad_shape = hidden_states_grad.shape + inputs_embeds_mtp_grad_shape[-1] = -1 + inputs_embeds_mtp_grad = inputs_embeds_mtp_grad.view(inputs_embeds_mtp_grad_shape) + else: + hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad = output_grad hidden_states_grad_, probs_grad, routing_map_grad = self.fp8_fusion_moe_node.dispatch_quant_node.backward( hs_grad, token_probs_grad @@ -784,7 +915,11 @@ def attn_backward(self, output_grad): ) output_grad = (inputs_embeds_mtp_grad, *output_grad) if self.send_mtp_embed else output_grad - output_grad = self.attn_and_gate_node.backward(output_grad) + if self.using_post_norm_recompute: + with TemporaryVarContext(norm_out, invar): + output_grad = self.attn_and_gate_node.backward(output_grad) + else: + output_grad = self.attn_and_gate_node.backward(output_grad) return output_grad def forward(self, inputs): @@ -896,7 +1031,7 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p combine_forward_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id) - combine_fwd_out = inputs[-1] + combine_fwd_out = inputs[-2] if self.forward_node.using_post_norm_recompute else inputs[-1] if pp_stream is not None: send_recv_stream = paddle.device.Stream(stream_base=pp_stream) diff --git a/paddlenlp/transformers/fp8_utils.py b/paddlenlp/transformers/fp8_utils.py index 2723e112f536..37f17176b732 100644 --- a/paddlenlp/transformers/fp8_utils.py +++ b/paddlenlp/transformers/fp8_utils.py @@ -411,14 +411,6 @@ def fp8_mlp_fwd(x, w1, w2): return x_fp8, x_scale, o3 - @staticmethod - def fp8_mlp_fwd_norm_rc(x, norm_w, norm_eps, w1, w2): - # ===== compute norm_output ===== - norm_output, _ = fused_ln.fused_rms_norm(x, norm_w, norm_eps) - # ===== compute fp8_mlp_fwd ===== - _, _, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(norm_output, w1, w2) - return o3 - @staticmethod def fp8_mlp_bwd(do3, x, w1, w2, apply_backward_hook=False): do3_orig_shape = do3.shape @@ -453,22 +445,10 @@ def fp8_mlp_bwd_norm_rc(do3, x, norm_w, norm_eps, w1, w2): # ===== compute fp8_mlp_fwd ===== d_norm_output = FP8LinearFunctionBase.fp8_mlp_bwd(do3, norm_output, w1, w2, True) - # ===== compute norm grad ===== - dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, norm_w, invar, d_norm_output, norm_eps) - - if hasattr(norm_w, "main_grad"): - if norm_w.main_grad is None: - norm_w.main_grad = paddle.zeros(shape=norm_w.shape, dtype=paddle.float32) - norm_w.main_grad += d_rms_norm_weight - else: - if norm_w.grad is None: - norm_w.grad = paddle.zeros(shape=norm_w.shape, dtype=paddle.float32) - norm_w.grad += d_rms_norm_weight - if hasattr(norm_w, "_apply_backward_hook"): norm_w._apply_backward_hook() - return dx + return d_norm_output, norm_output, invar class FP8LinearFunction(paddle.autograd.PyLayer):