diff --git a/paddleformers/transformers/__init__.py b/paddleformers/transformers/__init__.py index 7db5c7ca1d1..6bd9613b239 100644 --- a/paddleformers/transformers/__init__.py +++ b/paddleformers/transformers/__init__.py @@ -122,6 +122,10 @@ "get_triangle_upper_mask", "DeepseekV2LinearScalingRotaryEmbedding", ], + "deepseek_v2.modeling_fast": [ + "DeepseekV2ModelFast", + "DeepseekV2PretrainedModelFast", + ], "deepseek_v2.modeling_auto": [ "DeepseekV2LMHeadAuto", "DeepseekV2ForCausalLMAuto", diff --git a/paddleformers/transformers/deepseek_v2/__init__.py b/paddleformers/transformers/deepseek_v2/__init__.py index a0fac197982..2c7634b8810 100644 --- a/paddleformers/transformers/deepseek_v2/__init__.py +++ b/paddleformers/transformers/deepseek_v2/__init__.py @@ -56,6 +56,12 @@ "yarn_find_correction_range", "get_triangle_upper_mask", "DeepseekV2LinearScalingRotaryEmbedding", + "set_global_step", + "get_global_step", + ], + "modeling_fast": [ + "DeepseekV2ModelFast", + "DeepseekV2PretrainedModelFast", ], "modeling_auto": [ "DeepseekV2LMHeadAuto", diff --git a/paddleformers/transformers/deepseek_v2/configuration.py b/paddleformers/transformers/deepseek_v2/configuration.py index 1feba3cbec7..e62ae3dc5ef 100644 --- a/paddleformers/transformers/deepseek_v2/configuration.py +++ b/paddleformers/transformers/deepseek_v2/configuration.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ DeepSeekV2 model configuration""" -from ..configuration_utils import PretrainedConfig +from paddleformers.transformers.configuration_utils import PretrainedConfig __all__ = [ "DeepseekV2Config", @@ -179,6 +179,18 @@ def __init__( attention_dropout=0.0, speculate_model_type=False, using_flex_token=False, + use_dualpipev=False, + send_mtp_embed=True, + using_post_norm_recompute=False, + recompute_fwd_gate_up=0, + is_split_group_gemm=False, + fakse_gate_restrict_balance=False, + adaptive_remained_O1_recompute_ratio=0, + offline_quant_expert_weight=True, + clear_origin_weight_when_offline_quant=True, + mlp_bwd_subbatch_rows=0, + mlp_fwd_subbatch_rows=0, + output_subbatch_rows=0, **kwargs, ): self.vocab_size = vocab_size @@ -227,6 +239,18 @@ def __init__( self.speculate_model_type = speculate_model_type self.use_fp8 = False self.using_flex_token = using_flex_token + self.use_dualpipev = use_dualpipev + self.send_mtp_embed = send_mtp_embed + self.using_post_norm_recompute = using_post_norm_recompute + self.recompute_fwd_gate_up = recompute_fwd_gate_up + self.is_split_group_gemm = is_split_group_gemm + self.fakse_gate_restrict_balance = fakse_gate_restrict_balance + self.adaptive_remained_O1_recompute_ratio = adaptive_remained_O1_recompute_ratio + self.offline_quant_expert_weight = offline_quant_expert_weight + self.clear_origin_weight_when_offline_quant = clear_origin_weight_when_offline_quant + self.mlp_bwd_subbatch_rows = mlp_bwd_subbatch_rows + self.mlp_fwd_subbatch_rows = mlp_fwd_subbatch_rows + self.output_subbatch_rows = output_subbatch_rows super().__init__( pad_token_id=pad_token_id, diff --git a/paddleformers/transformers/deepseek_v2/modeling.py b/paddleformers/transformers/deepseek_v2/modeling.py index 04a8651f43e..7e35a4e58b1 100644 --- a/paddleformers/transformers/deepseek_v2/modeling.py +++ b/paddleformers/transformers/deepseek_v2/modeling.py @@ -23,6 +23,7 @@ import contextlib import math +import os import warnings from functools import partial from typing import List, Optional, Tuple, Union @@ -35,7 +36,9 @@ from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.recompute.recompute import recompute +from paddle.jit import to_static from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from paddle.utils import try_import try: from paddle.incubate.nn.functional import fused_rotary_position_embedding @@ -51,11 +54,12 @@ except: pass +from paddle import _C_ops try: from paddle.nn.functional.flash_attention import flash_attention except: flash_attention = None - +from paddleformers.transformers.model_utils import dtype_guard from ...utils.initializer import kaiming_uniform_ from ...utils.log import logger @@ -72,11 +76,46 @@ from ..model_utils import PretrainedModel, dtype_guard, register_base_model from ..moe_gate import PretrainedMoEGate from ..moe_layer import MoEFlexTokenLayer, MoELayer -from ..utils import device_guard +from ..utils import cast_if_needed, device_guard from . import fp8_linear as linear_utils from .configuration import DeepseekV2Config + +FA_VERSION = int(os.getenv("FA_VERSION", 2)) + +from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import WeightGradStore + +from ..fp8_utils import ( + FP8KeepXLinear, + FP8Linear, + FP8LinearFunctionBase, + FP8Mlp, + cache_fp8_weight, + set_parameter_color, +) from .fp8_linear import Linear +DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true" +DSV3_USE_ATTEN_RECOMPUTE = os.getenv("DSV3_USE_ATTEN_RECOMPUTE", "False").lower() == "true" + +Linear = FP8Linear if DSV3_USE_FP8_GEMM else Linear + +try: + import fused_ln + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + + +try: + from paddle.incubate.nn.functional import fused_partial_rope +except ImportError: + fused_partial_rope = None + + __all__ = [ "DeepseekV2LMHead", "DeepseekV2PretrainingCriterion", @@ -84,8 +123,54 @@ "DeepseekV2ForSequenceClassification", "DeepseekV2Model", "DeepseekV2PretrainedModel", + "set_global_step", + "get_global_step", ] +global_step = 0 + + +def set_global_step(cur_step): + global global_step + global_step = cur_step + + +def get_global_step(): + global global_step + return global_step + + +def rms_norm_fused(x_in, w, eps, use_fast_ln=False): + if use_fast_ln: + fast_ln = try_import("fast_ln") + return fast_ln.fast_rms_norm(x_in, w, eps)[0] + else: + fused_ln = try_import("fused_ln") + return fused_ln.fused_rms_norm(x_in, w, eps)[0] + + +def fusion_rms_norm(hidden_states, weight, variance_epsilon, use_fast_ln=False): + if get_env_device() == "npu": + return paddle.base.core.eager._run_custom_op("rms_norm_npu", hidden_states, weight, variance_epsilon)[0] + if get_env_device() == "mlu": + return paddle.base.core.eager._run_custom_op("rms_norm_mlu", hidden_states, weight, variance_epsilon)[0] + elif get_env_device() == "gcu": + return paddle.base.core.eager._run_custom_op("rms_norm_gcu", hidden_states, weight, variance_epsilon)[0] + elif get_env_device() == "intel_hpu": + return paddle.incubate.nn.functional.fused_rms_norm( + hidden_states, weight, None, variance_epsilon, hidden_states.dim() - 1 + )[0] + elif get_env_device() == "xpu": + try: + import paddle_xpu_nn # noqa: F821 + + return paddle_xpu_nn.xpu_rms_norm(hidden_states, weight, variance_epsilon)[0] + except ImportError: + raise NotImplementedError( + f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature" + ) + return rms_norm_fused(hidden_states, weight, variance_epsilon, use_fast_ln) + def get_triangle_upper_mask(x, mask=None): if mask is not None: @@ -129,7 +214,35 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int): return assignment_list -def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): +class LMHeadFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, weight, transpose_y): + out = paddle.matmul(x, weight, transpose_y = transpose_y) + + ctx.save_for_backward(x, weight, transpose_y) + return out + + @staticmethod + def backward(ctx, dout): + if dout.dtype == paddle.float32: + dout = dout.cast( paddle.bfloat16) + + x, weight, transpose_y = ctx.saved_tensor() + + dx = paddle.matmul( dout, weight, transpose_y = not transpose_y) + if transpose_y: + with paddle.amp.auto_cast(False): + paddle._C_ops.fused_linear_param_grad_add( + dout.reshape( [-1, dout.shape[-1]]), x.reshape( [-1, x.shape[-1]]), weight.main_grad, None, True, False + ) + else: + with paddle.amp.auto_cast(False): + paddle._C_ops.fused_linear_param_grad_add( + x.reshape([-1, x.shape[-1]]), dout.reshape([-1, dout.shape[-1]]), weight.main_grad, None, True, False + ) + return dx, None + +def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_output=True): is_fleet_init = True tensor_parallel_degree = 1 try: @@ -147,7 +260,7 @@ def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed: # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg' input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group) - logits = paddle.matmul(input_parallel, y, transpose_y=False) + logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y) if tensor_parallel_output: return logits @@ -155,7 +268,7 @@ def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): return paddle.distributed.collective._c_concat(logits, group=model_parallel_group) else: - logits = paddle.matmul(x, y, transpose_y=False) + logits = LMHeadFunction.apply(x, y, transpose_y=transpose_y) return logits @@ -328,17 +441,9 @@ def __init__(self, config: DeepseekV2Config, hidden_size=None, eps=1e-6, use_seq mark_as_sequence_parallel_parameter(self.weight) def forward(self, hidden_states): - if self.config.use_fused_rms_norm and get_env_device() == "xpu": - if self.weight.dtype != hidden_states.dtype: - hidden_states = paddle.cast(hidden_states, self.weight.dtype) - try: - import paddle_xpu_nn # noqa: F821 - - return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] - except ImportError: - raise NotImplementedError( - f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature" - ) + if self.config.use_fused_rms_norm: + # fusion_rms_norm集成了多硬件功能 + return fusion_rms_norm(hidden_states, self.weight, self.variance_epsilon, self.config.use_fast_layer_norm) with paddle.amp.auto_cast(False): hidden_states = hidden_states.astype("float32") @@ -528,34 +633,35 @@ def __init__( super().__init__(dim, max_position_embeddings, base) def _set_cos_sin_cache(self, seq_len): - self.max_seq_len_cached = seq_len - dim = self.dim - - freq_extra = 1.0 / (self.base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim)) - freq_inter = 1.0 / (self.scaling_factor * self.base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim)) - - low, high = yarn_find_correction_range( - self.beta_fast, - self.beta_slow, - dim, - self.base, - self.original_max_position_embeddings, - ) - inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) - self.inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + with paddle.amp.auto_cast(False): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / (self.base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim)) + freq_inter = 1.0 / (self.scaling_factor * self.base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim)) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) + self.inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask - t = paddle.arange(seq_len, dtype=paddle.float32) + t = paddle.arange(seq_len, dtype=paddle.float32) - freqs = paddle.outer(t, paddle.cast(self.inv_freq, dtype="float32")) + freqs = paddle.outer(t, paddle.cast(self.inv_freq, dtype="float32")) - _mscale = float( - yarn_get_mscale(self.scaling_factor, self.mscale) - / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) - ) + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) - emb = paddle.concat((freqs, freqs), axis=-1) - self.cos_cached = emb.cos() * _mscale - self.sin_cached = emb.sin() * _mscale + emb = paddle.concat((freqs, freqs), axis=-1) + self.cos_cached = emb.cos() * _mscale + self.sin_cached = emb.sin() * _mscale def rotate_half(x): @@ -592,7 +698,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, fuse_rope=False): b, s, h, d = k.shape k = k.reshape([b, s, h, d // 2, 2]).transpose([0, 1, 2, 4, 3]).reshape([b, s, h, d]) - if get_env_device() == "xpu" and fuse_rope: + if (get_env_device() == "xpu" or get_env_device() == "gpu") and fuse_rope: q_embed, k_embed, _ = fused_rotary_position_embedding( q, k, @@ -671,9 +777,84 @@ def forward(self, x): return down_proj + +class FusedNormGateFunc(paddle.autograd.PyLayer): + """recompute of postnorm and gate""" + + _current_norm_output = None + _current_invar = None + + @classmethod + def set_temporary_vars(cls, norm_output, invar): + FusedNormGateFunc._current_norm_output = norm_output + FusedNormGateFunc._current_invar = invar + + @classmethod + def clear_temporary_vars(cls): + FusedNormGateFunc._current_norm_output = None + FusedNormGateFunc._current_invar = None + + @staticmethod + def forward(ctx, x, rms_norm_weight, moe_gate_weight, eps): + ctx.dtype = paddle.float32 + norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + with paddle.amp.auto_cast(False): + gate_logits = F.linear(cast_if_needed(norm_output, ctx.dtype), cast_if_needed(moe_gate_weight, ctx.dtype)) + + ctx.save_for_backward(x, rms_norm_weight, moe_gate_weight, eps) + return gate_logits, norm_output + + @staticmethod + def backward(ctx, d_gate_logits, d_norm_output): + x, rms_norm_weight, moe_gate_weight, eps = ctx.saved_tensor() + # recompute rmsnorm + norm_output = FusedNormGateFunc._current_norm_output + invar = FusedNormGateFunc._current_invar + if norm_output is None or invar is None: + norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + d_norm_output_linear, d_moe_gate_weight = paddle._C_ops.matmul_grad( + cast_if_needed(norm_output, ctx.dtype), + cast_if_needed(moe_gate_weight, ctx.dtype), + d_gate_logits, + False, + False, + ) + d_norm_output_linear, d_moe_gate_weight = cast_if_needed( + d_norm_output_linear, norm_output.dtype + ), cast_if_needed(d_moe_gate_weight, moe_gate_weight.dtype) + + d_norm_output = d_norm_output + d_norm_output_linear + dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, rms_norm_weight, invar, d_norm_output, eps) + + return dx, d_rms_norm_weight, d_moe_gate_weight + + +class TemporaryVarContext: + def __init__(self, norm_output, invar): + self.norm_output = norm_output + self.invar = invar + + def __enter__(self): + FusedNormGateFunc.set_temporary_vars(self.norm_output, self.invar) + + def __exit__(self, exc_type, exc_val, exc_tb): + FusedNormGateFunc.clear_temporary_vars() + + +def balance_expert_assignment(n, m, k): + assert k * n % m == 0 + matrix = paddle.zeros((n, m), dtype=paddle.int32) + for row in range(n): + start_col = row % m + for i in range(k): + col = (start_col + i) % m + matrix[row, col] = 1 + return matrix + + class FakeGate(paddle.autograd.PyLayer): @staticmethod - def forward(ctx, hidden_states, weight): + def forward(ctx, hidden_states, weight, fakse_gate_restrict_balance=False, num_experts_per_tok=8): expert_num = weight.shape[1] bsz, seq, _ = hidden_states.shape @@ -681,8 +862,12 @@ def forward(ctx, hidden_states, weight): ctx.x_dtype = hidden_states.dtype ctx.y_shape = weight.shape ctx.y_dtype = weight.dtype - - return paddle.randn([bsz, seq, expert_num]).cast(weight.dtype) + if fakse_gate_restrict_balance: + return paddle.reshape( + balance_expert_assignment(bsz * seq, expert_num, num_experts_per_tok), [bsz, seq, expert_num] + ) + else: + return paddle.randn([bsz, seq, expert_num]).cast(weight.dtype) @staticmethod def backward(ctx, grad_output): @@ -882,6 +1067,792 @@ def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor: return hidden_states.reshape([batch, slen, num_key_value_heads * n_rep, head_axis]) +@to_static(backend="CINN") +def qkv_pre_process_no_fuse( + q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids +): + bsz, q_len, _ = q.shape + + target_query_shape = [0, 0, num_heads, q_head_dim] + target_key_value_shape = [0, 0, num_heads, qk_nope_head_dim + v_head_dim] + + q = q.reshape(shape=target_query_shape) + q_nope = q[..., :qk_nope_head_dim] + q_pe = q[..., qk_nope_head_dim:] + + # DeepSeekV2 kv_lora_rank+qk_rope_head_dim=512+64 + + kv = kv.reshape(shape=target_key_value_shape) + + k_pe = k_pe.reshape([-1, q_len, 1, qk_rope_head_dim]).expand([-1, q_len, num_heads, qk_rope_head_dim]) + + # self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim = 128+64 + # self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim) = config.qk_nope_head_dim + self.v_head_dim = 128+128 + k_nope = kv[..., :qk_nope_head_dim] + value_states = kv[..., qk_nope_head_dim:] + + kv_seq_len = value_states.shape[1] + + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids, False) + + query_states = paddle.concat([q_nope, q_pe], axis=-1) + key_states = paddle.concat([k_nope, k_pe], axis=-1) + + return query_states, key_states, value_states + + +@to_static(backend="CINN") +def rearrange_kv(kv, k_pe, qk_nope_head_dim, num_heads): + k_nope = kv[..., :qk_nope_head_dim] + value_states = kv[..., qk_nope_head_dim:] + + k_pe = k_pe.expand([k_pe.shape[0], k_pe.shape[1], num_heads, k_pe.shape[3]]) + key_states = paddle.concat([k_nope, k_pe], axis=-1) + + return key_states, value_states + + +def qkv_pre_process( + q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids +): + if (fused_partial_rope is None) or (position_ids is not None): + return qkv_pre_process_no_fuse( + q, + kv, + k_pe, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + ) + + bsz, q_len, _ = q.shape + + target_query_shape = [0, 0, num_heads, q_head_dim] + target_key_value_shape = [0, 0, num_heads, qk_nope_head_dim + v_head_dim] + + q = q.reshape(shape=target_query_shape) + kv = kv.reshape(shape=target_key_value_shape) + k_pe = k_pe.reshape([-1, q_len, 1, qk_rope_head_dim]) + + value_states = kv[..., qk_nope_head_dim:] + + kv_seq_len = value_states.shape[1] + + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + + query_states = fused_partial_rope(q, cos, sin) + k_pe = fused_partial_rope(k_pe, cos, sin) + + key_states, value_states = rearrange_kv(kv, k_pe, qk_nope_head_dim, num_heads) + + return query_states, key_states, value_states + + +def manul_fwd( + q_init, + kv_init, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, +): + + q_ln_t, q_ln_invar = fused_ln.fused_rms_norm(q_init, q_ln_weight, eps) + q = paddle.matmul(q_ln_t, q_up_weight) + + compressed_kv, k_pe = paddle.split(kv_init, [kv_lora_rank, qk_rope_head_dim], axis=-1) + + kv_ln_t, kv_ln_invar = fused_ln.fused_rms_norm(compressed_kv, kv_ln_weight, eps) + + kv = paddle.matmul(kv_ln_t, kv_up_weight) + + query_states, key_states, value_states = qkv_pre_process( + q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids + ) + + q_head_dim = query_states.shape[-1] + softmax_scale = softmax_scale * (q_head_dim**0.5) + query_states = query_states * softmax_scale + + attn_out, _, softmax_lse, seed_offset = _C_ops.flash_attn( + query_states, + key_states, + query_states, + None, + None, + 0.0, + True, + False, + False, + "", + ) + + return attn_out + + +class MemroyRecomputeAttnFunc(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx, + q_init, + kv_init, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + ): + + bsz = q_init.shape[0] + q_ln_t, q_ln_invar = fused_ln.fused_rms_norm(q_init, q_ln_weight, eps) + # q = paddle.matmul(q_ln_t, q_up_weight) + q_orig_shape = q_ln_t.shape + q = FP8LinearFunctionBase.compute_fp8_linear( + q_ln_t.reshape([-1, q_orig_shape[-1]]), q_up_weight, weight_transpose=True, return_transpose_only=True + ) + q = q.reshape(q_orig_shape[:-1] + [q_up_weight.shape[-1]]) + + compressed_kv, k_pe = paddle.split(kv_init, [kv_lora_rank, qk_rope_head_dim], axis=-1) + + kv_ln_t, kv_ln_invar = fused_ln.fused_rms_norm(compressed_kv, kv_ln_weight, eps) + # kv = paddle.matmul(kv_ln_t, kv_up_weight) + kv_orig_shape = kv_ln_t.shape + kv = FP8LinearFunctionBase.compute_fp8_linear( + kv_ln_t.reshape([-1, kv_orig_shape[-1]]), kv_up_weight, weight_transpose=True, return_transpose_only=True + ) + kv = kv.reshape(kv_orig_shape[:-1] + [kv_up_weight.shape[-1]]) + + query_states, key_states, value_states = qkv_pre_process( + q, + kv, + k_pe, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + ) + + q_head_dim = query_states.shape[-1] + + if FA_VERSION == 2: + softmax_scale = softmax_scale * (q_head_dim**0.5) + query_states = query_states * softmax_scale + kv_seq_len = value_states.shape[1] + v_num_heads = value_states.shape[2] + value_padding = paddle.zeros( + [bsz, kv_seq_len, v_num_heads, q_head_dim - v_head_dim], + dtype=value_states.dtype, + ) + value_states_pad = paddle.concat([value_states, value_padding], axis=-1) + + attn_out, _, softmax_lse, seed_offset = _C_ops.flash_attn( + query_states, + key_states, + value_states_pad, + None, + None, + 0.0, + True, + False, + False, + "", + ) + + elif FA_VERSION == 3: + attn_out, softmax_lse = _C_ops.flash_attn_v3( + query_states, + key_states, + value_states, + None, # q_v_ + None, # q_descale_ + None, # k_descale_ + None, # v_descale_ + softmax_scale, + True, + -1, # window_size_left + -1, # window_size_right + 0.0, # softcap + 1, # num_splits + False, # manual_set_pack_gqa + False, # pack_gqa_ + 0, # sm_margin + ) + else: + assert False, f"invalid {FA_VERSION=}" + + if FA_VERSION == 2: + ctx.save_for_backward( + q_init, + kv_init, + attn_out, + softmax_lse, + seed_offset, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + ) + elif FA_VERSION == 3: + ctx.save_for_backward( + q_init, + kv_init, + attn_out, + softmax_lse, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + ) + else: + assert False, f"invalid {FA_VERSION=}" + + return attn_out + + @staticmethod + def backward(ctx, dout): + if FA_VERSION == 2: + ( + q_init, + kv_init, + attn_out, + softmax_lse, + seed_offset, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + ) = ctx.saved_tensor() + elif FA_VERSION == 3: + ( + q_init, + kv_init, + attn_out, + softmax_lse, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + ) = ctx.saved_tensor() + else: + assert False, f"invalid {FA_VERSION=}" + + q_ln_t, q_ln_invar = fused_ln.fused_rms_norm(q_init, q_ln_weight, eps) + + q_ln_fp8, q_ln_scale, q_ln_trans_fp8, q_ln_trans_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + q_ln_t.reshape([-1, q_ln_t.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + ) + + q_orig_shape = q_ln_t.shape + q = FP8LinearFunctionBase.compute_fp8_linear( + (q_ln_fp8, q_ln_scale), q_up_weight, weight_transpose=True, return_transpose_only=True + ) + q = q.reshape(q_orig_shape[:-1] + [q_up_weight.shape[-1]]) + + compressed_kv, k_pe = paddle.split(kv_init, [kv_lora_rank, qk_rope_head_dim], axis=-1) + + kv_ln_t, kv_ln_invar = fused_ln.fused_rms_norm(compressed_kv, kv_ln_weight, eps) + + kv_ln_fp8, kv_ln_scale, kv_ln_trans_fp8, kv_ln_trans_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + kv_ln_t.reshape([-1, kv_ln_t.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + ) + kv_orig_shape = kv_ln_t.shape + kv = FP8LinearFunctionBase.compute_fp8_linear( + (kv_ln_fp8, kv_ln_scale), kv_up_weight, weight_transpose=True, return_transpose_only=True + ) + kv = kv.reshape(kv_orig_shape[:-1] + [kv_up_weight.shape[-1]]) + + paddle.base.core._set_has_grad(True) + q.stop_gradient = False + kv.stop_gradient = False + k_pe.stop_gradient = False + query_states, key_states, value_states = qkv_pre_process( + q, + kv, + k_pe, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + ) + + if FA_VERSION == 2: + q_head_dim = query_states.shape[-1] + query_states = query_states * softmax_scale + + bsz = value_states.shape[0] + kv_seq_len = value_states.shape[1] + v_num_heads = value_states.shape[2] + value_padding = paddle.zeros( + [bsz, kv_seq_len, v_num_heads, q_head_dim - v_head_dim], + dtype=value_states.dtype, + ) + value_states_pad = paddle.concat([value_states, value_padding], axis=-1) + + with paddle.no_grad(): + + q_grad, k_grad, v_grad = _C_ops.flash_attn_grad( + query_states, + key_states, + value_states_pad, + attn_out, + softmax_lse.view("bfloat16"), + seed_offset, + None, + dout, + 0.0, + True, + ) + + v_grad = v_grad[..., :v_head_dim] + q_grad = q_grad * softmax_scale + elif FA_VERSION == 3: + with paddle.no_grad(): + q_grad, k_grad, v_grad = _C_ops.flash_attn_v3_grad( + query_states, + key_states, + value_states, + attn_out, + softmax_lse.view("bfloat16"), + dout, + softmax_scale, + True, + -1, + -1, + 0.0, + 0, + ) + else: + assert False, f"invalid {FA_VERSION=}" + + d_q, d_kv, d_k_pe = paddle.grad( + outputs=[query_states, key_states, value_states], + inputs=[q, kv, k_pe], + grad_outputs=[q_grad, k_grad, v_grad], + create_graph=False, + retain_graph=False, + ) + + paddle.base.core._set_has_grad(False) + + # call up proj + if hasattr(kv_up_weight, "main_grad"): + d_kv_fp8, d_kv_scale, d_kv_t_fp8, d_kv_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + d_kv.reshape([-1, d_kv.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + ) + + d_kv_ln_t = FP8LinearFunctionBase.compute_fp8_linear( + (d_kv_fp8, d_kv_scale), kv_up_weight, weight_transpose=False + ) + d_kv_ln_t = d_kv_ln_t.reshape(d_kv.shape[:-1] + [kv_up_weight.shape[0]]) + + def kv_up_weight_grad(kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_scale, kv_up_weight): + FP8LinearFunctionBase.kitchen_gemm( + kv_ln_trans_fp8, + kv_ln_trans_scale, + d_kv_t_fp8, + d_kv_t_scale, + True, + True, + kv_up_weight.main_grad, + paddle.float32, + ) + + if WeightGradStore.enabled: + + WeightGradStore.put( + partial( + kv_up_weight_grad, kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_scale, kv_up_weight + ) + ) + else: + kv_up_weight_grad(kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_scale, kv_up_weight) + + d_kv_up_weight = None + + else: + d_kv_ln_t, d_kv_up_weight = _C_ops.matmul_grad(kv_ln_t, kv_up_weight, d_kv, False, False) + + d_compressed_kv, d_kv_ln_weight = fused_ln.fused_rms_norm_grad_func( + compressed_kv, kv_ln_weight, kv_ln_invar, d_kv_ln_t, eps + ) + + d_kv_init = paddle.concat([d_compressed_kv, d_k_pe], axis=-1) + + if hasattr(q_up_weight, "main_grad"): + + d_q_fp8, d_q_scale, d_q_t_fp8, d_q_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + d_q.reshape([-1, d_q.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + ) + # d_q_ln_t = paddle.matmul(d_q, q_up_weight, transpose_y=True) + + d_q_ln_t = FP8LinearFunctionBase.compute_fp8_linear( + (d_q_fp8, d_q_scale), q_up_weight, weight_transpose=False + ) + d_q_ln_t = d_q_ln_t.reshape(d_q.shape[:-1] + [q_up_weight.shape[0]]) + + def q_up_weight_grad(q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q_up_weight): + FP8LinearFunctionBase.kitchen_gemm( + q_ln_trans_fp8, + q_ln_trans_scale, + d_q_t_fp8, + d_q_t_scale, + True, + True, + q_up_weight.main_grad, + paddle.float32, + ) + + if WeightGradStore.enabled: + WeightGradStore.put( + partial(q_up_weight_grad, q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q_up_weight) + ) + else: + q_up_weight_grad(q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q_up_weight) + + d_q_up_weight = None + + else: + d_q_ln_t, d_q_up_weight = _C_ops.matmul_grad(q_ln_t, q_up_weight, d_q, False, False) + + d_q_init, d_q_ln_weight = fused_ln.fused_rms_norm_grad_func(q_init, q_ln_weight, q_ln_invar, d_q_ln_t, eps) + + return d_q_init, d_kv_init, d_q_ln_weight, d_kv_ln_weight, d_q_up_weight, d_kv_up_weight + + +class MemroyRecomputeAttn(paddle.nn.Layer): + def __init__( + self, + q_norm_hidden_size, + kv_norm_hidden_size, + q_up_in_dim, + q_up_out_dim, + kv_up_in_dim, + kv_up_out_dim, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + eps, + kv_lora_rank, + softmax_scale, + ) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.q_ln_weight = paddle.create_parameter( + shape=[q_norm_hidden_size], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(1.0), + ) + self.kv_ln_weight = paddle.create_parameter( + shape=[kv_norm_hidden_size], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(1.0), + ) + + self.q_up_weight = self.create_parameter( + shape=[q_up_in_dim, q_up_out_dim], + dtype=self._dtype, + is_bias=False, + ) + + self.kv_up_weight = self.create_parameter( + shape=[kv_up_in_dim, kv_up_out_dim], + dtype=self._dtype, + is_bias=False, + ) + ( + self.rotary_emb, + self.num_heads, + self.q_head_dim, + self.qk_nope_head_dim, + self.v_head_dim, + self.qk_rope_head_dim, + self.eps, + self.kv_lora_rank, + self.softmax_scale, + ) = ( + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + eps, + kv_lora_rank, + softmax_scale, + ) + set_parameter_color([self.q_up_weight, self.kv_up_weight], "memory_attn") + + def fp8_quant_weight(self): + cache_fp8_weight(self.q_up_weight) + cache_fp8_weight(self.kv_up_weight) + + def forward(self, q_init, kv_init, position_ids): + + seq_len = q_init.shape[1] + + if self.rotary_emb.max_seq_len_cached is None or seq_len > self.rotary_emb.max_seq_len_cached: + self.rotary_emb._set_cos_sin_cache(seq_len) + + return MemroyRecomputeAttnFunc.apply( + q_init, + kv_init, + self.q_ln_weight, + self.kv_ln_weight, + self.q_up_weight, + self.kv_up_weight, + self.rotary_emb, + self.num_heads, + self.q_head_dim, + self.qk_nope_head_dim, + self.v_head_dim, + self.qk_rope_head_dim, + position_ids, + self.eps, + self.kv_lora_rank, + self.softmax_scale, + ) + + +class FusedRMSLinearFunc(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, rms_norm_weight, q_down_weight, kv_down_weight, eps): + + hidden_states, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + + h_fp8, h_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + hidden_states.reshape([-1, hidden_states.shape[-1]]), output_scale_transpose=True, quant_method="1x128" + ) + + h_orig_shape = hidden_states.shape + q = FP8LinearFunctionBase.compute_fp8_linear( + (h_fp8, h_scale), q_down_weight, weight_transpose=True, return_transpose_only=True + ) + q = q.reshape(h_orig_shape[:-1] + [q_down_weight.shape[-1]]) + + kv = paddle.matmul(hidden_states, kv_down_weight) + + ctx.save_for_backward(x, rms_norm_weight, q_down_weight, kv_down_weight) + ctx.eps = eps + return q, kv + + @staticmethod + def backward(ctx, d_q, d_kv): + x, rms_norm_weight, q_down_weight, kv_down_weight = ctx.saved_tensor() + eps = ctx.eps + hidden_states, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + + h_t_fp8, h_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + hidden_states.reshape([-1, hidden_states.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + return_transpose_only=True, + ) + + h_grad, d_kv_down_weight = _C_ops.matmul_grad(hidden_states, kv_down_weight, d_kv, False, False) + + if hasattr(q_down_weight, "main_grad"): + d_q_fp8, d_q_scale, d_q_t_fp8, d_q_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + d_q.reshape([-1, d_q.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + ) + FP8LinearFunctionBase.compute_fp8_linear( + (d_q_fp8, d_q_scale), q_down_weight, weight_transpose=False, out=h_grad.view([-1, h_grad.shape[-1]]) + ) + + def q_down_weight_grad(h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight): + FP8LinearFunctionBase.kitchen_gemm( + h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, True, True, q_down_weight.main_grad, paddle.float32 + ) + + if WeightGradStore.enabled: + WeightGradStore.put( + partial(q_down_weight_grad, h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight) + ) + else: + q_down_weight_grad(h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight) + + d_q_down_weight = None + + else: + h_grad_0, d_q_down_weight = _C_ops.matmul_grad(hidden_states, q_down_weight, d_q, False, False) + h_grad = h_grad + h_grad_0 + + dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, rms_norm_weight, invar, h_grad, eps) + + return dx, d_rms_norm_weight, d_q_down_weight, d_kv_down_weight + + +class FusedRMSLinear(paddle.nn.Layer): + def __init__(self, hidden_size, q_out_dim, kv_outdim, eps=1e-6) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.rms_norm_weight = paddle.create_parameter( + shape=[hidden_size], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(1.0), + ) + + self.q_down_weight = self.create_parameter( + shape=[hidden_size, q_out_dim], + dtype=self._dtype, + is_bias=False, + ) + + self.kv_down_weight = self.create_parameter( + shape=[hidden_size, kv_outdim], + dtype=self._dtype, + is_bias=False, + ) + self.eps = eps + set_parameter_color([self.q_down_weight], "rms_linear") + + def fp8_quant_weight(self): + cache_fp8_weight(self.q_down_weight) + + def forward(self, x): + + return FusedRMSLinearFunc.apply(x, self.rms_norm_weight, self.q_down_weight, self.kv_down_weight, self.eps) + + +class FusedRMSLinearSingleFunc(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, rms_norm_weight, linear_weight, eps): + + hidden_states, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + q = paddle.matmul(hidden_states, linear_weight) + + ctx.save_for_backward(x, rms_norm_weight, linear_weight, eps) + return q + + @staticmethod + def backward(ctx, d_q, d_kv): + x, rms_norm_weight, linear_weight, eps = ctx.saved_tensor() + hidden_states, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + + h_grad, d_linear_weight = _C_ops.matmul_grad(hidden_states, linear_weight, d_q, False, False) + + dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, rms_norm_weight, invar, h_grad, eps) + + return dx, d_rms_norm_weight, d_linear_weight + + +class FusedRMSLinearSingle(paddle.nn.Layer): + def __init__(self, hidden_size, q_out_dim, kv_outdim, eps=1e-6) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.rms_norm_weight = paddle.create_parameter( + shape=[hidden_size], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(1.0), + ) + + self.linear_weight = self.create_parameter( + shape=[hidden_size, q_out_dim], + dtype=self._dtype, + is_bias=False, + ) + self.eps = eps + + def forward(self, x): + + return FusedRMSLinearFunc.apply(x, self.rms_norm_weight, self.linear_weight, self.eps) + + # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 class DeepseekV2Attention(nn.Layer): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -1922,10 +2893,11 @@ def compute_loss(preds, labels): masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss) ) count = paddle.sum(binary_sequence) - if count == 0: - loss = paddle.sum(masked_lm_loss * binary_sequence) - else: - loss = paddle.sum(masked_lm_loss * binary_sequence) / count + loss = paddle.where( + count == 0, + paddle.sum(masked_lm_loss * binary_sequence), + paddle.sum(masked_lm_loss * binary_sequence) / count, + ) return loss def add_loss(main_loss, loss): @@ -1956,7 +2928,7 @@ def add_loss(main_loss, loss): class DeepseekV2LMHead(nn.Layer): - def __init__(self, config: DeepseekV2Config): + def __init__(self, config: DeepseekV2Config, embedding_weight=None): super(DeepseekV2LMHead, self).__init__() self.config = config @@ -1970,11 +2942,16 @@ def __init__(self, config: DeepseekV2Config): else: vocab_size = config.vocab_size - self.weight = self.create_parameter( - shape=[config.hidden_size, vocab_size], - dtype=paddle.get_default_dtype(), - default_initializer=nn.initializer.XavierNormal(1.0), - ) + if embedding_weight is not None: + self.transpose_y = True + self.weight = embedding_weight + else: + self.transpose_y = False + self.weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.XavierNormal(1.0), + ) # Must set distributed attr for Tensor Parallel ! self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False if get_env_device() == "xpu": @@ -2004,7 +2981,7 @@ def forward(self, hidden_states, tensor_parallel_output=None): training=self.training, ) else: - logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) + logits = parallel_matmul(hidden_states, self.weight, transpose_y=self.transpose_y, tensor_parallel_output=tensor_parallel_output) return logits def extra_repr(self): diff --git a/paddleformers/transformers/deepseek_v2/modeling_fast.py b/paddleformers/transformers/deepseek_v2/modeling_fast.py new file mode 100644 index 00000000000..db176683fc4 --- /dev/null +++ b/paddleformers/transformers/deepseek_v2/modeling_fast.py @@ -0,0 +1,1580 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 DeepSeek. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Paddle DeepSeek model.""" + +import contextlib +import math +import os +import warnings +from functools import partial +from typing import List, Optional, Tuple, Union + +import paddle +import paddle.distributed as dist +import paddle.distributed.fleet.meta_parallel as mpu +import paddle.nn.functional as F +from paddle import Tensor, nn +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.distributed.fleet.recompute.recompute import recompute +from paddle.jit import to_static +from paddle.utils import try_import + +try: + from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + GatherOp, + ScatterOp, + mark_as_sequence_parallel_parameter, + ) +except: + pass + +from paddle import _C_ops +from paddleformers.transformers.model_utils import dtype_guard + +from ...utils.initializer import kaiming_uniform_ +from ...utils.log import logger +from ...utils.tools import get_env_device +from ..activations import ACT2FN +from ..conversion_utils import StateDictNameMapping, init_name_mappings +from ..llama import fusion_ops +from ..llama.modeling import get_use_casual_mask +from ..model_outputs import ( + BaseModelOutputWithPastAndMTP, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from ..model_utils import PretrainedModel, dtype_guard, register_base_model +from ..moe_gate import PretrainedMoEGate +from ..moe_layer import MoEFlexTokenLayer, MoELayer +from ..utils import cast_if_needed, device_guard +from . import fp8_linear as linear_utils +from .configuration import DeepseekV2Config + +FA_VERSION = int(os.getenv("FA_VERSION", 2)) + +from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import WeightGradStore + +from ..fp8_utils import ( + FP8KeepXLinear, + FP8Linear, + FP8Mlp, + set_parameter_color, +) +from .fp8_linear import Linear + +DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true" +DSV3_USE_ATTEN_RECOMPUTE = os.getenv("DSV3_USE_ATTEN_RECOMPUTE", "False").lower() == "true" + +Linear = FP8Linear if DSV3_USE_FP8_GEMM else Linear + +try: + import fused_ln + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + + +try: + from paddle.incubate.nn.functional import fused_partial_rope +except ImportError: + fused_partial_rope = None + +__all__ = [ + "DeepseekV2ModelFast", + "DeepseekV2PretrainedModelFast", +] + +from .modeling import (set_global_step, scaled_dot_product_attention, is_casual_mask, _make_causal_mask, _expand_2d_mask, yarn_get_mscale, apply_rotary_pos_emb, DeepseekV2RMSNorm, DeepseekV2YarnRotaryEmbedding, FusedRMSLinear, MemroyRecomputeAttn, FusedNormGateFunc, FakeGate) + +class DeepseekV2MLP(nn.Layer): + def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None, is_moe=False): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + self.fuse_attention_ffn = config.fuse_attention_ffn + + def linear_dtype_gaurd(): + if config.use_fp8: + return dtype_guard("float8_e4m3fn") + else: + return contextlib.nullcontext() + + if config.sequence_parallel: + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear + else: + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear + + with linear_dtype_gaurd(): + if config.tensor_parallel_degree > 1 and not is_moe: + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + input_is_parallel=True, + has_bias=False, + ) + else: + if config.fuse_attention_ffn: + self.gate_up_fused_proj = Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) + else: + self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.fuse_attention_ffn: + x = swiglu(self.gate_up_fused_proj(x)) + else: + x = swiglu(self.gate_proj(x), self.up_proj(x)) + out = self.down_proj(x) + return out + + +class MoEGate(PretrainedMoEGate): + def __init__( + self, + config, + num_experts, + expert_hidden_size, + using_post_norm_recompute=False, + norm_weight=None, + norm_eps=None, + **kwargs + ): + super().__init__(config, num_experts, expert_hidden_size, **kwargs) + # [hidden_size, n_expert] + + self.scoring_func = config.scoring_func + self.topk_method = config.topk_method + + self.weight = paddle.create_parameter( + shape=[expert_hidden_size, num_experts], + dtype=paddle.float32, + is_bias=False, + # default_initializer=nn.initializer.Constant(1.0), + ) + + self.config = config + self.using_post_norm_recompute = using_post_norm_recompute + + if config.topk_method == "noaux_tc": + self.e_score_correction_bias = paddle.create_parameter( + shape=[num_experts], + dtype=paddle.float32, + default_initializer=nn.initializer.Constant(0.0), + ) + self.e_score_correction_bias.is_distributed = True + + if self.using_post_norm_recompute: + assert norm_weight is not None and norm_eps is not None + self.norm_weight = norm_weight + self.norm_eps = norm_eps + + self.using_flex_token = False + + def forward(self, hidden_states): + """ + Args: + hidden_states (_type_): [batch_size * seq_len, hidden_size] + """ + _, _, h_dim = hidden_states.shape + + # compute gating score + if self.using_post_norm_recompute: + logits, norm_out = FusedNormGateFunc.apply(hidden_states, self.norm_weight, self.weight, self.norm_eps) + if hasattr(self.config, "using_fake_gate") and self.config.using_fake_gate: + logits = FakeGate.apply( + hidden_states, + self.weight, + self.config.fakse_gate_restrict_balance, + self.config.num_experts_per_tok, + ) + else: + with paddle.amp.auto_cast(False): + hidden_states = hidden_states.cast(self.weight.dtype) + if hasattr(self.config, "using_fake_gate") and self.config.using_fake_gate: + logits = FakeGate.apply( + hidden_states, + self.weight, + self.config.fakse_gate_restrict_balance, + self.config.num_experts_per_tok, + ) + else: + logits = F.linear(hidden_states, self.weight, None) + + scores = self.gate_score_func(logits=logits) + scores = scores.cast(paddle.float32) + + # Compute all possible return values + if self.using_flex_token: + scores, routing_map, exp_counts, l_aux, l_zloss = self.topkgating_nodrop( + scores + ) # (scores, routing_map, exp_counts, l_aux, l_zloss) + ret = (scores, routing_map, l_aux, l_zloss) + else: + ret = self.topkgating(scores) # (capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss) + + # Append norm_out if needed + if self.using_post_norm_recompute: + ret = (*ret, norm_out) + + return ret + + +class AddAuxiliaryLoss(paddle.autograd.PyLayer): + """ + The trick function of adding auxiliary (aux) loss, + which includes the gradient of the aux loss during backpropagation. + """ + + @staticmethod + def forward(ctx, x, loss): + ctx.dtype = loss.dtype + ctx.required_aux_loss = not loss.stop_gradient + return x.clone() # clone to avoid inplace problem when using overlap + + @staticmethod + def backward(ctx, grad_output): + grad_loss = None + if ctx.required_aux_loss: + grad_loss = paddle.ones(1, dtype=ctx.dtype) + return grad_output, grad_loss + + +class DeepseekV2MoE(MoELayer): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config: DeepseekV2Config, norm_weight=None, norm_eps=None): + assert config.tensor_parallel_degree <= 1, "tensor_parallel_degree should be 1" + + self.using_post_norm_recompute = config.using_post_norm_recompute + if self.using_post_norm_recompute: + assert norm_weight is not None and norm_eps is not None + + gate = MoEGate( + config=config, + num_experts=config.n_routed_experts, + expert_hidden_size=config.hidden_size, + top_k=config.num_experts_per_tok, + topk_method=config.topk_method, + n_group=config.n_group, + topk_group=config.topk_group, + norm_topk_prob=config.norm_topk_prob, + routed_scaling_factor=config.routed_scaling_factor, + drop_tokens=False, + using_post_norm_recompute=self.using_post_norm_recompute, + norm_weight=norm_weight, + norm_eps=norm_eps, + ) + DeepseekV2MLPClass = FP8Mlp if DSV3_USE_FP8_GEMM else DeepseekV2MLP + + super().__init__( + config=config, + moe_num_experts=config.n_routed_experts, + expert_class=DeepseekV2MLPClass, + expert_kwargs={ + "config": config, + "intermediate_size": config.moe_intermediate_size, + "is_moe": True, + }, + gate=gate, + capacity=2.0, + moe_group="expert", + using_post_norm_recompute=self.using_post_norm_recompute, + ) + + if config.offline_quant_expert_weight and config.clear_origin_weight_when_offline_quant: + moe_grad_group = fleet.get_hybrid_communicate_group().expert_grad_comm_group + expert_w1_list = [expert.w1 for expert in self.experts if expert is not None] + expert_w2_list = [expert.w2 for expert in self.experts if expert is not None] + for p in expert_w1_list: + setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) + for p in expert_w2_list: + setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) + + self.alpha = config.aux_loss_alpha + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + if self.using_post_norm_recompute: + assert DeepseekV2MLPClass is FP8Mlp + self.shared_experts = DeepseekV2MLPClass( + config=config, + intermediate_size=intermediate_size, + is_moe=False, + using_post_norm_recompute=self.using_post_norm_recompute, + norm_weight=norm_weight, + norm_eps=norm_eps, + recompute_fwd_gate_up=True, + ) + else: + self.shared_experts = DeepseekV2MLPClass( + config=config, intermediate_size=intermediate_size, is_moe=False + ) + set_parameter_color([self.shared_experts.w1, self.shared_experts.w2], "shared_expert") + + def fp8_quant_weight(self, batch_mode=False): + """Quantize weights in FP8 format. + + Args: + batch_mode: If True, quantize all weights in batch mode using the first expert's weights. + If False, quantize each expert's weights individually. + """ + + def quantize_weights(weight_list, weight_obj=None): + """Helper function to quantize a list of weights.""" + if weight_obj is None: + weight_obj = weight_list[0] + if hasattr(weight_obj, "fp8_weight_stacked"): + return + + # Quantize without transpose + fp8_weight, fp8_scale = paddle.incubate.nn.functional.fused_stack_transpose_quant( + weight_list, transpose=False + ) + setattr(weight_obj, "fp8_weight_stacked", fp8_weight) + setattr(weight_obj, "fp8_scale_stacked", fp8_scale) + + # Quantize with transpose + fp8_weight_t, fp8_scale_t = paddle.incubate.nn.functional.fused_stack_transpose_quant( + weight_list, transpose=True + ) + setattr(weight_obj, "fp8_weight_stacked_transpose", fp8_weight_t) + setattr(weight_obj, "fp8_scale_stacked_transpose", fp8_scale_t) + + if batch_mode: + # Batch mode: process all experts' weights together + expert_w1_list = [expert.w1 for expert in self.experts if expert is not None] + expert_w2_list = [expert.w2 for expert in self.experts if expert is not None] + + if expert_w1_list: + quantize_weights(expert_w1_list, expert_w1_list[0]) + if expert_w2_list: + quantize_weights(expert_w2_list, expert_w2_list[0]) + else: + # Individual mode: process each expert's weights separately + for expert in self.experts: + if expert is not None: + quantize_weights([expert.w1]) + quantize_weights([expert.w1]) + + if self.config.n_shared_experts is not None: + self.shared_experts.fp8_quant_weight() + + def forward(self, hidden_states): + if self.using_post_norm_recompute: + super().update_flex_token() + if self.using_flex_token: + probs, routing_map, l_aux, l_zloss, norm_out = self.router(hidden_states) + final_hidden_states, l_aux, l_zloss = super().forward( + norm_out, probs=probs, routing_map=routing_map, l_aux=l_aux, l_zloss=l_zloss + ) + else: + capacity, topk_weight, topk_ids, token_priority, l_aux, l_zloss, norm_out = self.gate(hidden_states) + final_hidden_states, l_aux, l_zloss = super().forward( + norm_out, + capacity=capacity, + topk_weight=topk_weight, + topk_ids=topk_ids, + token_priority=token_priority, + l_aux=l_aux, + l_zloss=l_zloss, + ) + final_hidden_states = self.post_process(hidden_states, final_hidden_states, l_aux) + else: + final_hidden_states, l_aux, l_zloss = super().forward(hidden_states) + final_hidden_states = self.post_process(hidden_states, final_hidden_states, l_aux) + return final_hidden_states + + def post_process(self, hidden_states, final_hidden_states, l_aux): + if self.training and self.alpha > 0.0: + l_aux = l_aux * self.alpha + final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, l_aux) + + if self.config.n_shared_experts is not None: + shared_expert_output = self.shared_experts(hidden_states) + final_hidden_states = final_hidden_states + shared_expert_output + return final_hidden_states + +# # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 +class DeepseekV2Attention(nn.Layer): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False): + super().__init__() + self.config = config + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.is_causal = True + self.fuse_rope = config.use_fused_rope + + if config.num_nextn_predict_layers > 0: + self.seq_length = config.seq_length - config.num_nextn_predict_layers + else: + self.seq_length = config.seq_length + self.sequence_parallel = config.sequence_parallel + + self.input_layernorm = DeepseekV2RMSNorm(config) + + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + + def linear_dtype_gaurd(): + if config.use_fp8: + return dtype_guard("float8_e4m3fn") + else: + return contextlib.nullcontext() + + # Note (@DrownFish19): For tensor parallel we consider that q_a_proj and kv_a_proj_with_mqa + # are the small weight and cannot achieve performance gain. So we use the original + # linear layers. We use the tensor parallel linear layers for q_proj,q_b_proj and kv_b_proj + # for which are the large weight and can achieve performance gain. + + self._init_rope() + self.softmax_scale = self.q_head_dim ** (-0.5) + + # fmt: off + if self.config.tensor_parallel_degree > 1: + # for tensor parallel + if config.sequence_parallel: + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear + else: + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear + + if self.q_lora_rank is None: + with linear_dtype_gaurd(): + self.q_proj = ColumnParallelLinear(self.hidden_size, self.num_heads * self.q_head_dim, has_bias=False, gather_output=True) + else: + with linear_dtype_gaurd(): + self.q_a_proj = Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias) + self.q_b_proj = ColumnParallelLinear(config.q_lora_rank, self.num_heads * self.q_head_dim, has_bias=False, gather_output=True) + self.q_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.q_lora_rank, use_sequence_parallel=False) + + with linear_dtype_gaurd(): + self.kv_a_proj_with_mqa = paddle.nn.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias) + self.kv_b_proj = ColumnParallelLinear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), has_bias=False, gather_output=True) + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, self.hidden_size, has_bias=config.attention_bias, input_is_parallel=False) + self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank, use_sequence_parallel=False) + else: + # for without tensor parallel + if DSV3_USE_ATTEN_RECOMPUTE: + self.fused_rms_norm_linear = FusedRMSLinear(self.hidden_size, config.q_lora_rank, config.kv_lora_rank + config.qk_rope_head_dim, 1e-6) + kv_up_dim = self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim) + self.memory_recompute_att = MemroyRecomputeAttn(config.q_lora_rank, config.kv_lora_rank, config.q_lora_rank, self.num_heads * self.q_head_dim, config.kv_lora_rank, kv_up_dim, self.rotary_emb, self.num_heads, self.q_head_dim, self.qk_nope_head_dim, self.v_head_dim, self.qk_rope_head_dim, 1e-6, self.kv_lora_rank, self.softmax_scale) + self.o_proj = FP8KeepXLinear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias) + else: + + if self.q_lora_rank is None: + with linear_dtype_gaurd(): + self.q_proj = Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias_attr=False) + else: + with linear_dtype_gaurd(): + self.q_a_proj = Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias) + self.q_b_proj = Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias_attr=False) + self.q_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.q_lora_rank) + + with linear_dtype_gaurd(): + self.kv_a_proj_with_mqa = paddle.nn.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias) + self.kv_b_proj = Linear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias_attr=False) + self.o_proj = Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias) + self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank) + + # fmt: on + self.softmax_scale = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + self.attn_func = scaled_dot_product_attention + + def fp8_quant_weight(self): + + if DSV3_USE_ATTEN_RECOMPUTE: + self.o_proj.fp8_quant_weight() + self.memory_recompute_att.fp8_quant_weight() + self.fused_rms_norm_linear.fp8_quant_weight() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = DeepseekV2RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "yarn": + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = DeepseekV2YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: paddle.Tensor, seq_len: int, bsz: int): + return tensor.reshape([bsz, seq_len, self.num_heads, self.v_head_dim]).transpose([1, 0, 2, 3]) + + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.shape + + # DeepSeekV2 q_lora_rank=1536 + # DeepSeekV2-lite q_lora_rank=None + if DSV3_USE_ATTEN_RECOMPUTE: + + q_t1, compressed_kv = self.fused_rms_norm_linear(hidden_states) + + outputs = self.memory_recompute_att(q_t1, compressed_kv, position_ids) + + if self.v_head_dim * self.num_heads != outputs.shape[-1]: + outputs = outputs.reshape([bsz, q_len, self.num_heads, -1]) + outputs = outputs[..., : self.v_head_dim] + outputs = outputs.reshape([bsz, q_len, -1]) + else: + # 这里多了一个layernorm,是因为把 DeepseekV2Attention 之外的一次计算放进来了 + hidden_states = self.input_layernorm(hidden_states) + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + + if self.sequence_parallel: + target_query_shape = [-1, self.seq_length, self.num_heads, self.q_head_dim] + target_key_value_shape = [-1, self.seq_length, self.num_heads, self.qk_nope_head_dim + self.v_head_dim] + else: + target_query_shape = [0, 0, self.num_heads, self.q_head_dim] + target_key_value_shape = [0, 0, self.num_heads, self.qk_nope_head_dim + self.v_head_dim] + + q = q.reshape(shape=target_query_shape) + q_nope, q_pe = paddle.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) + + # DeepSeekV2 kv_lora_rank+qk_rope_head_dim=512+64 + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = paddle.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], axis=-1) + if self.sequence_parallel: + k_pe = GatherOp.apply(k_pe) + k_pe = k_pe.reshape([-1, q_len, 1, self.qk_rope_head_dim]).expand( + [-1, q_len, self.num_heads, self.qk_rope_head_dim] + ) + + # self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim = 128+64 + # self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim) = config.qk_nope_head_dim + self.v_head_dim = 128+128 + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).reshape(shape=target_key_value_shape) + + k_nope, value_states = paddle.split(kv, [self.qk_nope_head_dim, self.v_head_dim], axis=-1) + kv_seq_len = value_states.shape[1] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-3] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids, self.fuse_rope) + + query_states = paddle.concat([q_nope, q_pe], axis=-1) + key_states = paddle.concat([k_nope, k_pe], axis=-1) + + # [bs, seq_len, num_head, head_dim] + if past_key_value is not None: + # reuse k, v, self_attention + key_states = paddle.concat([past_key_value[0], key_states], axis=1) + value_states = paddle.concat([past_key_value[1], value_states], axis=1) + past_key_value = (key_states, value_states) if use_cache else None + + has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient) + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "core_attn" + ): + outputs = recompute( + self.attn_func, + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + softmax_scale=self.softmax_scale, + training=self.training, + sequence_parallel=self.sequence_parallel, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = self.attn_func( + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + softmax_scale=self.softmax_scale, + training=self.training, + sequence_parallel=self.sequence_parallel, + ) + if output_attentions: + attn_output, attn_weights = outputs + else: + attn_output = outputs + + # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim] + # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism. + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + outputs = (attn_output,) + + if output_attentions: + outputs += (attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class DeepseekV2DecoderLayer(nn.Layer): + def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute: bool = False): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + self.using_post_norm_recompute = config.using_post_norm_recompute + + self.hidden_size = config.hidden_size + + self.self_attn = DeepseekV2Attention(config=config, layerwise_recompute=layerwise_recompute) + + DeepseekV2MLPClass = FP8Mlp if DSV3_USE_FP8_GEMM else DeepseekV2MLP + + self.input_layernorm = DeepseekV2RMSNorm(config) + self.post_attention_layernorm = DeepseekV2RMSNorm(config) + + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ): + self.mlp = ( + DeepseekV2MoE( + config, self.post_attention_layernorm.weight, self.post_attention_layernorm.variance_epsilon + ) + if config.using_post_norm_recompute + else DeepseekV2MoE(config) + ) + else: + self.mlp = DeepseekV2MLPClass(config) + + def fp8_quant_weight(self, batch_mode=False): + """fp8_quant_weight""" + if isinstance(self.mlp, DeepseekV2MoE): + # logger.info(f"fp8 quant weight for mlp {type(self.mlp)}") + self.mlp.fp8_quant_weight(batch_mode) + self.self_attn.fp8_quant_weight() + elif isinstance(self.mlp, FP8Mlp): + self.self_attn.fp8_quant_weight() + + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + """ + Args: + hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_axis)` + attention_mask (`paddle.Tensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + residual = hidden_states + + # Self Attention + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "full_attn" + ): + outputs = recompute( + self.self_attn, + hidden_states=hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + past_key_value=past_key_value, + use_cache=use_cache, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + **kwargs, + ) + else: + outputs = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + past_key_value=past_key_value, + use_cache=use_cache, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + **kwargs, + ) + + if type(outputs) is tuple: + hidden_states = outputs[0] + else: + hidden_states = outputs + + if output_attentions: + self_attn_weights = outputs[1] + + if use_cache: + present_key_value = outputs[2 if output_attentions else 1] + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + + if not (self.using_post_norm_recompute and isinstance(self.mlp, DeepseekV2MoE)): + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + def self_attn_compute(self, hidden_states, **kwargs): + residual = hidden_states + + # Self Attention + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "full_attn" + ): + outputs = recompute( + self.self_attn, + hidden_states=hidden_states, + position_ids=None, + attention_mask=None, + output_attentions=False, + past_key_value=None, + use_cache=False, + attn_mask_startend_row_indices=None, + **kwargs, + ) + else: + outputs = self.self_attn( + hidden_states=hidden_states, + position_ids=None, + attention_mask=None, + output_attentions=False, + past_key_value=None, + use_cache=False, + attn_mask_startend_row_indices=None, + **kwargs, + ) + + if type(outputs) is tuple: + hidden_states = outputs[0] + else: + hidden_states = outputs + + hidden_states = residual + hidden_states + + residual = hidden_states + + if not (self.using_post_norm_recompute and isinstance(self.mlp, DeepseekV2MoE)): + hidden_states = self.post_attention_layernorm(hidden_states) + + return hidden_states, residual + + def pre_dispatch_compute(self, hidden_states): + l_aux, l_zloss, intermediate_hidden_states, token_indices, token_probs = self.mlp.pre_dispatch_compute( + hidden_states + ) + + return l_aux, l_zloss, intermediate_hidden_states, token_indices, token_probs + + def expert_forward_compute(self, intermediate_hidden_states, dispatched_indices, dispatched_probs): + (global_input_tokens, token_permuted_indices, prob_permuted_indices) = self.mlp.post_dispatch_compute( + intermediate_hidden_states, dispatched_indices, dispatched_probs + ) + + expert_output = self.mlp.expert_forward(global_input_tokens) + + expert_output = self.mlp.pre_combine_compute( + expert_output, token_permuted_indices, prob_permuted_indices, dispatched_probs + ) + + return expert_output + + def post_combine_compute(self, residual, hidden_states, final_hidden_states, l_aux): + final_hidden_states = self.mlp.post_combine_compute(final_hidden_states) + + final_hidden_states = self.mlp.post_process(hidden_states, final_hidden_states, l_aux) + + final_hidden_states = residual + final_hidden_states + + outputs = (final_hidden_states,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class DeepseekV2MTPLayer(DeepseekV2DecoderLayer): + def __init__( + self, + config: DeepseekV2Config, + layer_idx: int, + layerwise_recompute: bool = False, + ): + super(DeepseekV2MTPLayer, self).__init__(config, layer_idx, layerwise_recompute) + + self.enorm = DeepseekV2RMSNorm(config) + self.hnorm = DeepseekV2RMSNorm(config) + self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias_attr=False) + + def forward( + self, + hidden_states: paddle.Tensor, + nextn_hidden_state: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + hidden_states = self.hnorm(hidden_states) + nextn_hidden_state = self.enorm(nextn_hidden_state) + + concat_h = paddle.concat([hidden_states, nextn_hidden_state], axis=-1) + hidden_states = LMHeadFunction.apply( concat_h, self.eh_proj.weight, False) + + layer_outputs = super(DeepseekV2MTPLayer, self).forward( + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + **kwargs, + ) + + if type(layer_outputs) is tuple: + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + return hidden_states + + +class DeepseekV2PretrainedModelFast(PretrainedModel): + config_class = DeepseekV2Config + base_model_prefix = "deepseek_v2" + _no_split_modules = ["DeepseekV2DecoderLayer"] + + def _get_model_flops(self, batch_size=1, seq_length=None, **kwargs): + from .mfu_utils import DeepSeekProjection + + # self._ + mfu_cal_proj = DeepSeekProjection(self.config) + if seq_length is None: + if hasattr(self.config, "seq_length"): + seq_length = self.config.seq_length + else: + seq_length = 2048 + + return mfu_cal_proj.get_num_flop_per_token() + + def _get_hardware_flops(self, *args, **kwargs): + return self._get_model_flops(*args, **kwargs) + + @classmethod + def _get_name_mappings(cls, config: DeepseekV2Config) -> list[StateDictNameMapping]: + mappings: list[StateDictNameMapping] = [] + model_mappings = [ + ["embed_tokens.weight"], + ["norm.weight"], + ] + # last one layer contains MTP (eagle) parameters for inference + for layer_index in range(config.num_hidden_layers + config.num_nextn_predict_layers): + layer_mappings = [ + [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.q_a_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.q_a_layernorm.weight"], + [f"layers.{layer_index}.self_attn.q_b_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.kv_a_proj_with_mqa.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.kv_a_layernorm.weight"], + [f"layers.{layer_index}.self_attn.kv_b_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + model_mappings.extend(layer_mappings) + + # MoE parameters + model_mappings.append([f"layers.{layer_index}.mlp.gate.weight", None, "transpose"]) + model_mappings.append([f"layers.{layer_index}.mlp.gate.e_score_correction_bias"]) + for expert_idx in range(config.n_routed_experts): + expert_mappings = [ + [f"layers.{layer_index}.mlp.experts.{expert_idx}.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.experts.{expert_idx}.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.experts.{expert_idx}.down_proj.weight", None, "transpose"], + ] + model_mappings.extend(expert_mappings) + model_mappings.append([f"layers.{layer_index}.mlp.shared_experts.gate_proj.weight", None, "transpose"]) + model_mappings.append([f"layers.{layer_index}.mlp.shared_experts.up_proj.weight", None, "transpose"]) + model_mappings.append([f"layers.{layer_index}.mlp.shared_experts.down_proj.weight", None, "transpose"]) + + # MTP (eagle) parameters for inference + if layer_index >= config.num_hidden_layers: + model_mappings.append([f"layers.{layer_index}.embed_tokens.weight"]) + model_mappings.append([f"layers.{layer_index}.enorm.weight"]) + model_mappings.append([f"layers.{layer_index}.hnorm.weight"]) + model_mappings.append([f"layers.{layer_index}.eh_proj.weight", None, "transpose"]) + model_mappings.append([f"layers.{layer_index}.shared_head.norm.weight"]) + model_mappings.append([f"layers.{layer_index}.shared_head.head.weight", None, "transpose"]) + + init_name_mappings(mappings=model_mappings) + if cls.base_model_class.__name__ not in config.architectures: + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = f"{cls.base_model_prefix}." + mapping[1] + if not config.tie_word_embeddings: + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + + @classmethod + def _get_tensor_parallel_mappings(cls, config: DeepseekV2Config, is_split=True): + from paddleformers.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + + base_actions = { + # Row Linear + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + } + if config.use_fp8: + base_actions["layers.0.self_attn.o_proj.weight.weight_scale_inv"] = partial(fn, is_column=False) + + if config.tie_word_embeddings: + base_actions["lm_head.weight"] = partial(fn, is_column=False) + else: + base_actions["lm_head.weight"] = partial(fn, is_column=True) + + if not config.vocab_size % config.tensor_parallel_degree == 0: + base_actions.pop("lm_head.weight") + base_actions.pop("embed_tokens.weight") + + # Column Linear + base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.q_b_proj.weight"] = partial(fn, is_column=True) + + # if we have enough num_key_value_heads to split, then split it. + # ??? + if config.num_key_value_heads % config.tensor_parallel_degree == 0: + base_actions["layers.0.self_attn.kv_b_proj.weight"] = partial(fn, is_column=True) + if config.use_fp8: + base_actions["layers.0.self_attn.kv_b_proj.weight.weight_scale_inv"] = partial(fn, is_column=True) + + # dense mlp + base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.down_proj.weight"] = partial(fn, is_column=False) + if config.use_fp8: + base_actions["layers.0.mlp.up_proj.weight.weight_scale_inv"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.gate_proj.weight.weight_scale_inv"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.down_proj.weight.weight_scale_inv"] = partial(fn, is_column=False) + + # moe unit routed experts + moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + expert_parallel_degree = dist.get_world_size(moe_group) + if expert_parallel_degree <= 1: + for e_i in range(config.n_routed_experts): + base_actions[f"layers.0.mlp.experts.{e_i}.up_proj.weight"] = partial(fn, is_column=True) + base_actions[f"layers.0.mlp.experts.{e_i}.gate_proj.weight"] = partial(fn, is_column=True) + base_actions[f"layers.0.mlp.experts.{e_i}.down_proj.weight"] = partial(fn, is_column=False) + + # moe unit shared experts + base_actions["layers.0.mlp.shared_experts.gate_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.shared_experts.up_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.shared_experts.down_proj.weight"] = partial(fn, is_column=False) + if config.use_fp8: + base_actions["layers.0.mlp.shared_experts.gate_proj.weight.weight_scale_inv"] = partial( + fn, is_column=True + ) + base_actions["layers.0.mlp.shared_experts.up_proj.weight.weight_scale_inv"] = partial( + fn, is_column=True + ) + base_actions["layers.0.mlp.shared_experts.down_proj.weight.weight_scale_inv"] = partial( + fn, is_column=False + ) + + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + # for MTP (eagle) parameters for inference + base_actions.pop("embed_tokens.weight") + base_actions.pop("lm_head.weight") + base_actions["layers.0.embed_tokens.weight"] = partial(fn, is_column=False) + base_actions["layers.0.shared_head.head.weight"] = partial(fn, is_column=True) + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range( + config.num_hidden_layers, config.num_hidden_layers + config.num_nextn_predict_layers + ): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + else: + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + + return mappings + + def _init_weights(self, layer): + if self.config.tensor_parallel_degree > 1: + rng_tracker = get_rng_state_tracker().rng_state + + if isinstance( + layer, + ( + nn.Linear, + nn.Embedding, + mpu.VocabParallelEmbedding, + mpu.RowParallelLinear, + mpu.ColumnParallelLinear, + linear_utils.RowSequenceParallelLinear, + linear_utils.ColumnSequenceParallelLinear, + Linear, + ), + ): + # In the dygraph mode, use the `set_value` to reset the parameter directly, + # and reset the `state_dict` to update parameter in static mode. + if isinstance(layer.weight, paddle.Tensor): + if layer.weight.is_distributed: + with rng_tracker(): + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.initializer_range, + shape=layer.weight.shape, + ) + ) + else: + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.initializer_range, + shape=layer.weight.shape, + ) + ) + + # set bias to zeros + if getattr(layer, "bias", None) is not None: + layer.bias.set_value(paddle.zeros(shape=layer.bias.shape)) + + if isinstance(layer, nn.Embedding): + if layer._padding_idx is not None: + layer.weight.data[layer._padding_idx].fill_(0) + + if isinstance(layer, MoEGate): + kaiming_uniform_(layer.weight, a=math.sqrt(5)) + + moe_grad_group = fleet.get_hybrid_communicate_group().expert_grad_comm_group + if moe_grad_group is not None and moe_grad_group.nranks > 1: + for p in layer.parameters(): + if hasattr(p, "color") and "color" in p.color: + if p.color["color"] == "moe_expert": + paddle.distributed.broadcast(p, src=moe_grad_group.ranks[0], group=moe_grad_group) + + def step_flex_token(self, cur_step): + set_global_step(cur_step) + +@register_base_model +class DeepseekV2ModelFast(DeepseekV2PretrainedModelFast): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`] + + Args: + config: DeepseekV2Config + """ + + def __init__(self, config: DeepseekV2Config): + super().__init__(config) + + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # Recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.recompute_granularity = config.recompute_granularity + self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + + if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: + self.embed_tokens = mpu.VocabParallelEmbedding(config.vocab_size, config.hidden_size) + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + + self.layers = nn.LayerList( + [ + DeepseekV2DecoderLayer(config, layer_idx, layer_idx not in self.no_recompute_layers) + for layer_idx in range(config.num_hidden_layers) + ] + ) + for layer_idx in range(config.num_hidden_layers, config.num_hidden_layers + config.num_nextn_predict_layers): + self.layers.append(DeepseekV2MTPLayer(config, layer_idx, layer_idx not in self.no_recompute_layers)) + + self.norm = DeepseekV2RMSNorm(config) + + self.enable_recompute = False + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @staticmethod + def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype): + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if len(attention_mask.shape) == 2: + expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1]) + # For decoding phase in generation, seq_length = 1, we don't need to add causal mask + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + past_key_values_length=past_key_values_length, + ) + expanded_attn_mask = expanded_attn_mask & combined_attention_mask + # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len] + elif len(attention_mask.shape) == 3: + expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool") + # if attention_mask is already 4-D, do nothing + else: + expanded_attn_mask = attention_mask + else: + expanded_attn_mask = _make_causal_mask( + input_shape, + past_key_values_length=past_key_values_length, + ) + # Convert bool attention_mask to float attention mask, which will be added to attention_scores later + if get_env_device() == "xpu": + x = paddle.to_tensor(0.0, dtype="float32") + y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32") + expanded_attn_mask = paddle.where(expanded_attn_mask, x, y) + else: + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min).astype( + dtype + ) + return expanded_attn_mask + + @paddle.jit.not_to_static + def recompute_training_full( + self, + layer_module: nn.Layer, + hidden_states: Tensor, + position_ids: Optional[Tensor], + attention_mask: Tensor, + output_attentions: bool, + past_key_value: Tensor, + use_cache: bool, + attn_mask_startend_row_indices: Optional[Tensor] = None, + ): + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = recompute( + create_custom_forward(layer_module), + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + use_reentrant=self.config.recompute_use_reentrant, + ) + + return hidden_states + + def forward( + self, + input_ids: paddle.Tensor = None, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + attn_mask_startend_row_indices: Optional[Tensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPastAndMTP]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.config.num_nextn_predict_layers > 0: + seq_length -= self.config.num_nextn_predict_layers + + if attention_mask is not None: + attention_mask = attention_mask[ + :, :, : -self.config.num_nextn_predict_layers, : -self.config.num_nextn_predict_layers + ] + + if self.enable_recompute and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." + ) + use_cache = False + + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + # NOTE: to make cache can be clear in-time + past_key_values = list(past_key_values) + + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[1] + seq_length_with_past += past_key_values_length + + if position_ids is None: + position_ids = paddle.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=paddle.int64 + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + # [bs, seq_len, dim] + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + if attn_mask_startend_row_indices is not None or get_use_casual_mask(): + attention_mask = None + else: + # [bs, seq_len] + attention_mask = ( + paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + if attention_mask is None + else attention_mask + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), past_key_values_length, inputs_embeds.dtype + ) # [bs, 1, seq_len, seq_len] + if self.config.use_flash_attention: + attention_mask = None if is_casual_mask(attention_mask) else attention_mask + + if self.config.num_nextn_predict_layers > 0: + inputs_embeds_extra = inputs_embeds[:, -self.config.num_nextn_predict_layers :, :] # [B, S, D] + inputs_embeds = inputs_embeds[:, : -self.config.num_nextn_predict_layers, :] + inputs_embeds_ori = inputs_embeds + + if self.config.sequence_parallel: + # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] + bs, seq_len, hidden_size = inputs_embeds.shape + inputs_embeds = paddle.reshape(inputs_embeds, [bs * seq_len, hidden_size]) + # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) + inputs_embeds = ScatterOp.apply(inputs_embeds) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + mtp_outputs = [] + + for idx in range(self.config.num_hidden_layers): + decoder_layer = self.layers[idx] + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and idx not in self.no_recompute_layers + and has_gradient + and self.recompute_granularity == "full" + ): + layer_outputs = self.recompute_training_full( + decoder_layer, + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + ) + + # NOTE: clear outdate cache after it has been used for memory saving + past_key_value = past_key_values[idx] = None + if type(layer_outputs) is tuple: + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if self.config.num_nextn_predict_layers > 0: + mtp_outputs.append(hidden_states) + + for nextn in range(self.config.num_nextn_predict_layers): + decoder_layer = self.layers[nextn + self.config.num_hidden_layers] + + if self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + hidden_states = hidden_states.reshape([-1, seq_length, hidden_states.shape[-1]]) + + inputs_embeds_cur_depth = paddle.concat( + [inputs_embeds_ori[:, (nextn + 1) :, :], inputs_embeds_extra[:, : (nextn + 1), :]], axis=1 + ) + + past_key_value = None + layer_outputs = decoder_layer( + hidden_states, + inputs_embeds_cur_depth, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + ) + + if isinstance(layer_outputs, (tuple, list)): + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + mtp_outputs.append(hidden_states) + mtp_outputs = [self.norm(hidden_states) for hidden_states in mtp_outputs] + hidden_states, mtp_outputs = mtp_outputs[0], mtp_outputs[1:] + else: + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple( + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, mtp_outputs] if v is not None + ) + return BaseModelOutputWithPastAndMTP( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + mtp_outputs=mtp_outputs, + ) diff --git a/paddleformers/transformers/deepseek_v2/modeling_pp.py b/paddleformers/transformers/deepseek_v2/modeling_pp.py index 42b0e5de776..a659f976e72 100644 --- a/paddleformers/transformers/deepseek_v2/modeling_pp.py +++ b/paddleformers/transformers/deepseek_v2/modeling_pp.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import math +import os from typing import OrderedDict, Tuple, Union import paddle @@ -20,32 +21,87 @@ import paddle.nn as nn from paddle.distributed.fleet.meta_parallel import ( LayerDesc, + LocalSharedLayerDesc, PipelineLayer, + ScheduleChunk, + ScheduleNode, SharedLayerDesc, ) +from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import WeightGradStore + +try: + from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import EventStore +except ImportError: + EventStore = None from paddle.distributed.fleet.recompute.recompute import recompute from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp +from ...utils.log import logger from ...utils.tools import get_env_device from ..model_utils import PipelinePretrainedModel -from .modeling import ( - DeepseekV2Config, - DeepseekV2DecoderLayer, - DeepseekV2LMHead, - DeepseekV2Model, - DeepseekV2MTPLayer, - DeepseekV2PretrainedModel, - DeepseekV2PretrainingCriterion, - DeepseekV2RMSNorm, + +if not os.getenv("DSV3_FAST_PRETRAIN", "False"): + from .modeling import ( + DeepseekV2Config, + DeepseekV2DecoderLayer, + DeepseekV2LMHead, + DeepseekV2Model, + DeepseekV2MoE, + DeepseekV2MTPLayer, + DeepseekV2PretrainedModel, + DeepseekV2PretrainingCriterion, + DeepseekV2RMSNorm, + TemporaryVarContext, + set_global_step, + ) +else: + from .modeling import ( + DeepseekV2Config, + DeepseekV2LMHead, + DeepseekV2PretrainingCriterion, + DeepseekV2RMSNorm, + TemporaryVarContext, + set_global_step, + ) + from .modeling_fast import ( + DeepseekV2MoE, + DeepseekV2DecoderLayer, + DeepseekV2MTPLayer, + ) + from .modeling_fast import DeepseekV2ModelFast as DeepseekV2Model + from .modeling_fast import DeepseekV2PretrainedModelFast as DeepseekV2PretrainedModel + + +try: + import paddle.distributed.communication.deep_ep as deep_ep +except ImportError: + deep_ep = None + +from paddleformers.transformers.fused_a2a import ( + fused_combine_backward_func, + fused_combine_forward_func, + fused_dispatch_backward_func, + fused_dispatch_forward_func, ) +from paddleformers.transformers.moe_layer import FusionMoeNode + +from ..fp8_utils import FP8LinearFunctionBase __all__ = [ "DeepseekV2ForCausalLMPipe", ] +import queue + +global_inputs_embeds_mtp_queue = queue.Queue() + + +DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true" +DSV3_USE_FP8_DISPATCH = os.getenv("DSV3_USE_FP8_DISPATCH", "False").lower() == "true" + def parse_args(args): - if isinstance(args, tuple): + if isinstance(args, (tuple, list)): if len(args) == 4: hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = args @@ -55,6 +111,9 @@ def parse_args(args): elif len(args) == 2: hidden_states, attention_mask = args attn_mask_startend_row_indices, position_ids = None, None + else: # len(args) == 1: + hidden_states = args[0] + attention_mask, attn_mask_startend_row_indices, position_ids = None, None, None else: hidden_states = args attention_mask, attn_mask_startend_row_indices, position_ids = None, None, None @@ -93,6 +152,1181 @@ def get_attr(layer, name): return get_attr(layer._layer, name) +def calc_stream_wait(group_id): + comm_event = deep_ep.get_event_from_comm_stream(group_id) + comm_event.calc_stream_wait(group_id) + + +class TensorMeta: + """Recording the meta info of forward inputs, to avoid 0-size problems""" + + def __init__(self, tensor): + self.shape = tensor.shape + self.dtype = tensor.dtype + + +class PostProcessNode(ScheduleNode): + def __init__( + self, + send_mtp_embed, + training, + alpha, + config, + shared_experts=None, + using_post_norm_recompute=False, + output_mtp_embed_first=False, + name="PostProcessNode", + ): + self.send_mtp_embed = send_mtp_embed + self.shared_experts = shared_experts + self.traning = training + self.config = config + self.alpha = alpha + self.using_post_norm_recompute = using_post_norm_recompute + self.output_mtp_embed_first = output_mtp_embed_first + self.name = name + + if self.using_post_norm_recompute: + assert self.shared_experts is not None + assert self.shared_experts.norm_weight is not None and self.shared_experts.norm_eps is not None + + def forward_without_residual(self, inputs): + + if isinstance(inputs, list): + inputs = tuple(inputs) + if self.using_post_norm_recompute: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states, norm_out) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states, norm_out) = inputs + else: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states) = inputs + + with paddle.no_grad(): + if self.shared_experts is not None: + if self.using_post_norm_recompute: + _, _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( + norm_out, self.shared_experts.w1, self.shared_experts.w2 + ) + norm_out = None + del norm_out + else: + _, _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( + hidden_states, self.shared_experts.w1, self.shared_experts.w2 + ) + residual = residual + shared_expert_output + + self.x = hidden_states + self.l_aux = l_aux + + hidden_states = residual + hidden_states.stop_gradient = False + + if self.send_mtp_embed: + assert not self.output_mtp_embed_first, "forward_without_residual doesn't support output_mtp_embed_first" + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + self.mtp_embed_shape = inputs_embeds_mtp.shape # 保存mtp_embed的shape用于反向传播 + + return return_args(hidden_states) + + def forward(self, inputs): + + if isinstance(inputs, list): + inputs = tuple(inputs) + + if self.using_post_norm_recompute: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states, norm_out) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states, norm_out) = inputs + else: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states) = inputs + + with paddle.no_grad(): + if self.shared_experts is not None: + if self.using_post_norm_recompute: + _, _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( + norm_out, self.shared_experts.w1, self.shared_experts.w2 + ) + norm_out = None + del norm_out + else: + _, _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( + hidden_states, self.shared_experts.w1, self.shared_experts.w2 + ) + final_hidden_states = final_hidden_states + shared_expert_output + + self.x = hidden_states + self.l_aux = l_aux + hidden_states = residual + final_hidden_states + + if self.send_mtp_embed: + if self.output_mtp_embed_first: + hidden_states = paddle.concat([inputs_embeds_mtp, hidden_states], axis=-1) + else: + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + self.mtp_embed_shape = inputs_embeds_mtp.shape # 保存mtp_embed的shape用于反向传播 + + return return_args(hidden_states) + + @paddle.no_grad() + def backward(self, output_grad): + (do3,) = output_grad + + if self.send_mtp_embed: + # 分割梯度:do3的前部分对应hidden_states,后部分对应inputs_embeds_mtp + hidden_size = do3.shape[-1] - self.mtp_embed_shape[-1] + if self.output_mtp_embed_first: + hidden_states_grad = do3[..., hidden_size:] + inputs_embeds_mtp_grad = do3[..., :hidden_size] + else: + hidden_states_grad = do3[..., :hidden_size] + inputs_embeds_mtp_grad = do3[..., hidden_size:] + else: + hidden_states_grad = do3 + inputs_embeds_mtp_grad = None + + if self.using_post_norm_recompute: + dx, norm_out, invar = FP8LinearFunctionBase.fp8_mlp_bwd_norm_rc( + hidden_states_grad, + self.x, + self.shared_experts.norm_weight, + self.shared_experts.norm_eps, + self.shared_experts.w1, + self.shared_experts.w2, + ) + else: + dx = FP8LinearFunctionBase.fp8_mlp_bwd( + hidden_states_grad, self.x, self.shared_experts.w1, self.shared_experts.w2, True + ) + + self.x = None + + residual_grad = hidden_states_grad + l_aux_grad = paddle.ones(1, dtype=self.l_aux.dtype) * self.alpha + final_hidden_states_grad = hidden_states_grad + + if self.using_post_norm_recompute: + if self.send_mtp_embed: + return ( + inputs_embeds_mtp_grad, + dx, + residual_grad, + l_aux_grad, + final_hidden_states_grad, + norm_out, + invar, + ) + else: + return (dx, residual_grad, l_aux_grad, final_hidden_states_grad, norm_out, invar) + else: + if self.send_mtp_embed: + return (inputs_embeds_mtp_grad, dx, residual_grad, l_aux_grad, final_hidden_states_grad) + else: + return (dx, residual_grad, l_aux_grad, final_hidden_states_grad) + + +class DecoderLayerNode(ScheduleNode): + def __init__( + self, + attn_node, + dispatch_node, + mlp_node, + combine_node, + post_process_node, + mlp_layer, + name="DecoderLayerNode", + ): + super().__init__(fwd_func=None, name=name) + assert (dispatch_node is None and combine_node is None) or ( + dispatch_node is not None and combine_node is not None + ) + self.attn_node = attn_node + self.dispatch_node = dispatch_node + self.mlp_node = mlp_node + self.combine_node = combine_node + self.post_process_node = post_process_node + + self.mlp_layer = mlp_layer + self.moe_group = mlp_layer.moe_group + self.moe_num_experts = mlp_layer.moe_num_experts + + self.states = None + self.hidden_states_meta = None + self.dispatched_probs_meta = None + self.combine_output_meta = None + + def dispatch_forward(self, inputs, previous_event=None, allocate_on_comm_stream=False): + paddle.base.core.nvprof_nvtx_push("raw_dispatch_forward") + if isinstance(inputs, list): + inputs = tuple(inputs) + ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + intermediate_hidden_states, + token_indices, + token_probs, + ) = inputs + + with paddle.no_grad(): + intermediate_hidden_states, dispatched_probs, states, _ = fused_dispatch_forward_func( + intermediate_hidden_states, + token_indices, + token_probs, + self.moe_num_experts, + self.moe_group, + previous_event=previous_event, + async_finish=True, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + dispatched_indices = states["dispatched_indices"] + self.mlp_layer.set_tokens_per_expert(states["tokens_per_expert"]) + dispatched_indices.stop_gradient = True + intermediate_hidden_states.stop_gradient = False + dispatched_probs.stop_gradient = False + self.states = states + self.hidden_states_meta = TensorMeta(intermediate_hidden_states) + self.dispatched_probs_meta = TensorMeta(dispatched_probs) + + inputs = ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + intermediate_hidden_states, + dispatched_indices, + dispatched_probs, + ) + paddle.base.core.nvprof_nvtx_pop() + return inputs + + def combine_forward(self, inputs, previous_event=None): + paddle.base.core.nvprof_nvtx_push("raw_combine_forward") + if isinstance(inputs, list): + inputs = tuple(inputs) + (inputs_embeds_mtp, hidden_states, residual, l_aux, expert_output) = inputs + + with paddle.no_grad(): + combine_output = fused_combine_forward_func( + expert_output, self.moe_group, self.states, previous_event=previous_event, async_finish=True + ) + combine_output.stop_gradient = False + self.combine_output_meta = TensorMeta(combine_output) + inputs = (inputs_embeds_mtp, hidden_states, residual, l_aux, combine_output) + paddle.base.core.nvprof_nvtx_pop() + return inputs + + def dispatch_backward(self, output_grad): + paddle.base.core.nvprof_nvtx_push("raw_dispatch_backward") + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + intermediate_hidden_states_grad, + dispatched_indices_grad, + dispatched_probs_grad, + ) = output_grad + + if intermediate_hidden_states_grad is None: + intermediate_hidden_states_grad = paddle.zeros( + self.hidden_states_meta.shape, self.hidden_states_meta.dtype + ) + if dispatched_probs_grad is None: + dispatched_probs_grad = paddle.zeros(self.dispatched_probs_meta.shape, self.dispatched_probs_meta.dtype) + with paddle.no_grad(): + intermediate_hidden_states_grad, token_indices_grad, token_probs_grad = fused_dispatch_backward_func( + intermediate_hidden_states_grad, + dispatched_probs_grad, + self.moe_group, + self.states["handle"], + async_finish=True, + ) + + output_grad = ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + intermediate_hidden_states_grad, + token_indices_grad, + token_probs_grad, + ) + paddle.base.core.nvprof_nvtx_pop() + return output_grad + + def combine_backward(self, output_grad): + paddle.base.core.nvprof_nvtx_push("raw_combine_backward") + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + combine_output_grad, + ) = output_grad + + if combine_output_grad is None: + combine_output_grad = paddle.zeros(self.combine_output_meta.shape, self.combine_output_meta.dtype) + with paddle.no_grad(): + expert_output_grad = fused_combine_backward_func( + combine_output_grad, self.moe_group, self.states["handle"], async_finish=True + ) + + output_grad = ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + expert_output_grad, + ) + paddle.base.core.nvprof_nvtx_pop() + return output_grad + + def forward(self, inputs): + inputs = self.attn_node.forward(inputs) + + if self.dispatch_node is None: + inputs = self.dispatch_forward(inputs) + calc_stream_wait(self.moe_group.id) + else: + inputs = self.dispatch_node.forward(inputs) + + inputs = self.mlp_node.forward(inputs) + + if self.combine_node is None: + inputs = self.combine_forward(inputs) + calc_stream_wait(self.moe_group.id) + else: + inputs = self.combine_node.forward(inputs) + + inputs = self.post_process_node.forward(inputs) + return inputs + + def backward(self, output_grad=None, scaler=None): + assert (output_grad is not None) and (scaler is None) + + output_grad = self.post_process_node.backward(output_grad) + + if self.combine_node is None: + output_grad = self.combine_backward(output_grad) + calc_stream_wait(self.moe_group.id) + else: + output_grad = self.combine_node.backward(output_grad) + + output_grad = self.mlp_node.backward(output_grad) + + if self.dispatch_node is None: + output_grad = self.dispatch_backward(output_grad) + calc_stream_wait(self.moe_group.id) + else: + output_grad = self.dispatch_node.backward(output_grad) + + output_grad = self.attn_node.backward(output_grad) + return output_grad + + +class OverlapedScheduleChunk: + def __init__(self, forward_nodes, backward_nodes, use_fuion=True): + assert len(forward_nodes) == len(backward_nodes) + self.nodes = [] + for f, b in zip(forward_nodes, backward_nodes): + schedule_node_class = OverlapedScheduleNode + if use_fuion: + schedule_node_class = OverlapedFUsionScheduleNode + if isinstance(f, DenseDecoderLayerNode) or isinstance(b, DenseDecoderLayerNode): + schedule_node_class = OverlapedDenseFusionScheduleNode + self.nodes.append(schedule_node_class(f, b, f"OverlapedNode_{len(self.nodes)}")) + + def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None): + # print(" fwd pp stream", pp_stream) + event_to_wait = combine_bw_event_to_wait + for i, n in enumerate(self.nodes): + pp_stream_t = pp_stream + if i + 1 != len(self.nodes): + pp_stream_t = None + + inputs, output_grad, event_to_wait = n.forward_backward( + inputs, output_grad, combine_bw_event_to_wait=event_to_wait, pp_stream=pp_stream_t + ) + return inputs, output_grad, None + + +class OverlapedScheduleNode: + def __init__(self, forward_node, backward_node, name=""): + assert isinstance(forward_node, DecoderLayerNode) and isinstance(backward_node, DecoderLayerNode) + self.forward_node = forward_node + self.backward_node = backward_node + self.name = name + + def forward_backward(self, inputs, output_grad, event_to_wait=None): + paddle.base.core.nvprof_nvtx_push("forward_backward") + output_grad = self.backward_node.post_process_node.backward(output_grad) + + output_grad = self.backward_node.combine_backward(output_grad) + inputs = self.forward_node.attn_node.forward(inputs) + + calc_stream_wait(self.backward_node.moe_group.id) + attn_compute_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + output_grad = self.backward_node.mlp_node.backward(output_grad) + inputs = self.forward_node.dispatch_forward( + inputs, previous_event=attn_compute_event, allocate_on_comm_stream=True + ) + + calc_stream_wait(self.forward_node.moe_group.id) + output_grad = self.backward_node.dispatch_backward(output_grad) + inputs = self.forward_node.mlp_node.forward(inputs) + + calc_stream_wait(self.backward_node.moe_group.id) + inputs = self.forward_node.combine_forward(inputs) + output_grad = self.backward_node.attn_node.backward(output_grad) + + calc_stream_wait(self.forward_node.moe_group.id) + inputs = self.forward_node.post_process_node.forward(inputs) + paddle.base.core.nvprof_nvtx_pop() + return inputs, output_grad + + +class FusionFp8DecoderLayerNode(ScheduleNode): + def __init__( + self, + attn_and_gate_node, + fp8_fusion_moe_node, + post_process_node, + mlp_layer, + send_mtp_embed, + using_post_norm_recompute=False, + name="", + ): + self.attn_and_gate_node = attn_and_gate_node + self.fp8_fusion_moe_node = fp8_fusion_moe_node + self.post_process_node = post_process_node + self.send_mtp_embed = send_mtp_embed + + self.using_post_norm_recompute = using_post_norm_recompute + self.name = name + + self.moe_group = mlp_layer.moe_group + + def attn_forward(self, inputs): + inputs = self.attn_and_gate_node.forward(inputs) + + if self.send_mtp_embed: + if self.using_post_norm_recompute: + inputs_embeds_mtp, hidden_states, residual, probs, routing_map, l_aux, norm_out = inputs + else: + inputs_embeds_mtp, hidden_states, residual, probs, routing_map, l_aux = inputs + else: + if self.using_post_norm_recompute: + hidden_states, residual, probs, routing_map, l_aux, norm_out = inputs + else: + hidden_states, residual, probs, routing_map, l_aux = inputs + + if self.using_post_norm_recompute: + hs_2d, token_indices, token_probs = self.fp8_fusion_moe_node.dispatch_quant_node.forward( + norm_out, probs, routing_map + ) + # common return values + ret = (hidden_states, residual, l_aux, hs_2d, token_indices, token_probs, norm_out) + + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + return ret + else: + hs_2d, token_indices, token_probs = self.fp8_fusion_moe_node.dispatch_quant_node.forward( + hidden_states, probs, routing_map + ) + + # common return values + ret = (hidden_states, residual, l_aux, hs_2d, token_indices, token_probs) + + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + return ret + + def dispatch_forward(self, inputs, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + inputs_embeds_mtp, hidden_states, residual, l_aux, hs_2d, token_indices, token_probs, norm_out = inputs + else: + hidden_states, residual, l_aux, hs_2d, token_indices, token_probs, norm_out = inputs + else: + if self.send_mtp_embed: + inputs_embeds_mtp, hidden_states, residual, l_aux, hs_2d, token_indices, token_probs = inputs + else: + hidden_states, residual, l_aux, hs_2d, token_indices, token_probs = inputs + + (hs_dispatched, dispatched_indices, dispatched_probs,) = self.fp8_fusion_moe_node.dispatch_node.forward( + hs_2d, + token_indices, + token_probs, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + ret = (hidden_states, residual, l_aux, hs_dispatched, dispatched_indices, dispatched_probs) + + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret + return ret + + def mlp_forward(self, inputs): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + hs_dispatched, + dispatched_indices, + dispatched_probs, + norm_out, + ) = inputs + else: + hidden_states, residual, l_aux, hs_dispatched, dispatched_indices, dispatched_probs, norm_out = inputs + else: + if self.send_mtp_embed: + ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + hs_dispatched, + dispatched_indices, + dispatched_probs, + ) = inputs + else: + hidden_states, residual, l_aux, hs_dispatched, dispatched_indices, dispatched_probs = inputs + + hidden_states_out = self.fp8_fusion_moe_node.mlp_node.forward( + hs_dispatched, dispatched_indices, dispatched_probs + ) + ret = (hidden_states, residual, l_aux, hidden_states_out) + + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret + return ret + + def combine_forward(self, inputs, async_finish=False, previous_event=None, allocate_on_comm_stream=False): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, hidden_states_out, norm_out) = inputs + else: + (hidden_states, residual, l_aux, hidden_states_out, norm_out) = inputs + else: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, hidden_states_out) = inputs + else: + (hidden_states, residual, l_aux, hidden_states_out) = inputs + + output_combine = self.fp8_fusion_moe_node.combine_node.forward( + hidden_states_out, + async_finish=async_finish, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream and previous_event is not None, + ) + + ret = (hidden_states, residual, l_aux, output_combine) + + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret + return ret + + def post_process_forward(self, inputs, with_residual=True): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, output_combine, norm_out) = inputs + else: + (hidden_states, residual, l_aux, output_combine, norm_out) = inputs + else: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, output_combine) = inputs + else: + (hidden_states, residual, l_aux, output_combine) = inputs + final_hidden_states = self.fp8_fusion_moe_node.combine_quant_node.forward(output_combine) + + inputs = (hidden_states, residual, l_aux, final_hidden_states) + inputs = (inputs_embeds_mtp, *inputs) if self.send_mtp_embed else inputs + inputs = (*inputs, norm_out) if self.using_post_norm_recompute else inputs + + if with_residual: + inputs = self.post_process_node.forward(inputs) + else: + inputs = self.post_process_node.forward_without_residual(inputs) + return inputs + + def post_process_backward(self, output_grad, event_to_wait=None): + grad = self.post_process_node.backward(output_grad) + + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + final_hidden_states_grad, + norm_out, + invar, + ) = grad + else: + hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad, norm_out, invar = grad + else: + if self.send_mtp_embed: + inputs_embeds_mtp_grad, hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad = grad + else: + hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad = grad + + output_combine_grad, quant_event = self.fp8_fusion_moe_node.combine_quant_node.backward( + final_hidden_states_grad, event_to_wait + ) + + ret = (hidden_states_grad, residual_grad, l_aux_grad, output_combine_grad, quant_event) + ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret + return ret + + def combine_backward(self, output_grad, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + output_combine_grad, + quant_event, + norm_out, + invar, + ) = output_grad + else: + ( + hidden_states_grad, + residual_grad, + l_aux_grad, + output_combine_grad, + quant_event, + norm_out, + invar, + ) = output_grad + else: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + output_combine_grad, + quant_event, + ) = output_grad + else: + ( + hidden_states_grad, + residual_grad, + l_aux_grad, + output_combine_grad, + quant_event, + ) = output_grad + + if DSV3_USE_FP8_DISPATCH and quant_event is not None: + combine_backward_wait_event = quant_event + else: + combine_backward_wait_event = previous_event + hidden_states_out_grad = self.fp8_fusion_moe_node.combine_node.backward( + output_combine_grad, + async_finish=async_finish, + previous_event=combine_backward_wait_event, + allocate_on_comm_stream=allocate_on_comm_stream and quant_event is not None, + ) + + ret = (hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad) + ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret + return ret + + def mlp_backward(self, output_grad): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hidden_states_out_grad, + norm_out, + invar, + ) = output_grad + else: + hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad, norm_out, invar = output_grad + else: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hidden_states_out_grad, + ) = output_grad + else: + hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad = output_grad + hs_dispatched_grad, dispatched_probs_grad = self.fp8_fusion_moe_node.mlp_node.backward(hidden_states_out_grad) + + ret = (hidden_states_grad, residual_grad, l_aux_grad, hs_dispatched_grad, dispatched_probs_grad) + ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret + return ret + + def dispatch_backward(self, output_grad, async_finish=False, previous_event=None, allocate_on_comm_stream=False): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_dispatched_grad, + dispatched_probs_grad, + norm_out, + invar, + ) = output_grad + else: + ( + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_dispatched_grad, + dispatched_probs_grad, + norm_out, + invar, + ) = output_grad + else: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_dispatched_grad, + dispatched_probs_grad, + ) = output_grad + else: + hidden_states_grad, residual_grad, l_aux_grad, hs_dispatched_grad, dispatched_probs_grad = output_grad + + hs_grad, token_probs_grad = self.fp8_fusion_moe_node.dispatch_node.backward( + hs_dispatched_grad, + dispatched_probs_grad, + async_finish=async_finish, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream and previous_event is not None, + ) + + ret = (hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad) + ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret + return ret + + def attn_backward(self, output_grad): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_grad, + token_probs_grad, + norm_out, + invar, + ) = output_grad + inputs_embeds_mtp_grad_shape = hidden_states_grad.shape + inputs_embeds_mtp_grad_shape[-1] = -1 + inputs_embeds_mtp_grad = inputs_embeds_mtp_grad.view(inputs_embeds_mtp_grad_shape) + else: + hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad, norm_out, invar = output_grad + else: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_grad, + token_probs_grad, + ) = output_grad + inputs_embeds_mtp_grad_shape = hidden_states_grad.shape + inputs_embeds_mtp_grad_shape[-1] = -1 + inputs_embeds_mtp_grad = inputs_embeds_mtp_grad.view(inputs_embeds_mtp_grad_shape) + else: + hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad = output_grad + + hidden_states_grad_, probs_grad, routing_map_grad = self.fp8_fusion_moe_node.dispatch_quant_node.backward( + hs_grad, token_probs_grad + ) + + output_grad = (residual_grad, probs_grad, routing_map_grad, l_aux_grad) + + output_grad = ( + (hidden_states_grad, *output_grad, hidden_states_grad_) + if self.using_post_norm_recompute + else (hidden_states_grad + hidden_states_grad_, *output_grad) + ) + output_grad = (inputs_embeds_mtp_grad, *output_grad) if self.send_mtp_embed else output_grad + + if self.using_post_norm_recompute: + with TemporaryVarContext(norm_out, invar): + output_grad = self.attn_and_gate_node.backward(output_grad) + else: + output_grad = self.attn_and_gate_node.backward(output_grad) + return output_grad + + def forward(self, inputs): + inputs = self.attn_forward(inputs) + inputs = self.dispatch_forward(inputs) + inputs = self.mlp_forward(inputs) + inputs = self.combine_forward(inputs) + inputs = self.post_process_forward(inputs) + return inputs + + def backward(self, output_grad=None, scaler=None): + assert (output_grad is not None) and (scaler is None) + output_grad = self.post_process_backward(output_grad) + output_grad = self.combine_backward(output_grad) + output_grad = self.mlp_backward(output_grad) + # todo(phlrain): overlap here + output_grad = self.dispatch_backward(output_grad) + output_grad = self.attn_backward(output_grad) + return output_grad + + +class DenseDecoderLayerNode(ScheduleNode): + def __init__( + self, + attn_node, + mlp_node, + name="DenseDecoderLayerNode", + ): + super().__init__(fwd_func=None, name=name) + self.attn_node = attn_node + self.mlp_node = mlp_node + + def forward(self, inputs): + inputs = self.attn_node.forward(inputs) + inputs = self.mlp_node.forward(inputs) + return inputs + + def backward(self, output_grad=None, scaler=None): + assert (output_grad is not None) and (scaler is None) + output_grad = self.mlp_node.backward(output_grad) + output_grad = self.attn_node.backward(output_grad) + return output_grad + + +class OverlapedFUsionScheduleNode: + def __init__(self, forward_node, backward_node, name=""): + assert isinstance(forward_node, FusionFp8DecoderLayerNode) and isinstance( + backward_node, FusionFp8DecoderLayerNode + ) + self.forward_node = forward_node + self.backward_node = backward_node + self.name = name + + def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None): + paddle.base.core.nvprof_nvtx_push("forward_backward") + + combine_bwd_event = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("attn_forward") + inputs = self.forward_node.attn_forward(inputs) + paddle.base.core.nvprof_nvtx_pop() + attn_compute_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("post_process_backward") + output_grad = self.backward_node.post_process_backward(output_grad, combine_bw_event_to_wait) + paddle.base.core.nvprof_nvtx_pop() + + paddle.base.core.nvprof_nvtx_push("combine_backward") + if combine_bw_event_to_wait is not None: + # print(" event", combine_bw_event_to_wait) + output_grad = self.backward_node.combine_backward( + output_grad, previous_event=combine_bw_event_to_wait, async_finish=True, allocate_on_comm_stream=True + ) + else: + output_grad = self.backward_node.combine_backward( + output_grad, previous_event=combine_bwd_event, async_finish=True, allocate_on_comm_stream=True + ) + # get combine event + combine_backward_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() + + combine_backward_event.calc_stream_wait(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_push("mlp_backward_dx") + assert WeightGradStore.funcs_queue.empty() + WeightGradStore.enabled = True + output_grad = self.backward_node.mlp_backward(output_grad) + WeightGradStore.enabled = False + WeightGradStore.flush() + paddle.base.core.nvprof_nvtx_pop() + + output_grad_event = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("dispatch_forward") + inputs = self.forward_node.dispatch_forward( + inputs, previous_event=attn_compute_event, async_finish=True, allocate_on_comm_stream=True + ) + paddle.base.core.nvprof_nvtx_pop() + dispatch_forward_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("dispatch_backward") + output_grad = self.backward_node.dispatch_backward( + output_grad, async_finish=True, previous_event=output_grad_event, allocate_on_comm_stream=True + ) + paddle.base.core.nvprof_nvtx_pop() + # get dispatch backward event + dispatch_backward_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("dispatch_backward_dw") + WeightGradStore.pop() + assert WeightGradStore.funcs_queue.empty() + paddle.base.core.nvprof_nvtx_pop() + + dispatch_forward_event.calc_stream_wait(self.forward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_push("mlp_forward") + inputs = self.forward_node.mlp_forward(inputs) + paddle.base.core.nvprof_nvtx_pop() + mlp_fwd_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + + if pp_stream is not None: + paddle.base.core.nvprof_nvtx_push("post_process_forward") + + final_out = self.forward_node.post_process_node.forward_without_residual(inputs) + paddle.base.core.nvprof_nvtx_pop() + + final_out_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("combine_forward") + inputs = self.forward_node.combine_forward( + inputs, previous_event=mlp_fwd_event, async_finish=True, allocate_on_comm_stream=True + ) + paddle.base.core.nvprof_nvtx_pop() + + combine_forward_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id) + + combine_fwd_out = inputs[-2] if self.forward_node.using_post_norm_recompute else inputs[-1] + + if pp_stream is not None: + send_recv_stream = paddle.device.Stream(stream_base=pp_stream) + + # combine_forward_event.custom_stream_wait( pp_stream) + # final_out_event.custom_stream_wait(pp_stream) + + paddle.base.core.nvprof_nvtx_push("pp stream add") + + with paddle.device.stream_guard(send_recv_stream): + combine_forward_event.current_stream_wait() + final_out_event.current_stream_wait() + + inputs = final_out + combine_fwd_out + + final_out._record_stream() + combine_fwd_out._record_stream() + + paddle.base.core.nvprof_nvtx_pop() + + dispatch_backward_event.calc_stream_wait(self.backward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("attn_backward") + assert WeightGradStore.funcs_queue.empty() + WeightGradStore.enabled = True + output_grad = self.backward_node.attn_backward(output_grad) + event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) + + if EventStore is not None: + EventStore.set(event_to_wait) + + WeightGradStore.enabled = False + WeightGradStore.flush() + WeightGradStore.pop() + assert WeightGradStore.funcs_queue.empty() + + paddle.base.core.nvprof_nvtx_pop() + + # residual add + if pp_stream is None: + combine_forward_event.calc_stream_wait(self.forward_node.moe_group.id) + + final_out = self.forward_node.post_process_node.forward_without_residual(inputs) + if final_out.shape[-1] != combine_fwd_out.shape[-1]: + final_out[:, :, : combine_fwd_out.shape[-1]] += combine_fwd_out # 直接广播并相加 + else: + final_out += combine_fwd_out + inputs = final_out + combine_fwd_out._record_stream() + + paddle.base.core.nvprof_nvtx_pop() + return inputs, output_grad, event_to_wait + + +class OverlapedDenseFusionScheduleNode: + def __init__(self, forward_node, backward_node, name=""): + assert isinstance(forward_node, FusionFp8DecoderLayerNode) or isinstance( + backward_node, FusionFp8DecoderLayerNode + ) + assert isinstance(forward_node, DenseDecoderLayerNode) or isinstance( + backward_node, DenseDecoderLayerNode + ) + self.forward_node = forward_node + self.backward_node = backward_node + self.name = name + + def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None): + # Dense forward + MoE backward + if isinstance(self.forward_node, DenseDecoderLayerNode): + paddle.base.core.nvprof_nvtx_push("dense_fw_moe_bw") + + paddle.base.core.nvprof_nvtx_push("dense_attn_moe_combine") + # Note: the input combine_bw_event_to_wait is unreliable, we need to record a new event here. + combine_bw_event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) + output_grad = self.backward_node.post_process_backward(output_grad, combine_bw_event_to_wait) + output_grad = self.backward_node.combine_backward( + output_grad, previous_event=combine_bw_event_to_wait, async_finish=True, allocate_on_comm_stream=True + ) + combine_bw_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id) + inputs = self.forward_node.attn_node.forward(inputs) + combine_bw_event.calc_stream_wait(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_attn_moe_combine + + paddle.base.core.nvprof_nvtx_push("moe_mlp") + output_grad = self.backward_node.mlp_backward(output_grad) + paddle.base.core.nvprof_nvtx_pop() # moe_mlp + + paddle.base.core.nvprof_nvtx_push("dense_mlp_moe_dispatch") + output_grad = self.backward_node.dispatch_backward( + output_grad, async_finish=True, allocate_on_comm_stream=True + ) + dispatch_bw_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id) + inputs = self.forward_node.mlp_node.forward(inputs) + dispatch_bw_event.calc_stream_wait(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_mlp_moe_dispatch + + paddle.base.core.nvprof_nvtx_push("moe_attn") + output_grad = self.backward_node.attn_backward(output_grad) + paddle.base.core.nvprof_nvtx_pop() # moe_attn + + event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_fw_moe_bw + + # Dense backward + MoE forward + else: + paddle.base.core.nvprof_nvtx_push("dense_bw_moe_fw") + + paddle.base.core.nvprof_nvtx_push("moe_attn") + inputs = self.forward_node.attn_forward(inputs) + attn_fw_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # moe_attn + + paddle.base.core.nvprof_nvtx_push("dense_mlp_moe_dispatch") + output_grad = self.backward_node.mlp_node.backward(output_grad) + inputs = self.forward_node.dispatch_forward( + inputs, previous_event=attn_fw_event, async_finish=True, allocate_on_comm_stream=True + ) + dispatch_fw_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id) + dispatch_fw_event.calc_stream_wait(self.forward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_mlp_moe_dispatch + + paddle.base.core.nvprof_nvtx_push("moe_mlp") + inputs = self.forward_node.mlp_forward(inputs) + paddle.base.core.nvprof_nvtx_pop() # moe_mlp + + paddle.base.core.nvprof_nvtx_push("dense_attn_moe_combine") + inputs = self.forward_node.combine_forward( + inputs, async_finish=True, allocate_on_comm_stream=True + ) + combine_fw_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id) + output_grad = self.backward_node.attn_node.backward(output_grad) + combine_fw_event.calc_stream_wait(self.forward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_attn_moe_combine + + paddle.base.core.nvprof_nvtx_push("moe_post") + inputs = self.forward_node.post_process_forward(inputs) + paddle.base.core.nvprof_nvtx_pop() # moe_post + + event_to_wait = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_bw_moe_fw + + return inputs, output_grad, event_to_wait + + +def build_overlapped_nodes(forward_chunk, backward_chunk): + overlap_element_class = ( + FusionFp8DecoderLayerNode if DSV3_USE_FP8_GEMM else DecoderLayerNode, + DenseDecoderLayerNode + ) + forward_decoder_layer_num = 0 + backward_decoder_layer_num = 0 + assert isinstance(forward_chunk, ScheduleChunk) and isinstance(backward_chunk, ScheduleChunk) + for n in forward_chunk.nodes: + if isinstance(n, overlap_element_class): + forward_decoder_layer_num += 1 + for n in reversed(backward_chunk.nodes): + if isinstance(n, overlap_element_class): + backward_decoder_layer_num += 1 + + overlap_layers_num = min(forward_decoder_layer_num, backward_decoder_layer_num) + forward_pre_overlap_layers = [] + forward_post_overlap_layers = [] + forward_overlap_layers = [] + is_pre = True + for n in forward_chunk.nodes: + if not isinstance(n, overlap_element_class): + if is_pre: + forward_pre_overlap_layers.append(n) + else: + forward_post_overlap_layers.append(n) + else: + is_pre = False + if len(forward_overlap_layers) == overlap_layers_num: + forward_post_overlap_layers.append(n) + else: + forward_overlap_layers.append(n) + forward_pre_node = ScheduleChunk(forward_pre_overlap_layers) + forward_post_node = ScheduleChunk(forward_post_overlap_layers) + + backward_pre_overlap_layers = [] + backward_post_overlap_layers = [] + backward_overlap_layers = [] + is_pre = True + for n in reversed(backward_chunk.nodes): + if not isinstance(n, overlap_element_class): + if is_pre: + backward_pre_overlap_layers.append(n) + else: + backward_post_overlap_layers.append(n) + else: + is_pre = False + if len(backward_overlap_layers) == overlap_layers_num: + backward_post_overlap_layers.append(n) + else: + backward_overlap_layers.append(n) + + backward_pre_node = ScheduleChunk(list(reversed(backward_pre_overlap_layers))) + backward_post_node = ScheduleChunk(list(reversed(backward_post_overlap_layers))) + + overlap_node = OverlapedScheduleChunk(forward_overlap_layers, backward_overlap_layers, use_fuion=DSV3_USE_FP8_GEMM) + return forward_pre_node, backward_pre_node, overlap_node, forward_post_node, backward_post_node + + class DeepseekV2EmbeddingPipe(nn.Layer): def __init__(self, config: DeepseekV2Config): super(DeepseekV2EmbeddingPipe, self).__init__() @@ -160,6 +1394,7 @@ def forward(self, args): # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) inputs_embeds = ScatterOp.apply(inputs_embeds) embeds_res = [inputs_embeds] + mtp_embeds = [] for depth in range(self.config.num_nextn_predict_layers): inputs_embeds_mtp = paddle.concat( [ @@ -171,12 +1406,19 @@ def forward(self, args): if self.sequence_parallel: inputs_embeds_mtp = inputs_embeds_mtp.reshape([-1, inputs_embeds_mtp.shape[-1]]) inputs_embeds_mtp = ScatterOp.apply(inputs_embeds_mtp) - embeds_res.append(inputs_embeds_mtp) - # if not self.sequence_parallel - # mtp_embeds: [B*num_nextn_predict_layers, seq_len, hidden_size] - # else: - # mtp_embeds: [B*seq_len*num_nextn_predict_layers, hidden_size] - inputs_embeds = paddle.concat(embeds_res, axis=-1) + mtp_embeds.append(inputs_embeds_mtp) + + if self.config.send_mtp_embed: + embeds_res.extend(mtp_embeds) + # if not self.sequence_parallel + # mtp_embeds: [B*num_nextn_predict_layers, seq_len, hidden_size] + # else: + # mtp_embeds: [B*seq_len*num_nextn_predict_layers, hidden_size] + inputs_embeds = paddle.concat(embeds_res, axis=-1) + else: + global global_inputs_embeds_mtp_queue + cloned_mtp_embeds = [t.detach() for t in mtp_embeds] + global_inputs_embeds_mtp_queue.put(cloned_mtp_embeds) return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids) else: if self.sequence_parallel: @@ -184,15 +1426,18 @@ def forward(self, args): inputs_embeds = ScatterOp.apply(inputs_embeds) return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids) + def build_schedule_node(self): + return ScheduleNode(self.forward, name="DeepseekV2EmbeddingPipe") + class DeepseekV2DecoderLayerPipe(DeepseekV2DecoderLayer): def forward(self, args): hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) - if self.config.num_nextn_predict_layers > 0: + if self.config.send_mtp_embed: batch_size, _, hidden_size = hidden_states.shape batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1) - inputs_embeds_mtp = hidden_states[..., -batch_size_mtp:] + inputs_embeds_mtp = hidden_states[..., batch_size_mtp:] hidden_states = hidden_states[..., :batch_size_mtp] has_gradient = not hidden_states.stop_gradient @@ -235,19 +1480,285 @@ def forward(self, args): attn_mask_startend_row_indices=attn_mask_startend_row_indices, ) - if self.config.num_nextn_predict_layers > 0: + if self.config.send_mtp_embed: hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids) + def attn_compute(self, args): + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) + assert attention_mask is None + assert attn_mask_startend_row_indices is None + assert position_ids is None + assert self.config.send_mtp_embed + + batch_size, _, hidden_size = hidden_states.shape + batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1) + inputs_embeds_mtp = hidden_states[..., batch_size_mtp:] + hidden_states = hidden_states[..., :batch_size_mtp] + + def attn_compute_func(hidden_states): + hidden_states, residual = self.self_attn_compute(hidden_states) + l_aux, _, intermediate_hidden_states, token_indices, token_probs = self.pre_dispatch_compute(hidden_states) + return (hidden_states, residual, l_aux, intermediate_hidden_states, token_indices, token_probs) + + has_gradient = not hidden_states.stop_gradient + if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + # for pretrain + outputs = recompute( + attn_compute_func, + hidden_states, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = attn_compute_func(hidden_states) + + return (inputs_embeds_mtp, *outputs) + + def attn_compute_for_fusion(self, args): + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) + assert attention_mask is None + assert attn_mask_startend_row_indices is None + assert position_ids is None + + send_mtp_embed = self.config.send_mtp_embed + + if send_mtp_embed: + # slice from holy tensor + batch_size, _, hidden_size = hidden_states.shape + batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1) + inputs_embeds_mtp = hidden_states[..., batch_size_mtp:] + hidden_states = hidden_states[..., :batch_size_mtp] + + hidden_states, residual = self.self_attn_compute(hidden_states) + _, _, d_model = hidden_states.shape + + if self.using_post_norm_recompute: + probs, routing_map, l_aux, _, norm_out = self.mlp.router(hidden_states) + else: + probs, routing_map, l_aux, _ = self.mlp.router(hidden_states) + + # common return values + ret = ( + hidden_states, + residual, + probs, + routing_map, + l_aux, + ) + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if send_mtp_embed else ret + # append norm_out if using post_norm recompute + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret + + return ret + + def mlp_compute(self, inputs): + if isinstance(inputs, list): + inputs = tuple(inputs) + send_mtp_embed = self.config.send_mtp_embed + + if send_mtp_embed: + ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + intermediate_hidden_states, + dispatched_indices, + dispatched_probs, + ) = inputs + else: + ( + hidden_states, + residual, + l_aux, + intermediate_hidden_states, + dispatched_indices, + dispatched_probs, + ) = inputs + has_gradient = not intermediate_hidden_states.stop_gradient + if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + expert_output = recompute( + self.expert_forward_compute, + intermediate_hidden_states, + dispatched_indices, + dispatched_probs, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + expert_output = self.expert_forward_compute( + intermediate_hidden_states, dispatched_indices, dispatched_probs + ) + if send_mtp_embed: + return (inputs_embeds_mtp, hidden_states, residual, l_aux, expert_output) + else: + return (hidden_states, residual, l_aux, expert_output) + + def post_process_compute(self, inputs): + send_mtp_embed = self.config.send_mtp_embed + + if isinstance(inputs, list): + inputs = tuple(inputs) + if send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, combine_output) = inputs + else: + (hidden_states, residual, l_aux, combine_output) = inputs + has_gradient = not hidden_states.stop_gradient + if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + hidden_states = recompute( + self.post_combine_compute, + residual, + hidden_states, + combine_output, + l_aux, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + hidden_states = self.post_combine_compute( + residual, + hidden_states, + combine_output, + l_aux, + ) + if send_mtp_embed: + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + + return return_args(hidden_states) + + def post_process_compute_for_fusion(self, inputs): + send_mtp_embed = self.config.send_mtp_embed + + if isinstance(inputs, list): + inputs = tuple(inputs) + + if send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states) = inputs + + final_hidden_states = self.mlp.post_process(hidden_states, final_hidden_states, l_aux) + + hidden_states = residual + final_hidden_states + + hidden_states = (hidden_states,) + + if type(hidden_states) is tuple and len(hidden_states) == 1: + hidden_states = hidden_states[0] + + if send_mtp_embed: + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + + return return_args(hidden_states) + + def attn_compute_dense(self, args): + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) + assert attention_mask is None + assert attn_mask_startend_row_indices is None + assert position_ids is None + + if self.config.send_mtp_embed: + batch_size, _, hidden_size = hidden_states.shape + batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1) + inputs_embeds_mtp = hidden_states[..., batch_size_mtp:] + hidden_states = hidden_states[..., :batch_size_mtp] + + hidden_states, residual = self.self_attn_compute(hidden_states) + + ret = (hidden_states, residual) + ret = (inputs_embeds_mtp, *ret) if self.config.send_mtp_embed else ret + return ret + + def mlp_compute_dense(self, inputs): + if self.config.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual) = inputs + else: + (hidden_states, residual) = inputs + + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + if self.config.send_mtp_embed: + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + + return hidden_states + + def build_schedule_node(self): + if isinstance(self.mlp, DeepseekV2MoE): + self.mlp.update_flex_token() + if self.mlp.using_flex_token: + if DSV3_USE_FP8_GEMM: + attn_and_gate_node = ScheduleNode(self.attn_compute_for_fusion, name="attn_and_gate_node") + + # recompute_fwd_gate_up_ may be 1, 0 or -1, 1 means recompute, 0 means disable recompute, -1 means adaptive recompute. + recompute_fwd_gate_up_ = 1 if self.layer_idx in self.config.recompute_fwd_gate_up_list else 0 + if recompute_fwd_gate_up_ == 0 and self.config.adaptive_remained_O1_recompute_ratio: + recompute_fwd_gate_up_ = -1 + + fp8_fusion_moe_node = FusionMoeNode( + self.mlp, + recompute_fwd_gate_up=recompute_fwd_gate_up_, + is_split_group_gemm=self.config.is_split_group_gemm, + mlp_fwd_subbatch_rows=self.config.mlp_fwd_subbatch_rows, + mlp_bwd_subbatch_rows=self.config.mlp_bwd_subbatch_rows, + output_subbatch_rows=self.config.output_subbatch_rows, + name="fp8_fusion_moe_node", + ) + post_process_node = PostProcessNode( + self.config.send_mtp_embed, + self.mlp.training, + self.mlp.alpha, + self.config, + self.mlp.shared_experts, + self.config.using_post_norm_recompute, + output_mtp_embed_first=isinstance(self, DeepseekV2MTPLayer), + name="post_process_node", + ) + return FusionFp8DecoderLayerNode( + attn_and_gate_node=attn_and_gate_node, + fp8_fusion_moe_node=fp8_fusion_moe_node, + post_process_node=post_process_node, + mlp_layer=self.mlp, + send_mtp_embed=self.config.send_mtp_embed, + using_post_norm_recompute=self.config.using_post_norm_recompute, + name="FusionFp8DecoderLayerNode", + ) + else: + attn_node = ScheduleNode(self.attn_compute, name="attn_node") + mlp_node = ScheduleNode(self.mlp_compute, name="mlp_node") + post_process_node = ScheduleNode(self.post_process_compute, name="post_process_node") + return DecoderLayerNode( + attn_node=attn_node, + dispatch_node=None, + mlp_node=mlp_node, + combine_node=None, + post_process_node=post_process_node, + mlp_layer=self.mlp, + name="DecoderLayerNode", + ) + + attn_node = ScheduleNode(self.attn_compute_dense, name="attn_node") + mlp_node = ScheduleNode(self.mlp_compute_dense, name="mlp_node") + return DenseDecoderLayerNode( + attn_node=attn_node, + mlp_node=mlp_node, + name="DenseDecoderLayerNode", + ) + class DeepseekV2MTPLayerPipe(DeepseekV2MTPLayer): def forward(self, args): hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) - hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1) - hidden_states_main_model = hidden_states_list[0] - inputs_embeds_cur_depth_list = hidden_states_list[1:] + if self.config.send_mtp_embed: + hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1) + hidden_states_main_model = hidden_states_list[0] + inputs_embeds_cur_depth_list = hidden_states_list[1:] + else: + hidden_states_main_model = hidden_states + global global_inputs_embeds_mtp_queue + inputs_embeds_cur_depth_list = global_inputs_embeds_mtp_queue.get() + has_gradient = not hidden_states_main_model.stop_gradient if attention_mask is not None and attention_mask.dtype == paddle.int32: @@ -299,6 +1810,70 @@ def forward(self, args): hidden_states = paddle.concat(output_list, axis=-1) return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids) + def attn_compute_for_fusion(self, args): + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) + assert attention_mask is None + assert attn_mask_startend_row_indices is None + assert position_ids is None + assert self.config.num_nextn_predict_layers == 1 + + if self.config.send_mtp_embed: + hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1) + hidden_states_main_model = hidden_states_list[0] + inputs_embeds_cur_depth_list = hidden_states_list[1:] + else: + hidden_states_main_model = hidden_states + global global_inputs_embeds_mtp_queue + inputs_embeds_cur_depth_list = global_inputs_embeds_mtp_queue.get() + + hidden_states = hidden_states_main_model + nextn_hidden_state = inputs_embeds_cur_depth_list[0] + + # mtp compute + hidden_states = self.hnorm(hidden_states) + nextn_hidden_state = self.enorm(nextn_hidden_state) + + hidden_states = self.eh_proj(paddle.concat([hidden_states, nextn_hidden_state], axis=-1)) + + # attention compute + hidden_states, residual = self.self_attn_compute(hidden_states) + + if self.using_post_norm_recompute: + probs, routing_map, l_aux, _, norm_out = self.mlp.router(hidden_states) + else: + probs, routing_map, l_aux, _ = self.mlp.router(hidden_states) + + # common return values + ret = ( + hidden_states_main_model, + hidden_states, + residual, + probs, + routing_map, + l_aux, + ) + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret + + return ret + + def build_schedule_node(self): + if isinstance(self.mlp, DeepseekV2MoE): + self.mlp.update_flex_token() + if ( + self.mlp.using_flex_token and + DSV3_USE_FP8_GEMM and + self.config.num_nextn_predict_layers == 1 + ): + prev_send_mtp_embed = self.config.send_mtp_embed + self.config.send_mtp_embed = True # must be True in MTP node + + node = DeepseekV2DecoderLayerPipe.build_schedule_node(self) + assert isinstance(node, FusionFp8DecoderLayerNode) + + self.config.send_mtp_embed = prev_send_mtp_embed + return node + return ScheduleNode(self.forward, name="DeepseekV2MTPLayerPipe") + class DeepseekV2RMSNormPipe(nn.Layer): def __init__(self, config): @@ -321,10 +1896,13 @@ def forward(self, args): else: return self.norm(hidden_states) + def build_schedule_node(self): + return ScheduleNode(self.forward, name="DeepseekV2RMSNormPipe") + class DeepseekV2LMHeadPipe(DeepseekV2LMHead): - def __init__(self, config): - super(DeepseekV2LMHeadPipe, self).__init__(config) + def __init__(self, config, embedding_weight=None): + super(DeepseekV2LMHeadPipe, self).__init__(config, embedding_weight=embedding_weight) @property def embedding_weight(self): @@ -340,6 +1918,9 @@ def forward(self, args: Union[Tuple, paddle.Tensor]): logits = super().forward(hidden_states) return logits + def build_schedule_node(self): + return ScheduleNode(self.forward, name="DeepseekV2LMHeadPipe") + class DeepseekV2PretrainingCriterionPipe(DeepseekV2PretrainingCriterion): def forward(self, logits, labels): @@ -348,9 +1929,14 @@ def forward(self, logits, labels): logits = logits[0] loss = super().forward(logits, labels, mtp_logits=mtp_logits) else: + if isinstance(logits, (tuple, list)): + logits = logits[0] loss = super().forward(logits, labels) return loss + def build_schedule_node(self): + return ScheduleNode(self.forward, name="DeepseekV2PretrainingCriterionPipe") + class DeepseekV2ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): """DeepseekV2ForPretraining adapted for pipeline parallelism. @@ -371,6 +1957,9 @@ class DeepseekV2ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): # DONOT Add base_model_prefix !!!! + def step_flex_token(self, cur_step): + set_global_step(cur_step) + @classmethod def _prepare_pipeline_inputs_func(cls, inputs): first_stage_keys = ["input_ids", "attention_mask", "attn_mask_startend_row_indices", "position_ids"] @@ -408,6 +1997,10 @@ def __init__(self, config: DeepseekV2Config): assert len(self.no_recompute_layers) == 0, "for pp with full recompute, no_recompute_layers is not support" virtual_pp_degree = getattr(self.config, "virtual_pp_degree", 1) + use_dualpipev = getattr(self.config, "use_dualpipev", False) + if use_dualpipev: + assert LocalSharedLayerDesc is not None, "LocalSharedLayerDesc is None, please update your paddle." + shared_class = LocalSharedLayerDesc if use_dualpipev else SharedLayerDesc def get_hcg(): return fleet.get_hybrid_communicate_group() @@ -422,7 +2015,7 @@ def get_hcg(): if config.tie_word_embeddings: self.add_sequential_layer( - SharedLayerDesc( + shared_class( "DeepseekV2_shared_weight", DeepseekV2EmbeddingPipe, shared_weight_attr="embedding_weight", @@ -435,6 +2028,43 @@ def get_hcg(): LayerDesc(DeepseekV2EmbeddingPipe, config=config), self._base_model.base_model_prefix ) + def compute_recompute_fwd_gate_up_list(pp_nums, all_dl_nums, dense_dl_nums, recompute_fwd_gate_up): + all_layers_nums = all_dl_nums + 4 # embedding, rms, lm_head, mtp + segment_size = all_layers_nums // pp_nums + boundary = math.ceil((1 + dense_dl_nums) / segment_size) * segment_size + recompute_fwd_gate_up_list = [dense_dl_nums] + for idx in range(boundary - 1, all_dl_nums, segment_size): + recompute_fwd_gate_up_list.append(idx) + + # If `recompute_fwd_gate_up` is a Boolean value and is True, means all O1 will be recomputed. + # Otherwise `recompute_fwd_gate_up` should be an integer representing how many O1 are recomputed. + assert isinstance(recompute_fwd_gate_up, (int, bool)) + if type(recompute_fwd_gate_up) is bool: + enable_k_o1_rc = segment_size if recompute_fwd_gate_up is True else 0 + else: + enable_k_o1_rc = recompute_fwd_gate_up + + ret = [] + for i in range(len(recompute_fwd_gate_up_list)): + for k in range(min(segment_size, enable_k_o1_rc)): + ret.append(recompute_fwd_gate_up_list[i] + k) + return ret + + pp_nums = ( + self.config["pipeline_parallel_degree"] * 2 + if self.config.use_dualpipev + else self.config["pipeline_parallel_degree"] + ) + recompute_fwd_gate_up_list = compute_recompute_fwd_gate_up_list( + pp_nums, + self.config.num_hidden_layers, + self.config.first_k_dense_replace, + self.config.recompute_fwd_gate_up, + ) + + logger.info(f"recompute_fwd_gate_up_list: {recompute_fwd_gate_up_list}") + config.recompute_fwd_gate_up_list = recompute_fwd_gate_up_list + for i in range(config.num_hidden_layers): self.add_sequential_layer( LayerDesc( @@ -455,7 +2085,7 @@ def get_hcg(): if config.tie_word_embeddings: self.add_sequential_layer( - SharedLayerDesc( + shared_class( "DeepseekV2_shared_weight", DeepseekV2LMHeadPipe, shared_weight_attr="embedding_weight", @@ -491,11 +2121,69 @@ def get_hcg(): "partition": False, }, num_virtual_pipeline_stages=virtual_pp_degree, + use_dualpipev=use_dualpipev, ) # You should call init here, since there is a diamond inheritance problem self.apply(self._init_weights) # DON'T init PipelinePretrainedModel # PipelinePretrainedModel.__init__(self.super(), config=config) + def fp8_quant_weight(self, batch_mode=False): + """fp8_quant_weight""" + with paddle.no_grad(): + for i, layer in self._sub_layers.items(): + if isinstance( + layer, paddle.distributed.fleet.meta_parallel.parallel_layers.pp_layers.PipelineLayerChunk + ): + for i, sub_layer in layer.named_sublayers(): + if isinstance(sub_layer, DeepseekV2DecoderLayer) and hasattr(sub_layer, "fp8_quant_weight"): + sub_layer.fp8_quant_weight(batch_mode) + if isinstance(layer, DeepseekV2DecoderLayer) and hasattr(layer, "fp8_quant_weight"): + layer.fp8_quant_weight(batch_mode) + def get_loss_fn(self, config): return DeepseekV2PretrainingCriterionPipe(config) + + def overlapped_forward_backward( + self, + forward_chunk, # the module of the forward chunk + forward_inputs, + forward_loss_fn_node, + backward_chunk, # the module of the backward chunk, maybe not used + backward_loss_fn_node, + backward_input_grads, + scaler, + combine_bw_event_to_wait=None, + pp_stream=None, + ): + if backward_loss_fn_node is not None: + if scaler: + backward_input_grads = backward_loss_fn_node.backward(scaler=scaler) + else: + backward_input_grads = backward_loss_fn_node.backward() + + ( + forward_pre_node, + backward_pre_node, + overlap_node, + forward_post_node, + backward_post_node, + ) = build_overlapped_nodes(forward_chunk, backward_chunk) + forward_inputs = forward_pre_node.forward(forward_inputs) + backward_input_grads = backward_pre_node.backward(backward_input_grads) + forward_inputs, backward_input_grads, _ = overlap_node.forward_backward( + forward_inputs, + backward_input_grads, + combine_bw_event_to_wait=combine_bw_event_to_wait, + pp_stream=pp_stream, + ) + forward_inputs = forward_post_node.forward(forward_inputs) + backward_input_grads = backward_post_node.backward(backward_input_grads) + + if forward_loss_fn_node is not None: + forward_loss = forward_loss_fn_node.forward(forward_inputs) + else: + forward_loss = None + + forward_inputs = [forward_inputs] if isinstance(forward_inputs, paddle.Tensor) else forward_inputs + return forward_inputs, forward_loss, backward_input_grads diff --git a/paddleformers/transformers/deepseek_v3/modeling.py b/paddleformers/transformers/deepseek_v3/modeling.py index 51c0d1978fe..8f6ed05c3e7 100644 --- a/paddleformers/transformers/deepseek_v3/modeling.py +++ b/paddleformers/transformers/deepseek_v3/modeling.py @@ -25,13 +25,25 @@ import paddle -from ..deepseek_v2.modeling import ( - DeepseekV2ForSequenceClassification, - DeepseekV2LMHead, - DeepseekV2Model, - DeepseekV2PretrainedModel, - DeepseekV2PretrainingCriterion, -) +import os +if not os.getenv("DSV3_FAST_PRETRAIN", "False"): + from ..deepseek_v2.modeling import ( + DeepseekV2ForSequenceClassification, + DeepseekV2LMHead, + DeepseekV2Model, + DeepseekV2PretrainedModel, + DeepseekV2PretrainingCriterion, + ) +else: + from ..deepseek_v2.modeling import ( + DeepseekV2ForSequenceClassification, + DeepseekV2LMHead, + DeepseekV2PretrainingCriterion, + ) + + from ..deepseek_v2.modeling_fast import DeepseekV2ModelFast as DeepseekV2Model + from ..deepseek_v2.modeling_fast import DeepseekV2PretrainedModelFast as DeepseekV2PretrainedModel + from ..model_outputs import CausalLMOutputWithPast from ..model_utils import register_base_model from .configuration import DeepseekV3Config