diff --git a/examples/experiments/deepseek_v3_pretrain/config/__init__.py b/examples/experiments/deepseek_v3_pretrain/config/__init__.py new file mode 100644 index 00000000000..b9cb4467041 --- /dev/null +++ b/examples/experiments/deepseek_v3_pretrain/config/__init__.py @@ -0,0 +1,100 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import sys +from typing import TYPE_CHECKING + +from paddleformers.utils.lazy_import import _LazyModule + +import_structure = { + "configuration": ["DeepseekV2FastConfig"], + "modeling": [ + "masked_fill", + "DeepseekV2Attention", + "MoEGate", + "FakeGate", + "DeepseekV2ForCausalLM", + "_make_causal_mask", + "is_casual_mask", + "DeepseekV2MoE", + "DeepseekV2MoEFlexToken", + "scaled_dot_product_attention", + "DeepseekV2RotaryEmbedding", + "rotate_half", + "DeepseekV2MTPLayer", + "DeepseekV2RMSNorm", + "DeepseekV2YarnRotaryEmbedding", + "parallel_matmul", + "DeepseekV2PretrainedModel", + "AddAuxiliaryLoss", + "apply_rotary_pos_emb", + "assign_kv_heads", + "DeepseekV2ForSequenceClassification", + "_expand_2d_mask", + "DeepseekV2ModelFast", + "repeat_kv", + "yarn_find_correction_dim", + "yarn_linear_ramp_mask", + "DeepseekV2DynamicNTKScalingRotaryEmbedding", + "DeepseekV2MLP", + "yarn_get_mscale", + "DeepseekV2LMHead", + "DeepseekV2DecoderLayer", + "DeepseekV2PretrainingCriterionFast", + "yarn_find_correction_range", + "get_triangle_upper_mask", + "DeepseekV2LinearScalingRotaryEmbedding", + "set_global_step", + "get_global_step", + ], + "modeling_auto": [ + "DeepseekV2LMHeadAuto", + "DeepseekV2ForCausalLMAuto", + "DeepseekV2ModelAuto", + "DeepseekV2PretrainedModelAuto", + ], + "modeling_pp": ["DeepseekV2ForCausalLMPipe"], + "mfu_utils": ["DeepSeekProjection"], + "kernel": [ + "act_quant", + "weight_dequant", + "fp8_gemm", + "weight_dequant_kernel", + "act_quant_kernel", + "fp8_gemm_kernel", + ], + "tokenizer_fast": ["DeepseekTokenizerFast"], + "fp8_linear": [ + "Linear", + "ColumnParallelLinear", + "RowParallelLinear", + "ColumnSequenceParallelLinear", + "RowSequenceParallelLinear", + ], +} + +if TYPE_CHECKING: + from .configuration import * + from .modeling import * + from .modeling_auto import * + from .modeling_pp import * + from .tokenizer_fast import * +else: + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + import_structure, + module_spec=__spec__, + ) diff --git a/examples/experiments/deepseek_v3_pretrain/config/config.json b/examples/experiments/deepseek_v3_pretrain/config/config.json new file mode 100644 index 00000000000..a644bfe1e84 --- /dev/null +++ b/examples/experiments/deepseek_v3_pretrain/config/config.json @@ -0,0 +1,80 @@ +{ + "architectures": [ + "DeepseekV2ForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "DeepseekV2FastConfig", + "AutoModel": "DeepseekV2ModelFast", + "AutoModelForCausalLM": "DeepseekV2ForCausalLM" + }, + "aux_loss_alpha": 0.001, + "bos_token_id": 0, + "eos_token_id": 1, + "ep_size": 1, + "first_k_dense_replace": 3, + "hidden_act": "silu", + "hidden_size": 7168, + "initializer_range": 0.02, + "intermediate_size": 18432, + "kv_lora_rank": 512, + "max_position_embeddings": 163840, + "model_type": "deepseek_v3", + "moe_intermediate_size": 2048, + "moe_layer_freq": 1, + "n_group": 8, + "n_routed_experts": 8, + "n_shared_experts": 1, + "norm_topk_prob": true, + "num_attention_heads": 128, + "num_experts_per_tok": 8, + "num_hidden_layers": 15, + "num_key_value_heads": 128, + "num_nextn_predict_layers": 1, + "pretraining_tp": 1, + "q_lora_rank": 1536, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn" + }, + "rope_theta": 10000, + "routed_scaling_factor": 2.5, + "scoring_func": "sigmoid", + "seq_aux": true, + "tie_word_embeddings": false, + "topk_group": 4, + "topk_method": "noaux_tc", + "dtype": "bfloat16", + "transformers_version": "4.33.1", + "use_cache": true, + "v_head_dim": 128, + "vocab_size": 129280, + "using_flex_token": true, + "using_fake_gate": true, + "use_fused_rms_norm": true, + "fuse_attention_ffn": true, + "use_fused_rope": true, + "token_drop_steps": 0, + "recompute_fwd_gate_up": true, + "adaptive_remained_O1_recompute_ratio": 0.3, + "using_post_norm_recompute": true, + "is_split_group_gemm": false, + "use_dualpipev": true, + "send_mtp_embed": true, + "offline_quant_expert_weight": false, + "clear_origin_weight_when_offline_quant": false, + "dsv3_use_fp8_gemm": true, + "dsv3_use_atten_recompute": true, + "use_ds_gemm": false, + "dsv3_use_fp8_dispatch": true, + "fa_version": 3 + } \ No newline at end of file diff --git a/examples/experiments/deepseek_v3_pretrain/config/configuration.py b/examples/experiments/deepseek_v3_pretrain/config/configuration.py new file mode 100644 index 00000000000..f281c12c17d --- /dev/null +++ b/examples/experiments/deepseek_v3_pretrain/config/configuration.py @@ -0,0 +1,275 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 DeepSeek. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" DeepSeekV2 model configuration""" +from paddleformers.transformers.configuration_utils import PretrainedConfig + +__all__ = [ + "DeepseekV2FastConfig", +] + + +class DeepseekV2FastConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekV2ModelFast`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V2. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 102400): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV2ModelFast`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_nextn_predict_layers (`int`, *optional*, defaults to 1): + Number of nextn predict layers in the DeepSeekV3 Model. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + speculate_model_type (`str`, defaults to `None`, *optional*, defaults to `False`): + The model type for speculate. Support ['eagle', 'mtp'] Now. + + ```python + >>> from paddleformers.transformers import DeepseekV2ModelFast, DeepseekV2FastConfig + + >>> # Initializing a Deepseek-V2 style configuration + >>> configuration = DeepseekV2FastConfig() + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deepseek_v3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size=1407, + num_hidden_layers=30, + num_nextn_predict_layers=0, + num_nextn_predict_lambda=1.0, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts=None, + n_routed_experts=None, + ep_size=1, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="gready", + n_group=None, + topk_group=None, + num_experts_per_tok=None, + moe_layer_freq=1, + first_k_dense_replace=0, + norm_topk_prob=False, + scoring_func="softmax", + aux_loss_alpha=0.001, + seq_aux=True, + hidden_act="silu", + max_position_embeddings=2048, + seq_length=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + speculate_model_type=False, + using_flex_token=False, + use_dualpipev=False, + send_mtp_embed=False, + using_post_norm_recompute=False, + stepped_recompute_fwd_gate_up=False, + recompute_fwd_gate_up=0, + recompute_fa3=0, + is_split_group_gemm=False, + fakse_gate_restrict_balance=False, + adaptive_remained_O1_recompute_ratio=0, + offline_quant_expert_weight=True, + clear_origin_weight_when_offline_quant=True, + mlp_bwd_subbatch_rows=0, + mlp_fwd_subbatch_rows=0, + output_subbatch_rows=0, + dsv3_use_fp8_gemm=True, + dsv3_use_atten_recompute=True, + use_ds_gemm=False, + dsv3_use_fp8_dispatch=True, + fa_version=3, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.seq_length = seq_length + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_nextn_predict_layers = num_nextn_predict_layers + self.num_nextn_predict_lambda = num_nextn_predict_lambda + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.speculate_model_type = speculate_model_type + self.use_fp8 = False + self.using_flex_token = using_flex_token + self.use_dualpipev = use_dualpipev + self.send_mtp_embed = send_mtp_embed + self.using_post_norm_recompute = using_post_norm_recompute + self.recompute_fwd_gate_up = recompute_fwd_gate_up + self.recompute_fa3 = recompute_fa3 + self.stepped_recompute_fwd_gate_up = stepped_recompute_fwd_gate_up + self.is_split_group_gemm = is_split_group_gemm + self.fakse_gate_restrict_balance = fakse_gate_restrict_balance + self.adaptive_remained_O1_recompute_ratio = adaptive_remained_O1_recompute_ratio + self.offline_quant_expert_weight = offline_quant_expert_weight + self.clear_origin_weight_when_offline_quant = clear_origin_weight_when_offline_quant + self.mlp_bwd_subbatch_rows = mlp_bwd_subbatch_rows + self.mlp_fwd_subbatch_rows = mlp_fwd_subbatch_rows + self.output_subbatch_rows = output_subbatch_rows + self.dsv3_use_fp8_gemm = dsv3_use_fp8_gemm + self.dsv3_use_atten_recompute = dsv3_use_atten_recompute + self.use_ds_gemm = use_ds_gemm + self.dsv3_use_fp8_dispatch = dsv3_use_fp8_dispatch + self.fa_version = fa_version + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/examples/experiments/deepseek_v3_pretrain/config/pretrain_argument.json b/examples/experiments/deepseek_v3_pretrain/config/pretrain_argument.json new file mode 100644 index 00000000000..0c8d4aefed9 --- /dev/null +++ b/examples/experiments/deepseek_v3_pretrain/config/pretrain_argument.json @@ -0,0 +1,53 @@ +{ + "model_name_or_path": "./config/", + "tokenizer_name_or_path": "deepseek-ai/DeepSeek-V3", + "input_dir": "./data", + "output_dir": "./checkpoints/pretrain_ckpts", + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 24, + "per_device_eval_batch_size": 1, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 4, + "pipeline_parallel_config": "use_dualpipev", + "sharding_parallel_degree": 2, + "sharding_parallel_config": "split_param", + "sharding_comm_buffer_size_MB": 2048, + "expert_parallel_degree": 2, + "sharding": "stage1", + "virtual_pp_degree": 1, + "sequence_parallel": 0, + "use_flash_attention": true, + "max_seq_length": 4097, + "learning_rate": 3e-05, + "min_learning_rate": 3e-06, + "warmup_steps": 30, + "logging_steps": 1, + "max_steps": 200, + "save_steps": 5000, + "eval_steps": 1000, + "weight_decay": 0.01, + "bf16": true, + "fp16_opt_level": "O2", + "warmup_ratio": 0.01, + "max_grad_norm": 1.0, + "amp_master_grad": 1, + "dataloader_num_workers": 8, + "continue_training": 0, + "do_train": true, + "do_eval": true, + "do_predict": false, + "disable_tqdm": true, + "recompute": false, + "distributed_dataloader": 1, + "unified_checkpoint": true, + "save_total_limit": 2, + "skip_profile_timer": false, + "use_fused_rms_norm": true, + "fuse_attention_ffn": true, + "use_fused_rope": true, + "save_sharded_model": false, + "load_sharded_model": false, + "use_expert_parallel": true, + "unified_checkpoint_config": "skip_save_model_weight", + "offload_optim": true + } \ No newline at end of file diff --git a/examples/experiments/deepseek_v3_pretrain/data/indexed_dataset.py b/examples/experiments/deepseek_v3_pretrain/data/indexed_dataset.py new file mode 100644 index 00000000000..1b1cf7403de --- /dev/null +++ b/examples/experiments/deepseek_v3_pretrain/data/indexed_dataset.py @@ -0,0 +1,972 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# https://github.com/NVIDIA/Megatron-LM/blob/060415572f4365a2e895f8036c4e37dad0efbdf5/megatron/data/indexed_dataset.py +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# copied from fairseq/fairseq/data/indexed_dataset.py +# Removed IndexedRawTextDataset since it relied on Fairseq dictionary +# other slight modifications to remove fairseq dependencies +# Added document index to index file and made it accessible. +# An empty sentence no longer separates documents. + +import os +import shutil +import struct +import time +from dataclasses import fields +from functools import lru_cache +from itertools import accumulate + +import numpy as np +import paddle + + +def print_rank_0(*args, **kwargs): + if paddle.distributed.get_rank() == 0: + print(*args, **kwargs) + + +def __best_fitting_dtype(vocab_size=None): + if vocab_size is not None and vocab_size < 65500: + return np.uint16 + else: + return np.int32 + + +def get_available_dataset_impl(): + return ["lazy", "mmap"] + + +def make_dataset(path, impl, skip_warmup=False): + if CompatibleIndexedDataset.exists(path): + print("Using old dataset (.npy & .npz)") + return CompatibleIndexedDataset(path) + elif not IndexedDataset.exists(path): + print(f"Dataset does not exist: {path}") + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + return None + elif impl == "lazy" and IndexedDataset.exists(path): + return IndexedDataset(path) + elif impl == "mmap" and MMapIndexedDataset.exists(path): + return MMapIndexedDataset(path, skip_warmup) + print(f"Unknown dataset implementation: {impl}") + return None + + +def make_sft_dataset(path, dataclass, skip_warmup=False, impl="mmap"): + if impl != "mmap": + raise ValueError("SFT Indexed Dataset only support mmap memory-mapped method temporarily") + + print_rank_0(" > building dataset index ...") + start_time = time.time() + sft_indexed_dataset = SFTMMapIndexedDataset(path, dataclass, skip_warmup) + print_rank_0(" > finished creating SFT indexed dataset in {:4f} " "seconds".format(time.time() - start_time)) + print_rank_0(" number of samples: {}".format(len(sft_indexed_dataset.doc_idx) - 1)) + + return sft_indexed_dataset + + +def dataset_exists(path, impl): + if impl == "mmap": + return MMapIndexedDataset.exists(path) + else: + return IndexedDataset.exists(path) + + +def read_longs(f, n): + a = np.empty(n, dtype=np.int64) + f.readinto(a) + return a + + +def write_longs(f, a): + f.write(np.array(a, dtype=np.int64)) + + +def read_shorts(f, n): + a = np.empty(n, dtype=np.int32) + f.readinto(a) + return a + + +def write_shorts(f, a): + f.write(np.array(a, dtype=np.int32)) + + +dtypes = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: np.float64, + 7: np.float32, + 8: np.uint16, + 9: np.uint32, + 10: np.uint64, +} + + +def code(dtype): + for k in dtypes.keys(): + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + + +def index_file_path(prefix_path): + return prefix_path + ".idx" + + +def sft_index_file_path(prefix_path): + return os.path.join(prefix_path, "index.idx") + + +def sft_data_file_path(prefix_path, dataclass): + file_path_list = [] + for field in fields(dataclass): + file_path = os.path.join(prefix_path, f"{field.name}.bin") + file_path_list.append(file_path) + return file_path_list + + +def data_file_path(prefix_path): + return prefix_path + ".bin" + + +def loss_mask_file_path(prefix_path): + return prefix_path + ".lsm" + + +def create_doc_idx(sizes): + doc_idx = [0] + for i, s in enumerate(sizes): + if s == 0: + doc_idx.append(i + 1) + return doc_idx + + +class IndexedDataset(paddle.io.Dataset): + """Loader for IndexedDataset""" + + _HDR_MAGIC = b"TNTIDX\x00\x00" + + def __init__(self, path): + super().__init__() + self.path = path + self.data_file = None + self.read_index(path) + + def read_index(self, path): + with open(index_file_path(path), "rb") as f: + magic = f.read(8) + assert magic == self._HDR_MAGIC, ( + "Index file doesn't match expected format. " "Make sure that --dataset-impl is configured properly." + ) + version = f.read(8) + assert struct.unpack("= self._len: + raise IndexError("index out of range") + + def __del__(self): + if self.data_file: + self.data_file.close() + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if not self.data_file: + self.read_data(self.path) + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + return a + elif isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + if step != 1: + raise ValueError("Slices into indexed_dataset must be contiguous") + sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]] + size = sum(sizes) + a = np.empty(size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[start] * self.element_size) + self.data_file.readinto(a) + offsets = list(accumulate(sizes)) + sents = np.split(a, offsets[:-1]) + return sents + + def get(self, idx, offset=0, length=None): + """Retrieves a single item from the dataset with the option to only + return a portion of the item. + + get(idx) is the same as [idx] but get() does not support slicing. + """ + if not self.data_file: + self.read_data(self.path) + size = self.sizes[idx] + ptr = self.data_offsets[idx] + if length is None: + length = size - offset + ptr += offset + a = np.empty(length, dtype=self.dtype) + self.data_file.seek(ptr * self.element_size) + self.data_file.readinto(a) + return a + + def __len__(self): + return self._len + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + @staticmethod + def exists(path): + return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) + + @property + def supports_prefetch(self): + return False # avoid prefetching to save memory + + @property + def doc_idx(self): + return self._doc_idx + + def get_doc_idx(self): + return self._doc_idx + + def set_doc_idx(self, doc_idx_): + self._doc_idx = doc_idx_ + + +class IndexedDatasetBuilder(object): + element_sizes = { + np.uint8: 1, + np.int8: 1, + np.int16: 2, + np.uint16: 2, + np.int32: 4, + np.int64: 8, + np.float32: 4, + np.float64: 8, + } + + def __init__(self, out_file, dtype=np.int32): + self.out_file = open(out_file, "wb") + self.dtype = dtype + self.data_offsets = [0] + self.dim_offsets = [0] + self.sizes = [] + self.element_size = self.element_sizes[self.dtype] + self.doc_idx = [0] + + def add_item(self, tensor): + tensor = np.array(tensor, dtype=self.dtype) + bytes = self.out_file.write(tensor) + self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) + for s in tensor.shape: + self.sizes.append(s) + self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.shape)) + del bytes + + def end_document(self): + self.doc_idx.append(len(self.sizes)) + + def merge_file_(self, another_file): + index = IndexedDataset(another_file) + assert index.dtype == self.dtype + + doc_offset = len(self.sizes) + + begin = self.data_offsets[-1] + for data_offset in index.data_offsets[1:]: + self.data_offsets.append(begin + data_offset) + self.sizes.extend(index.sizes) + + begin = self.dim_offsets[-1] + for dim_offset in index.dim_offsets[1:]: + self.dim_offsets.append(begin + dim_offset) + + self.doc_idx.extend((doc_offset + index.doc_idx)[1:]) + + with open(data_file_path(another_file), "rb") as f: + while True: + data = f.read(1024) + if data: + self.out_file.write(data) + else: + break + + def finalize(self, index_file): + self.out_file.close() + index = open(index_file, "wb") + index.write(b"TNTIDX\x00\x00") + index.write(struct.pack(" 1 and not add_sequence_len: + self._sizes.append(tensor.size) + add_sequence_len = True + self._data_file_dict[key].write(tensor.tobytes(order="C")) + + def end_document(self): + self._doc_idx.append(len(self._sizes)) + + def finalize(self, index_file): + for key, filename in self._data_file_dict.items(): + filename.close() + with SFTMMapIndexedDataset.Index.writer(index_file, self._dtype) as index: + index.write(self._sizes, self._doc_idx) + + +class MMapIndexedDatasetBuilder(object): + def __init__(self, out_file, dtype, loss_mask_file=None): + self._data_file = open(out_file, "wb") + self._loss_mask_file = None + if loss_mask_file is not None: + self._loss_mask_file = open(loss_mask_file, "wb") + self._dtype = dtype + self._sizes = [] + self._doc_idx = [0] + + def flush_loss_mask_item(self, loss_mask_lst): + for loss_mask in loss_mask_lst: + tensor = np.array(loss_mask, dtype=np.uint8) + self._loss_mask_file.write(tensor.tobytes(order="C")) + + def add_item(self, tensor): + tensor = np.array(tensor, dtype=self._dtype) + self._data_file.write(tensor.tobytes(order="C")) + self._sizes.append(tensor.size) + + def add_doc(self, tensor, sizes): + np_array = np.array(tensor, dtype=self._dtype) + self._data_file.write(np_array.tobytes(order="C")) + self._sizes.extend(sizes) + self._doc_idx.append(len(self._sizes)) + + def end_document(self): + self._doc_idx.append(len(self._sizes)) + + def merge_file_(self, another_file): + # Concatenate index + index = MMapIndexedDataset.Index(index_file_path(another_file)) + assert index.dtype == self._dtype + + offset = len(self._sizes) + self._sizes.extend(index.sizes) + self._doc_idx.extend((offset + index.doc_idx)[1:]) + + # Concatenate data + with open(data_file_path(another_file), "rb") as f: + shutil.copyfileobj(f, self._data_file) + + def finalize(self, index_file): + self._data_file.close() + + with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index: + index.write(self._sizes, self._doc_idx) + print("Total sentences num: %d" % len(self._sizes)) + print("Total documents num: %d" % (len(self._doc_idx) - 1)) + print("Total tokens num: %d" % sum(self._sizes)) + print("Average tokens per sentence: %.2f" % (sum(self._sizes) / len(self._sizes))) + print("Average tokens per document: %.2f" % (sum(self._sizes) / (len(self._doc_idx) - 1))) + + +def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): + + print_rank_0(" > building dataset index ...") + + start_time = time.time() + indexed_dataset = make_dataset(data_prefix, data_impl, skip_warmup) + assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] + print_rank_0(" > finished creating indexed dataset in {:4f} " "seconds".format(time.time() - start_time)) + + print_rank_0(" > indexed dataset stats:") + print_rank_0(" number of documents: {}".format(indexed_dataset.doc_idx.shape[0] - 1)) + print_rank_0(" number of sentences: {}".format(indexed_dataset.sizes.shape[0])) + + return indexed_dataset + + +class CompatibleIndexedDataset(paddle.io.Dataset): + def __init__(self, path): + super().__init__() + + self._path = path + + # All document ids, extend as 1-D array. + self._token_ids = np.load(path + "_ids.npy", mmap_mode="r", allow_pickle=True) + process_data = np.load(path + "_idx.npz") + self._sizes = process_data["lens"] + self._pointers = np.empty(len(self._sizes) + 1, dtype=np.int64) + self._pointers[0] = 0 + np.cumsum(self._sizes, out=self._pointers[1:]) + self._doc_idx = process_data["docs"] + + def __getstate__(self): + return self._path + + def __len__(self): + return len(self._sizes) + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if isinstance(idx, int): + size = self._sizes[idx] + ptr = self._pointers[idx] + np_array = self._token_ids[ptr : ptr + size] + return np_array + + elif isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + if step != 1: + raise ValueError("Slices into indexed_dataset must be contiguous") + ptr = self._pointers[start] + sizes = self._sizes[idx] + offsets = list(accumulate(sizes)) + total_size = sum(sizes) + np_array = self._token_ids[ptr : ptr + total_size] + sents = np.split(np_array, offsets[:-1]) + return sents + + def get(self, idx, offset=0, length=None): + """Retrieves a single item from the dataset with the option to only + return a portion of the item. + + get(idx) is the same as [idx] but get() does not support slicing. + """ + size = self._sizes[idx] + ptr = self._pointers[idx] + + if length is None: + length = size - offset + ptr += offset + np_array = self._token_ids[ptr : ptr + length] + return np_array, None + + @property + def sizes(self): + return self._sizes + + @property + def doc_idx(self): + return self._doc_idx + + def get_doc_idx(self): + return self._doc_idx + + def set_doc_idx(self, doc_idx_): + self._doc_idx = doc_idx_ + + @staticmethod + def exists(path): + return os.path.isfile(path + "_ids.npy") and os.path.isfile(path + "_idx.npz") diff --git a/examples/experiments/deepseek_v3_pretrain/load_hf_ckpt.py b/examples/experiments/deepseek_v3_pretrain/load_hf_ckpt.py new file mode 100644 index 00000000000..802b8e205f9 --- /dev/null +++ b/examples/experiments/deepseek_v3_pretrain/load_hf_ckpt.py @@ -0,0 +1,377 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import re +import sys +from collections import defaultdict +from typing import List, Optional + +import paddle + +from paddleformers.utils.log import logger + +try: + from safetensors import safe_open +except: + safe_open = None + +_LAYER_RE = re.compile(r"^_layers\.(\d+)\.(\d+)(?:\.(.*))?$") +_EXPERT_W1_RE = re.compile(r"^mlp\.experts\.(\d+)\.w1(?:\.weight)?$") +_EXPERT_W2_RE = re.compile(r"^mlp\.experts\.(\d+)\.w2(?:\.weight)?$") +_SHARE_EXPERT_W1_RE = re.compile(r"^mlp\.shared_experts\.w1(?:\.weight)?$") +_SHARE_EXPERT_W2_RE = re.compile(r"^mlp\.shared_experts\.w2(?:\.weight)?$") + +_EXPERT_W1_RE_v2 = re.compile(r"^mlp\.experts\.(\d+)\.gate_up_fused_proj(?:\.weight)?$") +_SHARE_EXPERT_W1_RE_v2 = re.compile(r"^mlp\.shared_experts\.gate_up_fused_proj(?:\.weight)?$") +_LAYER_RE_v2 = re.compile(r"_layers.deepseek_v2.layers\.(\d+)\.(.*)$") + +custom_name_map = { + "self_attn.input_layernorm.weight": "input_layernorm.weight", + "self_attn.fused_rms_norm_linear.rms_norm_weight": "input_layernorm.weight", + "self_attn.memory_recompute_att.kv_ln_weight": "self_attn.kv_a_layernorm.weight", + "self_attn.fused_rms_norm_linear.kv_down_weight": "self_attn.kv_a_proj_with_mqa.weight", + "self_attn.memory_recompute_att.kv_up_weight": "self_attn.kv_b_proj.weight", + "self_attn.memory_recompute_att.q_ln_weight": "self_attn.q_a_layernorm.weight", + "self_attn.fused_rms_norm_linear.q_down_weight": "self_attn.q_a_proj.weight", + "self_attn.memory_recompute_att.q_up_weight": "self_attn.q_b_proj.weight", +} + + +def paddle_name_to_hf_names_ds_v2(paddle_name: str) -> List[str]: + """ + Convert Paddle model parameter names to Hugging Face format name lists + + Args: + paddle_name: Parameter name in Paddle format + + Returns: + List of parameter names in Hugging Face format (may be split into multiple parameters) + """ + if paddle_name == "_layers.deepseek_v2.embed_tokens.weight": + return ["model.embed_tokens.weight"] + + if paddle_name == "_layers.deepseek_v2.norm.weight": + return ["model.norm.weight"] + + if paddle_name == "_layers.lm_head.weight": + return ["lm_head.weight"] + + m = _LAYER_RE_v2.match(paddle_name) + if not m: + return [] + + rest = m.group(2) or "" + layer_id = m.group(1) + if rest in custom_name_map: + rest = custom_name_map[rest] + out_name = "model.layers." + layer_id + "." + rest + + if rest == "mlp.gate_up_fused_proj.weight" or rest == "mlp.w1": + return [ + "model.layers." + layer_id + ".mlp.gate_proj.weight", + "model.layers." + layer_id + ".mlp.up_proj.weight", + ] + + if rest == "mlp.w2": + return ["model.layers." + layer_id + ".mlp.down_proj.weight"] + + if rest == "mlp.shared_experts.gate_up_fused_proj.weight": + return [ + "model.layers." + layer_id + ".mlp.shared_experts.gate_proj.weight", + "model.layers." + layer_id + ".mlp.shared_experts.up_proj.weight", + ] + + if m := _EXPERT_W1_RE_v2.match(rest): + expert_id = m.group(1) + return [ + "model.layers." + layer_id + ".mlp.experts." + expert_id + ".gate_proj.weight", + "model.layers." + layer_id + ".mlp.experts." + expert_id + ".up_proj.weight", + ] + + if m := _EXPERT_W1_RE.match(rest): + expert_id = m.group(1) + return [ + "model.layers." + layer_id + ".mlp.experts." + expert_id + ".gate_proj.weight", + "model.layers." + layer_id + ".mlp.experts." + expert_id + ".up_proj.weight", + ] + + if m := _EXPERT_W2_RE.match(rest): + expert_id = m.group(1) + return ["model.layers." + layer_id + ".mlp.experts." + expert_id + ".down_proj.weight"] + + if m := _SHARE_EXPERT_W1_RE.match(rest): + return [ + "model.layers." + layer_id + ".mlp.shared_experts.gate_proj.weight", + "model.layers." + layer_id + ".mlp.shared_experts.up_proj.weight", + ] + + if m := _SHARE_EXPERT_W2_RE.match(rest): + return ["model.layers." + layer_id + ".mlp.shared_experts.down_proj.weight"] + + return [out_name] + + +def paddle_name_to_hf_names(paddle_name: str) -> List[str]: + """ + Convert Paddle model parameter names to Hugging Face format name lists + + Args: + paddle_name: Parameter name in Paddle format + + Returns: + List of parameter names in Hugging Face format (may be split into multiple parameters) + """ + if paddle_name == "_layers.local_shared_layers.DeepseekV2_shared_weight.embed_tokens.weight": + return ["model.embed_tokens.weight"] + + if paddle_name == "_layers.deepseek_v2.embed_tokens.weight": + return ["model.embed_tokens.weight"] + + m = _LAYER_RE.match(paddle_name) + + if not m: + return [] + else: + rest = m.group(3) or "" + + segment_id = int(m.group(1)) + id_in_segment = int(m.group(2)) + + hf_prefix = _get_hf_prefix(segment_id, id_in_segment) + + if rest in custom_name_map: + return [f"{hf_prefix}.{custom_name_map[rest]}"] + + if expert_names := _handle_expert_weights(hf_prefix, rest): + return expert_names + + if shared_mlp_names := _handle_shared_expert_weights(hf_prefix, rest): + return shared_mlp_names + + if mlp_names := _handle_mlp_weights(hf_prefix, rest): + return mlp_names + + if rest == "mlp.gate_up_fused_proj.weight" or rest == "mlp.w1": + return [hf_prefix + ".mlp.gate_proj.weight", hf_prefix + ".mlp.up_proj.weight"] + + if rest == "mlp.w2": + return [hf_prefix + ".mlp.down_proj.weight"] + + if rest == "mlp.shared_experts.gate_up_fused_proj.weight": + return [hf_prefix + ".mlp.shared_experts.gate_proj.weight", hf_prefix + ".mlp.shared_experts.up_proj.weight"] + + if m := _EXPERT_W1_RE_v2.match(rest): + expert_id = m.group(1) + return [ + hf_prefix + ".mlp.experts." + expert_id + ".gate_proj.weight", + hf_prefix + ".mlp.experts." + expert_id + ".up_proj.weight", + ] + + if m := _EXPERT_W1_RE.match(rest): + expert_id = m.group(1) + return [ + hf_prefix + ".mlp.experts." + expert_id + ".gate_proj.weight", + hf_prefix + ".mlp.experts." + expert_id + ".up_proj.weight", + ] + + if m := _EXPERT_W2_RE.match(rest): + expert_id = m.group(1) + return [hf_prefix + ".mlp.experts." + expert_id + ".down_proj.weight"] + + if m := _SHARE_EXPERT_W1_RE.match(rest): + return [hf_prefix + ".mlp.shared_experts.gate_proj.weight", hf_prefix + ".mlp.shared_experts.up_proj.weight"] + + if m := _SHARE_EXPERT_W2_RE.match(rest): + return [hf_prefix + ".mlp.shared_experts.down_proj.weight"] + + return [f"{hf_prefix}.{rest}"] if rest else [hf_prefix] + + +def _get_hf_prefix(segment_id: int, id_in_segment: int) -> str: + """Generate hierarchical prefix in Hugging Face format""" + # Special layer mappings + # special_cases = {(0, 0): "model", (60, 2): "model.layers.61", (60, 3): "model"} + # special_cases = {(0, 0): "model", (28, 2): "model.layers.61", (28, 3): "model"} + # special_cases = {(0, 0): "model", (28, 2): "model.layers.61", (4, 1): "model"} + # special_cases = {(0, 0): "model", (28, 2): "model", (28,3): "lm_head"} + special_cases = {(0, 0): "model", (60, 2): "model.layers.61", (60, 3): "model", (60, 4): "lm_head"} + + if (segment_id, id_in_segment) in special_cases: + return special_cases[(segment_id, id_in_segment)] + + # General layer calculation + layer_idx = segment_id + id_in_segment - 1 + return f"model.layers.{layer_idx}" + + +def _handle_expert_weights(hf_prefix: str, rest: str) -> Optional[List[str]]: + if m := _EXPERT_W1_RE.match(rest): + expert_id = int(m.group(1)) + return [ + f"{hf_prefix}.mlp.experts.{expert_id}.gate_proj.weight", + f"{hf_prefix}.mlp.experts.{expert_id}.up_proj.weight", + ] + + if m := _EXPERT_W2_RE.match(rest): + expert_id = int(m.group(1)) + return [f"{hf_prefix}.mlp.experts.{expert_id}.down_proj.weight"] + + return None + + +def _handle_shared_expert_weights(hf_prefix: str, rest: str) -> Optional[List[str]]: + if _SHARE_EXPERT_W1_RE.match(rest): + return [ + f"{hf_prefix}.mlp.shared_experts.gate_proj.weight", + f"{hf_prefix}.mlp.shared_experts.up_proj.weight", + ] + + if _SHARE_EXPERT_W2_RE.match(rest): + return [f"{hf_prefix}.mlp.shared_experts.down_proj.weight"] + + return None + + +def _handle_mlp_weights(hf_prefix: str, rest: str) -> Optional[List[str]]: + if rest == "mlp.w1": + return [f"{hf_prefix}.mlp.gate_proj.weight", f"{hf_prefix}.mlp.up_proj.weight"] + + if rest == "mlp.w2": + return [f"{hf_prefix}.mlp.down_proj.weight"] + + return None + + +def prepare_tensor(tensor, dst_shape, *, force_transpose=False): + if isinstance(tensor, list): + t = paddle.concat( + [ + paddle.transpose(tensor[0], perm=[1, 0]).contiguous(), + paddle.transpose(tensor[1], perm=[1, 0]).contiguous(), + ], + axis=-1, + ) + if t.shape != dst_shape: + logger.warning( + f"Prepare_tensor: shape not match. base tensor shape: {tensor[0].shape}, {tensor[1].shape}, t.shape: {t.shape}, dst_shape: {dst_shape}" + ) + sys.exit() + return t + + if force_transpose: + return tensor.T.contiguous() + + if tensor.shape == dst_shape: + return tensor + if len(tensor.shape) == 2 and paddle.transpose(tensor, perm=[1, 0]).contiguous().shape == dst_shape: + return paddle.transpose(tensor, perm=[1, 0]).contiguous() + + logger.warning("Prepare_tensor: shape not match.") + sys.exit() + + +def load_huggingface_ckpt(model, huggingface_ckpt_path): + ckpt_pre = huggingface_ckpt_path + + # 1. Load parameter file mapping table + weight_map_path = ckpt_pre + "/model.safetensors.index.json" + with open(weight_map_path, "r") as f: + weight_map = json.load(f)["weight_map"] + + # 2. Create inverse index: file -> parameter list + file_to_params = defaultdict(list) + for param_name, filename in weight_map.items(): + file_to_params[filename].append(param_name) + + # 3. Collect file list that model needs + required_files = set() + file_to_pd_param_name = defaultdict(list) + pd_param_name_to_file = defaultdict(list) + for pd_name, p in model.named_parameters(): + hf_name = paddle_name_to_hf_names(pd_name) + if hf_name[0] in weight_map: + filename = weight_map[hf_name[0]] + required_files.add(filename) + file_to_pd_param_name[filename].append(pd_name) + pd_param_name_to_file[pd_name].append(filename) + else: + logger.warning(f"Warning: {pd_name} -> {hf_name[0]} not found in weight map") + import sys + + sys.exit() + + if len(hf_name) > 1: + if hf_name[1] in weight_map: + filename = weight_map[hf_name[1]] + required_files.add(filename) + file_to_pd_param_name[filename].append(pd_name) + if filename != pd_param_name_to_file[pd_name][0]: + pd_param_name_to_file[pd_name].append(filename) + else: + logger.warning(f"Warning: {pd_name} -> {hf_name[1]} not found in weight map") + + # 4. Group file and load + check_list = [] + logger.info("Start load huggingface ckpt") + for i, filename in enumerate(required_files): + try: + with safe_open(ckpt_pre + filename, framework="paddle", device="cpu") as f: + # Load all parameters in file + pd_params = file_to_pd_param_name[filename] + for pd_param in pd_params: + if pd_param in check_list: + continue + + hf_name = paddle_name_to_hf_names(pd_param) + if len(hf_name) == 1: + tensor = f.get_tensor(hf_name[0]) + + force_transpose = False + + model.state_dict()[pd_param].set_value( + paddle.cast( + prepare_tensor( + tensor, model.state_dict()[pd_param].shape, force_transpose=force_transpose + ), + model.state_dict()[pd_param].dtype, + ) + ) + else: + files = pd_param_name_to_file[pd_param] + if len(files) == 1: + tensor0 = f.get_tensor(hf_name[0]) + tensor1 = f.get_tensor(hf_name[1]) + else: + if weight_map[hf_name[0]] == filename: + tensor0 = f.get_tensor(hf_name[0]) + with safe_open( + ckpt_pre + weight_map[hf_name[1]], framework="paddle", device="cpu" + ) as f_other: + tensor1 = f_other.get_tensor(hf_name[1]) + else: + with safe_open( + ckpt_pre + weight_map[hf_name[0]], framework="paddle", device="cpu" + ) as f_other: + tensor0 = f_other.get_tensor(hf_name[0]) + tensor1 = f.get_tensor(hf_name[1]) + model.state_dict()[pd_param].set_value( + prepare_tensor([tensor0, tensor1], model.state_dict()[pd_param].shape) + ) + check_list.append(pd_param) + + except Exception as e: + logger.warning(f"Error loading {filename}: {str(e)}") + raise diff --git a/examples/experiments/deepseek_v3_pretrain/modeling.py b/examples/experiments/deepseek_v3_pretrain/modeling.py new file mode 100644 index 00000000000..ef95b829399 --- /dev/null +++ b/examples/experiments/deepseek_v3_pretrain/modeling.py @@ -0,0 +1,3039 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 DeepSeek. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Paddle DeepSeek model.""" + +from __future__ import annotations + +import contextlib +import math +import warnings +from functools import partial +from typing import List, Optional, Tuple, Union + +import paddle +import paddle.distributed as dist +import paddle.distributed.fleet.meta_parallel as mpu +import paddle.nn.functional as F +from paddle import Tensor, nn +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.distributed.fleet.recompute.recompute import recompute +from paddle.jit import to_static +from paddle.utils import try_import + +try: + from paddle.incubate.nn.functional import fused_rotary_position_embedding +except ImportError: + fused_rotary_position_embedding = None + +try: + from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + GatherOp, + ScatterOp, + mark_as_sequence_parallel_parameter, + ) +except: + pass + +from paddle import _C_ops + +try: + from paddle.nn.functional.flash_attention import flash_attention +except: + flash_attention = None + +from config.configuration import DeepseekV2FastConfig +from moe_gate import PretrainedMoEGate +from moe_layer import MoELayer +from moe_utils import get_env_device +from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import WeightGradStore + +from paddleformers.transformers.activations import ACT2FN +from paddleformers.transformers.conversion_utils import ( + StateDictNameMapping, + init_name_mappings, +) +from paddleformers.transformers.deepseek_v2 import ( + DeepseekV2DynamicNTKScalingRotaryEmbedding, + DeepseekV2LinearScalingRotaryEmbedding, + DeepseekV2RotaryEmbedding, + _expand_2d_mask, + _make_causal_mask, +) +from paddleformers.transformers.deepseek_v2 import fp8_linear as linear_utils +from paddleformers.transformers.deepseek_v2 import ( + is_casual_mask, + rotate_half, + scaled_dot_product_attention, + yarn_find_correction_range, + yarn_get_mscale, + yarn_linear_ramp_mask, +) +from paddleformers.transformers.deepseek_v2.fp8_linear import Linear as Linear_ +from paddleformers.transformers.fp8_utils import ( + FP8KeepXLinear, + FP8Linear, + FP8LinearFunction, + FP8LinearFunctionBase, + FP8Mlp, + cache_fp8_weight, + set_parameter_color, +) +from paddleformers.transformers.llama.modeling import get_use_casual_mask +from paddleformers.transformers.model_outputs import BaseModelOutputWithPastAndMTP +from paddleformers.transformers.model_utils import ( + PretrainedModel, + dtype_guard, + register_base_model, +) +from paddleformers.transformers.utils import device_guard +from paddleformers.utils.initializer import kaiming_uniform_ +from paddleformers.utils.log import logger + +try: + import fused_ln + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + + +try: + from paddle.incubate.nn.functional import fused_partial_rope +except ImportError: + fused_partial_rope = None + + +__all__ = [ + "DeepseekV2LMHead", + "set_global_step", + "get_global_step", + "DeepseekV2PretrainingCriterionFast", + "DeepseekV2ModelFast", + "DeepseekV2PretrainedModelFast", +] + +global_step = 0 + + +def set_global_step(cur_step): + global global_step + global_step = cur_step + + +def get_global_step(): + global global_step + return global_step + + +def rms_norm_fused(x_in, w, eps, use_fast_ln=False): + if use_fast_ln: + fast_ln = try_import("fast_ln") + return fast_ln.fast_rms_norm(x_in, w, eps)[0] + else: + fused_ln = try_import("fused_ln") + return fused_ln.fused_rms_norm(x_in, w, eps)[0] + + +def cast_if_needed(x, dtype): + """ + cast_if_needed + """ + return x.cast(dtype) if x.dtype != dtype else x + + +def fusion_rms_norm(hidden_states, weight, variance_epsilon, use_fast_ln=False): + if get_env_device() == "npu": + return paddle.base.core.eager._run_custom_op("rms_norm_npu", hidden_states, weight, variance_epsilon)[0] + if get_env_device() == "mlu": + return paddle.base.core.eager._run_custom_op("rms_norm_mlu", hidden_states, weight, variance_epsilon)[0] + elif get_env_device() == "gcu": + return paddle.base.core.eager._run_custom_op("rms_norm_gcu", hidden_states, weight, variance_epsilon)[0] + elif get_env_device() == "intel_hpu": + return paddle.incubate.nn.functional.fused_rms_norm( + hidden_states, weight, None, variance_epsilon, hidden_states.dim() - 1 + )[0] + elif get_env_device() == "xpu": + try: + import paddle_xpu_nn # noqa: F821 + + return paddle_xpu_nn.xpu_rms_norm(hidden_states, weight, variance_epsilon)[0] + except ImportError: + raise NotImplementedError( + f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature" + ) + return rms_norm_fused(hidden_states, weight, variance_epsilon, use_fast_ln) + + +class LMHeadFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, weight, transpose_y): + out = paddle.matmul(x, weight, transpose_y=transpose_y) + + ctx.save_for_backward(x, weight, transpose_y) + return out + + @staticmethod + def backward(ctx, dout): + if dout.dtype == paddle.float32: + dout = dout.cast(paddle.bfloat16) + + x, weight, transpose_y = ctx.saved_tensor() + + dx = paddle.matmul(dout, weight, transpose_y=not transpose_y) + if transpose_y: + with paddle.amp.auto_cast(False): + paddle._C_ops.fused_linear_param_grad_add( + dout.reshape([-1, dout.shape[-1]]), + x.reshape([-1, x.shape[-1]]), + weight.main_grad, + None, + True, + False, + ) + else: + with paddle.amp.auto_cast(False): + paddle._C_ops.fused_linear_param_grad_add( + x.reshape([-1, x.shape[-1]]), + dout.reshape([-1, dout.shape[-1]]), + weight.main_grad, + None, + True, + False, + ) + return dx, None + + +def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_output=True): + is_fleet_init = True + tensor_parallel_degree = 1 + try: + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + tensor_parallel_degree = hcg.get_model_parallel_world_size() + except AttributeError: + is_fleet_init = False + + if paddle.in_dynamic_mode(): + y_is_distributed = y.is_distributed + else: + y_is_distributed = tensor_parallel_degree > 1 + + if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed: + # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg' + input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group) + logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y) + + if tensor_parallel_output: + return logits + + return paddle.distributed.collective._c_concat(logits, group=model_parallel_group) + + else: + logits = LMHeadFunction.apply(x, y, transpose_y=transpose_y) + return logits + + +class DeepseekV2MLP(nn.Layer): + def __init__(self, config: DeepseekV2FastConfig, hidden_size=None, intermediate_size=None, is_moe=False): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + self.fuse_attention_ffn = config.fuse_attention_ffn + Linear = FP8Linear if self.config.dsv3_use_fp8_gemm else Linear_ + + def linear_dtype_gaurd(): + if config.use_fp8: + return dtype_guard("float8_e4m3fn") + else: + return contextlib.nullcontext() + + if config.sequence_parallel: + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear + else: + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear + + with linear_dtype_gaurd(): + if config.tensor_parallel_degree > 1 and not is_moe: + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + input_is_parallel=True, + has_bias=False, + ) + else: + if config.fuse_attention_ffn: + self.gate_up_fused_proj = Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) + else: + self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.fuse_attention_ffn: + x = swiglu(self.gate_up_fused_proj(x)) + else: + x = swiglu(self.gate_proj(x), self.up_proj(x)) + out = self.down_proj(x) + return out + + +class MoEGate(PretrainedMoEGate): + def __init__( + self, + config, + num_experts, + expert_hidden_size, + using_post_norm_recompute=False, + norm_weight=None, + norm_eps=None, + **kwargs + ): + super().__init__(config, num_experts, expert_hidden_size, **kwargs) + # [hidden_size, n_expert] + + self.scoring_func = config.scoring_func + self.topk_method = config.topk_method + + self.weight = paddle.create_parameter( + shape=[expert_hidden_size, num_experts], + dtype=paddle.float32, + is_bias=False, + # default_initializer=nn.initializer.Constant(1.0), + ) + + self.config = config + self.using_post_norm_recompute = using_post_norm_recompute + + if config.topk_method == "noaux_tc": + self.e_score_correction_bias = paddle.create_parameter( + shape=[num_experts], + dtype=paddle.float32, + default_initializer=nn.initializer.Constant(0.0), + ) + self.e_score_correction_bias.is_distributed = True + + if self.using_post_norm_recompute: + assert norm_weight is not None and norm_eps is not None + self.norm_weight = norm_weight + self.norm_eps = norm_eps + self.using_flex_token = False + + def forward(self, hidden_states): + """ + Args: + hidden_states (_type_): [batch_size * seq_len, hidden_size] + """ + _, _, h_dim = hidden_states.shape + + # compute gating score + if self.using_post_norm_recompute: + logits, norm_out = FusedNormGateFunc.apply(hidden_states, self.norm_weight, self.weight, self.norm_eps) + if hasattr(self.config, "using_fake_gate") and self.config.using_fake_gate: + logits = FakeGate.apply( + hidden_states, + self.weight, + self.config.fakse_gate_restrict_balance, + self.config.num_experts_per_tok, + ) + else: + with paddle.amp.auto_cast(False): + hidden_states = hidden_states.cast(self.weight.dtype) + if hasattr(self.config, "using_fake_gate") and self.config.using_fake_gate: + logits = FakeGate.apply( + hidden_states, + self.weight, + self.config.fakse_gate_restrict_balance, + self.config.num_experts_per_tok, + ) + else: + logits = F.linear(hidden_states, self.weight, None) + + scores = self.gate_score_func(logits=logits) + scores = scores.cast(paddle.float32) + + # Compute all possible return values + if self.using_flex_token: + scores, routing_map, exp_counts, l_aux, l_zloss = self.topkgating_nodrop(scores) + ret = (scores, routing_map, l_aux, l_zloss) + else: + ret = self.topkgating(scores) # (capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss) + + # Append norm_out if needed + if self.using_post_norm_recompute: + ret = (*ret, norm_out) + + return ret + + +class DeepseekV2MoE(MoELayer): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config: DeepseekV2FastConfig, norm_weight=None, norm_eps=None): + assert config.tensor_parallel_degree <= 1, "tensor_parallel_degree should be 1" + + self.using_post_norm_recompute = config.using_post_norm_recompute + if self.using_post_norm_recompute: + assert norm_weight is not None and norm_eps is not None + + gate = MoEGate( + config=config, + num_experts=config.n_routed_experts, + expert_hidden_size=config.hidden_size, + top_k=config.num_experts_per_tok, + topk_method=config.topk_method, + n_group=config.n_group, + topk_group=config.topk_group, + norm_topk_prob=config.norm_topk_prob, + routed_scaling_factor=config.routed_scaling_factor, + drop_tokens=False, + using_post_norm_recompute=self.using_post_norm_recompute, + norm_weight=norm_weight, + norm_eps=norm_eps, + ) + DeepseekV2MLPClass = FP8Mlp if config.dsv3_use_fp8_gemm else DeepseekV2MLP + + super().__init__( + config=config, + moe_num_experts=config.n_routed_experts, + expert_class=DeepseekV2MLPClass, + expert_kwargs={ + "config": config, + "intermediate_size": config.moe_intermediate_size, + "is_moe": True, + }, + gate=gate, + capacity=2.0, + moe_group="expert", + using_post_norm_recompute=self.using_post_norm_recompute, + ) + + if config.offline_quant_expert_weight and config.clear_origin_weight_when_offline_quant: + moe_grad_group = fleet.get_hybrid_communicate_group().expert_grad_comm_group + expert_w1_list = [expert.w1 for expert in self.experts if expert is not None] + expert_w2_list = [expert.w2 for expert in self.experts if expert is not None] + for p in expert_w1_list: + setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) + for p in expert_w2_list: + setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) + + self.alpha = config.aux_loss_alpha + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + if self.using_post_norm_recompute: + assert DeepseekV2MLPClass is FP8Mlp + self.shared_experts = DeepseekV2MLPClass( + config=config, + intermediate_size=intermediate_size, + is_moe=False, + using_post_norm_recompute=self.using_post_norm_recompute, + norm_weight=norm_weight, + norm_eps=norm_eps, + recompute_fwd_gate_up=True, + ) + else: + self.shared_experts = DeepseekV2MLPClass( + config=config, intermediate_size=intermediate_size, is_moe=False + ) + set_parameter_color([self.shared_experts.w1, self.shared_experts.w2], "shared_expert") + + def fp8_quant_weight(self, batch_mode=False, quant_transpose=None): + """Quantize weights in FP8 format. + + Args: + batch_mode: If True, quantize all weights in batch mode using the first expert's weights. + If False, quantize each expert's weights individually. + """ + + def quantize_weights(weight_list, weight_obj=None, quant_transpose=None): + """Helper function to quantize a list of weights.""" + if weight_obj is None: + weight_obj = weight_list[0] + if hasattr(weight_obj, "fp8_weight_stacked") or hasattr(weight_obj, "fp8_weight_stacked_transpose"): + return + + if quant_transpose is None: + fp8_weight, fp8_scale = paddle.incubate.nn.functional.fused_stack_transpose_quant( + weight_list, transpose=False + ) + setattr(weight_obj, "fp8_weight_stacked", fp8_weight) + setattr(weight_obj, "fp8_scale_stacked", fp8_scale) + + fp8_weight_t, fp8_scale_t = paddle.incubate.nn.functional.fused_stack_transpose_quant( + weight_list, transpose=True + ) + setattr(weight_obj, "fp8_weight_stacked_transpose", fp8_weight_t) + setattr(weight_obj, "fp8_scale_stacked_transpose", fp8_scale_t) + elif quant_transpose is False: + # Only quantize without transpose + fp8_weight, fp8_scale = paddle.incubate.nn.functional.fused_stack_transpose_quant( + weight_list, transpose=False + ) + setattr(weight_obj, "fp8_weight_stacked", fp8_weight) + setattr(weight_obj, "fp8_scale_stacked", fp8_scale) + elif quant_transpose is True: + # Only quantize with transpose + fp8_weight_t, fp8_scale_t = paddle.incubate.nn.functional.fused_stack_transpose_quant( + weight_list, transpose=True + ) + setattr(weight_obj, "fp8_weight_stacked_transpose", fp8_weight_t) + setattr(weight_obj, "fp8_scale_stacked_transpose", fp8_scale_t) + else: + raise ValueError("Invalid value for `quant_transpose`.") + + if batch_mode: + # Batch mode: process all experts' weights together + expert_w1_list = [expert.w1 for expert in self.experts if expert is not None] + expert_w2_list = [expert.w2 for expert in self.experts if expert is not None] + + if expert_w1_list: + quantize_weights(expert_w1_list, expert_w1_list[0], quant_transpose) + if expert_w2_list: + quantize_weights(expert_w2_list, expert_w2_list[0], quant_transpose) + else: + # Individual mode: process each expert's weights separately + for expert in self.experts: + if expert is not None: + quantize_weights([expert.w1], quant_transpose=quant_transpose) + quantize_weights([expert.w2], quant_transpose=quant_transpose) + + if self.config.n_shared_experts is not None: + self.shared_experts.fp8_quant_weight(quant_transpose) + + def forward(self, hidden_states): + if self.using_post_norm_recompute: + super().update_flex_token() + if self.using_flex_token: + probs, routing_map, l_aux, l_zloss, norm_out = self.router(hidden_states) + final_hidden_states, l_aux, l_zloss = super().forward( + norm_out, probs=probs, routing_map=routing_map, l_aux=l_aux, l_zloss=l_zloss + ) + else: + capacity, topk_weight, topk_ids, token_priority, l_aux, l_zloss, norm_out = self.gate(hidden_states) + final_hidden_states, l_aux, l_zloss = super().forward( + norm_out, + capacity=capacity, + topk_weight=topk_weight, + topk_ids=topk_ids, + token_priority=token_priority, + l_aux=l_aux, + l_zloss=l_zloss, + ) + final_hidden_states = self.post_process(hidden_states, final_hidden_states, l_aux) + else: + final_hidden_states, l_aux, l_zloss = super().forward(hidden_states) + final_hidden_states = self.post_process(hidden_states, final_hidden_states, l_aux) + return final_hidden_states + + def post_process(self, hidden_states, final_hidden_states, l_aux): + if self.training and self.alpha > 0.0: + l_aux = l_aux * self.alpha + final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, l_aux) + + if self.config.n_shared_experts is not None: + shared_expert_output = self.shared_experts(hidden_states) + final_hidden_states = final_hidden_states + shared_expert_output + return final_hidden_states + + +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 +class DeepseekV2Attention(nn.Layer): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DeepseekV2FastConfig, layerwise_recompute: bool = False, recompute_fa3: bool = False): + super().__init__() + self.config = config + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.is_causal = True + self.fuse_rope = config.use_fused_rope + + if config.num_nextn_predict_layers > 0: + self.seq_length = config.seq_length - config.num_nextn_predict_layers + else: + self.seq_length = config.seq_length + self.sequence_parallel = config.sequence_parallel + + self.recompute_fa3 = recompute_fa3 + self.fa_version = config.fa_version + + self.input_layernorm = DeepseekV2RMSNorm(config) + + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + + def linear_dtype_gaurd(): + if config.use_fp8: + return dtype_guard("float8_e4m3fn") + else: + return contextlib.nullcontext() + + # Note (@DrownFish19): For tensor parallel we consider that q_a_proj and kv_a_proj_with_mqa + # are the small weight and cannot achieve performance gain. So we use the original + # linear layers. We use the tensor parallel linear layers for q_proj,q_b_proj and kv_b_proj + # for which are the large weight and can achieve performance gain. + + self._init_rope() + self.softmax_scale = self.q_head_dim ** (-0.5) + Linear = FP8Linear if self.config.dsv3_use_fp8_gemm else Linear_ + + # fmt: off + if self.config.tensor_parallel_degree > 1: + # for tensor parallel + if config.sequence_parallel: + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear + else: + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear + + if self.q_lora_rank is None: + with linear_dtype_gaurd(): + self.q_proj = ColumnParallelLinear(self.hidden_size, self.num_heads * self.q_head_dim, has_bias=False, gather_output=True) + else: + with linear_dtype_gaurd(): + self.q_a_proj = Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias) + self.q_b_proj = ColumnParallelLinear(config.q_lora_rank, self.num_heads * self.q_head_dim, has_bias=False, gather_output=True) + self.q_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.q_lora_rank, use_sequence_parallel=False) + + with linear_dtype_gaurd(): + self.kv_a_proj_with_mqa = paddle.nn.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias) + self.kv_b_proj = ColumnParallelLinear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), has_bias=False, gather_output=True) + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, self.hidden_size, has_bias=config.attention_bias, input_is_parallel=False) + self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank, use_sequence_parallel=False) + else: + # for without tensor parallel + if self.config.dsv3_use_atten_recompute: + self.fused_rms_norm_linear = FusedRMSLinear(self.hidden_size, config.q_lora_rank, config.kv_lora_rank + config.qk_rope_head_dim, 1e-6) + kv_up_dim = self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim) + self.memory_recompute_att = MemroyRecomputeAttn(config.q_lora_rank, config.kv_lora_rank, config.q_lora_rank, self.num_heads * self.q_head_dim, config.kv_lora_rank, kv_up_dim, self.rotary_emb, self.num_heads, self.q_head_dim, self.qk_nope_head_dim, self.v_head_dim, self.qk_rope_head_dim, 1e-6, self.kv_lora_rank, self.softmax_scale, recompute_fa3=self.recompute_fa3, fa_version=self.fa_version) + self.o_proj = FP8KeepXLinear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias) + else: + + if self.q_lora_rank is None: + with linear_dtype_gaurd(): + self.q_proj = Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias_attr=False) + else: + with linear_dtype_gaurd(): + self.q_a_proj = Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias) + self.q_b_proj = Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias_attr=False) + self.q_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.q_lora_rank) + + with linear_dtype_gaurd(): + self.kv_a_proj_with_mqa = paddle.nn.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias) + self.kv_b_proj = Linear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias_attr=False) + self.o_proj = Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias) + self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank) + + # fmt: on + self.softmax_scale = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + self.attn_func = scaled_dot_product_attention + + def fp8_quant_weight(self, quant_transpose=None): + + if self.config.dsv3_use_atten_recompute: + self.o_proj.fp8_quant_weight(quant_transpose=quant_transpose) + self.memory_recompute_att.fp8_quant_weight(quant_transpose=quant_transpose) + self.fused_rms_norm_linear.fp8_quant_weight(quant_transpose=quant_transpose) + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = DeepseekV2RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "yarn": + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = DeepseekV2YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: paddle.Tensor, seq_len: int, bsz: int): + return tensor.reshape([bsz, seq_len, self.num_heads, self.v_head_dim]).transpose([1, 0, 2, 3]) + + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.shape + + # DeepSeekV2 q_lora_rank=1536 + # DeepSeekV2-lite q_lora_rank=None + if self.config.dsv3_use_atten_recompute: + + q_t1, compressed_kv = self.fused_rms_norm_linear(hidden_states) + + outputs = self.memory_recompute_att(q_t1, compressed_kv, position_ids) + + if self.v_head_dim * self.num_heads != outputs.shape[-1]: + outputs = outputs.reshape([bsz, q_len, self.num_heads, -1]) + outputs = outputs[..., : self.v_head_dim] + outputs = outputs.reshape([bsz, q_len, -1]) + else: + hidden_states = self.input_layernorm(hidden_states) + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + + if self.sequence_parallel: + target_query_shape = [-1, self.seq_length, self.num_heads, self.q_head_dim] + target_key_value_shape = [-1, self.seq_length, self.num_heads, self.qk_nope_head_dim + self.v_head_dim] + else: + target_query_shape = [0, 0, self.num_heads, self.q_head_dim] + target_key_value_shape = [0, 0, self.num_heads, self.qk_nope_head_dim + self.v_head_dim] + + q = q.reshape(shape=target_query_shape) + q_nope, q_pe = paddle.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) + + # DeepSeekV2 kv_lora_rank+qk_rope_head_dim=512+64 + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = paddle.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], axis=-1) + if self.sequence_parallel: + k_pe = GatherOp.apply(k_pe) + k_pe = k_pe.reshape([-1, q_len, 1, self.qk_rope_head_dim]).expand( + [-1, q_len, self.num_heads, self.qk_rope_head_dim] + ) + + # self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim = 128+64 + # self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim) = config.qk_nope_head_dim + self.v_head_dim = 128+128 + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).reshape(shape=target_key_value_shape) + + k_nope, value_states = paddle.split(kv, [self.qk_nope_head_dim, self.v_head_dim], axis=-1) + kv_seq_len = value_states.shape[1] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-3] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids, self.fuse_rope) + + query_states = paddle.concat([q_nope, q_pe], axis=-1) + key_states = paddle.concat([k_nope, k_pe], axis=-1) + + # [bs, seq_len, num_head, head_dim] + if past_key_value is not None: + # reuse k, v, self_attention + key_states = paddle.concat([past_key_value[0], key_states], axis=1) + value_states = paddle.concat([past_key_value[1], value_states], axis=1) + past_key_value = (key_states, value_states) if use_cache else None + + has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient) + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "core_attn" + ): + outputs = recompute( + self.attn_func, + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + softmax_scale=self.softmax_scale, + training=self.training, + sequence_parallel=self.sequence_parallel, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = self.attn_func( + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + softmax_scale=self.softmax_scale, + training=self.training, + sequence_parallel=self.sequence_parallel, + ) + if output_attentions: + attn_output, attn_weights = outputs + else: + attn_output = outputs + + # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim] + # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism. + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + outputs = (attn_output,) + + if output_attentions: + outputs += (attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class DeepseekV2DecoderLayer(nn.Layer): + def __init__( + self, + config: DeepseekV2FastConfig, + layer_idx: int, + layerwise_recompute: bool = False, + recompute_fa3: bool = False, + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + self.using_post_norm_recompute = config.using_post_norm_recompute + + self.hidden_size = config.hidden_size + + self.self_attn = DeepseekV2Attention( + config=config, layerwise_recompute=layerwise_recompute, recompute_fa3=recompute_fa3 + ) + + DeepseekV2MLPClass = FP8Mlp if self.config.dsv3_use_fp8_gemm else DeepseekV2MLP + + self.input_layernorm = DeepseekV2RMSNorm(config) + self.post_attention_layernorm = DeepseekV2RMSNorm(config) + + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ): + self.mlp = ( + DeepseekV2MoE( + config, self.post_attention_layernorm.weight, self.post_attention_layernorm.variance_epsilon + ) + if config.using_post_norm_recompute + else DeepseekV2MoE(config) + ) + else: + self.mlp = DeepseekV2MLPClass(config, recompute_fwd_gate_up=True) + + def fp8_quant_weight(self, batch_mode=False, quant_transpose=None): + """fp8_quant_weight""" + if isinstance(self.mlp, DeepseekV2MoE): + # logger.info(f"fp8 quant weight for mlp {type(self.mlp)}") + self.mlp.fp8_quant_weight(batch_mode, quant_transpose=quant_transpose) + self.self_attn.fp8_quant_weight(quant_transpose=quant_transpose) + elif isinstance(self.mlp, FP8Mlp): + self.self_attn.fp8_quant_weight(quant_transpose=quant_transpose) + + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + """ + Args: + hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_axis)` + attention_mask (`paddle.Tensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + residual = hidden_states + + # Self Attention + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "full_attn" + ): + outputs = recompute( + self.self_attn, + hidden_states=hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + past_key_value=past_key_value, + use_cache=use_cache, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + **kwargs, + ) + else: + outputs = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + past_key_value=past_key_value, + use_cache=use_cache, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + **kwargs, + ) + + if type(outputs) is tuple: + hidden_states = outputs[0] + else: + hidden_states = outputs + + if output_attentions: + self_attn_weights = outputs[1] + + if use_cache: + present_key_value = outputs[2 if output_attentions else 1] + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + + if not (self.using_post_norm_recompute and isinstance(self.mlp, DeepseekV2MoE)): + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + def self_attn_compute(self, hidden_states, **kwargs): + residual = hidden_states + + # Self Attention + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "full_attn" + ): + outputs = recompute( + self.self_attn, + hidden_states=hidden_states, + position_ids=None, + attention_mask=None, + output_attentions=False, + past_key_value=None, + use_cache=False, + attn_mask_startend_row_indices=None, + **kwargs, + ) + else: + outputs = self.self_attn( + hidden_states=hidden_states, + position_ids=None, + attention_mask=None, + output_attentions=False, + past_key_value=None, + use_cache=False, + attn_mask_startend_row_indices=None, + **kwargs, + ) + + if type(outputs) is tuple: + hidden_states = outputs[0] + else: + hidden_states = outputs + + hidden_states = residual + hidden_states + + residual = hidden_states + + if not (self.using_post_norm_recompute and isinstance(self.mlp, DeepseekV2MoE)): + hidden_states = self.post_attention_layernorm(hidden_states) + + return hidden_states, residual + + def pre_dispatch_compute(self, hidden_states): + l_aux, l_zloss, intermediate_hidden_states, token_indices, token_probs = self.mlp.pre_dispatch_compute( + hidden_states + ) + + return l_aux, l_zloss, intermediate_hidden_states, token_indices, token_probs + + def expert_forward_compute(self, intermediate_hidden_states, dispatched_indices, dispatched_probs): + (global_input_tokens, token_permuted_indices, prob_permuted_indices) = self.mlp.post_dispatch_compute( + intermediate_hidden_states, dispatched_indices, dispatched_probs + ) + + expert_output = self.mlp.expert_forward(global_input_tokens) + + expert_output = self.mlp.pre_combine_compute( + expert_output, token_permuted_indices, prob_permuted_indices, dispatched_probs + ) + + return expert_output + + def post_combine_compute(self, residual, hidden_states, final_hidden_states, l_aux): + final_hidden_states = self.mlp.post_combine_compute(final_hidden_states) + + final_hidden_states = self.mlp.post_process(hidden_states, final_hidden_states, l_aux) + + final_hidden_states = residual + final_hidden_states + + outputs = (final_hidden_states,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class DeepseekV2MTPLayer(DeepseekV2DecoderLayer): + def __init__( + self, + config: DeepseekV2FastConfig, + layer_idx: int, + layerwise_recompute: bool = False, + ): + super(DeepseekV2MTPLayer, self).__init__(config, layer_idx, layerwise_recompute) + + self.enorm = DeepseekV2RMSNorm(config) + self.hnorm = DeepseekV2RMSNorm(config) + self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias_attr=False) + + def forward( + self, + hidden_states: paddle.Tensor, + nextn_hidden_state: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + hidden_states = self.hnorm(hidden_states) + nextn_hidden_state = self.enorm(nextn_hidden_state) + + concat_h = paddle.concat([nextn_hidden_state, hidden_states], axis=-1) + hidden_states = FP8LinearFunction.apply(concat_h, self.eh_proj) + + layer_outputs = super(DeepseekV2MTPLayer, self).forward( + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + **kwargs, + ) + + if type(layer_outputs) is tuple: + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + return hidden_states + + +class DeepseekV2PretrainedModelFast(PretrainedModel): + config_class = DeepseekV2FastConfig + base_model_prefix = "deepseek_v2" + _no_split_modules = ["DeepseekV2DecoderLayer"] + + def _get_model_flops(self, batch_size=1, seq_length=None, **kwargs): + from paddleformers.transformers.deepseek_v2.mfu_utils import DeepSeekProjection + + # self._ + mfu_cal_proj = DeepSeekProjection(self.config) + if seq_length is None: + if hasattr(self.config, "seq_length"): + seq_length = self.config.seq_length + else: + seq_length = 2048 + + return mfu_cal_proj.get_num_flop_per_token() + + def _get_hardware_flops(self, *args, **kwargs): + return self._get_model_flops(*args, **kwargs) + + @classmethod + def _get_name_mappings(cls, config: DeepseekV2FastConfig) -> list[StateDictNameMapping]: + mappings: list[StateDictNameMapping] = [] + model_mappings = [ + ["embed_tokens.weight"], + ["norm.weight"], + ] + # last one layer contains MTP (eagle) parameters for inference + for layer_index in range(config.num_hidden_layers + config.num_nextn_predict_layers): + layer_mappings = [ + [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.q_a_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.q_a_layernorm.weight"], + [f"layers.{layer_index}.self_attn.q_b_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.kv_a_proj_with_mqa.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.kv_a_layernorm.weight"], + [f"layers.{layer_index}.self_attn.kv_b_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + model_mappings.extend(layer_mappings) + + # MoE parameters + model_mappings.append([f"layers.{layer_index}.mlp.gate.weight", None, "transpose"]) + model_mappings.append([f"layers.{layer_index}.mlp.gate.e_score_correction_bias"]) + for expert_idx in range(config.n_routed_experts): + expert_mappings = [ + [f"layers.{layer_index}.mlp.experts.{expert_idx}.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.experts.{expert_idx}.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.experts.{expert_idx}.down_proj.weight", None, "transpose"], + ] + model_mappings.extend(expert_mappings) + model_mappings.append([f"layers.{layer_index}.mlp.shared_experts.gate_proj.weight", None, "transpose"]) + model_mappings.append([f"layers.{layer_index}.mlp.shared_experts.up_proj.weight", None, "transpose"]) + model_mappings.append([f"layers.{layer_index}.mlp.shared_experts.down_proj.weight", None, "transpose"]) + + # MTP (eagle) parameters for inference + if layer_index >= config.num_hidden_layers: + model_mappings.append([f"layers.{layer_index}.embed_tokens.weight"]) + model_mappings.append([f"layers.{layer_index}.enorm.weight"]) + model_mappings.append([f"layers.{layer_index}.hnorm.weight"]) + model_mappings.append([f"layers.{layer_index}.eh_proj.weight", None, "transpose"]) + model_mappings.append([f"layers.{layer_index}.shared_head.norm.weight"]) + model_mappings.append([f"layers.{layer_index}.shared_head.head.weight", None, "transpose"]) + + init_name_mappings(mappings=model_mappings) + if cls.base_model_class.__name__ not in config.architectures: + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = f"{cls.base_model_prefix}." + mapping[1] + if not config.tie_word_embeddings: + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + + @classmethod + def _get_tensor_parallel_mappings(cls, config: DeepseekV2FastConfig, is_split=True): + from paddleformers.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + + base_actions = { + # Row Linear + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + } + if config.use_fp8: + base_actions["layers.0.self_attn.o_proj.weight.weight_scale_inv"] = partial(fn, is_column=False) + + if config.tie_word_embeddings: + base_actions["lm_head.weight"] = partial(fn, is_column=False) + else: + base_actions["lm_head.weight"] = partial(fn, is_column=True) + + if not config.vocab_size % config.tensor_parallel_degree == 0: + base_actions.pop("lm_head.weight") + base_actions.pop("embed_tokens.weight") + + # Column Linear + base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.q_b_proj.weight"] = partial(fn, is_column=True) + + # if we have enough num_key_value_heads to split, then split it. + if config.num_key_value_heads % config.tensor_parallel_degree == 0: + base_actions["layers.0.self_attn.kv_b_proj.weight"] = partial(fn, is_column=True) + if config.use_fp8: + base_actions["layers.0.self_attn.kv_b_proj.weight.weight_scale_inv"] = partial(fn, is_column=True) + + # dense mlp + base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.down_proj.weight"] = partial(fn, is_column=False) + if config.use_fp8: + base_actions["layers.0.mlp.up_proj.weight.weight_scale_inv"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.gate_proj.weight.weight_scale_inv"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.down_proj.weight.weight_scale_inv"] = partial(fn, is_column=False) + + # moe unit routed experts + moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + expert_parallel_degree = dist.get_world_size(moe_group) + if expert_parallel_degree <= 1: + for e_i in range(config.n_routed_experts): + base_actions[f"layers.0.mlp.experts.{e_i}.up_proj.weight"] = partial(fn, is_column=True) + base_actions[f"layers.0.mlp.experts.{e_i}.gate_proj.weight"] = partial(fn, is_column=True) + base_actions[f"layers.0.mlp.experts.{e_i}.down_proj.weight"] = partial(fn, is_column=False) + + # moe unit shared experts + base_actions["layers.0.mlp.shared_experts.gate_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.shared_experts.up_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.shared_experts.down_proj.weight"] = partial(fn, is_column=False) + if config.use_fp8: + base_actions["layers.0.mlp.shared_experts.gate_proj.weight.weight_scale_inv"] = partial( + fn, is_column=True + ) + base_actions["layers.0.mlp.shared_experts.up_proj.weight.weight_scale_inv"] = partial( + fn, is_column=True + ) + base_actions["layers.0.mlp.shared_experts.down_proj.weight.weight_scale_inv"] = partial( + fn, is_column=False + ) + + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + # for MTP (eagle) parameters for inference + base_actions.pop("embed_tokens.weight") + base_actions.pop("lm_head.weight") + base_actions["layers.0.embed_tokens.weight"] = partial(fn, is_column=False) + base_actions["layers.0.shared_head.head.weight"] = partial(fn, is_column=True) + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range( + config.num_hidden_layers, config.num_hidden_layers + config.num_nextn_predict_layers + ): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + else: + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + + return mappings + + def _init_weights(self, layer): + if self.config.tensor_parallel_degree > 1: + rng_tracker = get_rng_state_tracker().rng_state + Linear = FP8Linear if self.config.dsv3_use_fp8_gemm else Linear_ + + if isinstance( + layer, + ( + nn.Linear, + nn.Embedding, + mpu.VocabParallelEmbedding, + mpu.RowParallelLinear, + mpu.ColumnParallelLinear, + linear_utils.RowSequenceParallelLinear, + linear_utils.ColumnSequenceParallelLinear, + Linear, + ), + ): + # In the dygraph mode, use the `set_value` to reset the parameter directly, + # and reset the `state_dict` to update parameter in static mode. + if isinstance(layer.weight, paddle.Tensor): + if layer.weight.is_distributed: + with rng_tracker(): + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.initializer_range, + shape=layer.weight.shape, + ) + ) + else: + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.initializer_range, + shape=layer.weight.shape, + ) + ) + + # set bias to zeros + if getattr(layer, "bias", None) is not None: + layer.bias.set_value(paddle.zeros(shape=layer.bias.shape)) + + if isinstance(layer, nn.Embedding): + if layer._padding_idx is not None: + layer.weight.data[layer._padding_idx].fill_(0) + + if isinstance(layer, MoEGate): + kaiming_uniform_(layer.weight, a=math.sqrt(5)) + + moe_grad_group = fleet.get_hybrid_communicate_group().expert_grad_comm_group + if moe_grad_group is not None and moe_grad_group.nranks > 1: + for p in layer.parameters(): + if hasattr(p, "color") and "color" in p.color: + if p.color["color"] == "moe_expert": + paddle.distributed.broadcast(p, src=moe_grad_group.ranks[0], group=moe_grad_group) + + def step_flex_token(self, cur_step): + set_global_step(cur_step) + + +@register_base_model +class DeepseekV2ModelFast(DeepseekV2PretrainedModelFast): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`] + + Args: + config: DeepseekV2FastConfig + """ + + def __init__(self, config: DeepseekV2FastConfig): + super().__init__(config) + + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # Recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.recompute_granularity = config.recompute_granularity + self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + + if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: + self.embed_tokens = mpu.VocabParallelEmbedding(config.vocab_size, config.hidden_size) + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + + self.layers = nn.LayerList( + [ + DeepseekV2DecoderLayer(config, layer_idx, layer_idx not in self.no_recompute_layers) + for layer_idx in range(config.num_hidden_layers) + ] + ) + for layer_idx in range(config.num_hidden_layers, config.num_hidden_layers + config.num_nextn_predict_layers): + self.layers.append(DeepseekV2MTPLayer(config, layer_idx, layer_idx not in self.no_recompute_layers)) + + self.norm = DeepseekV2RMSNorm(config) + + self.enable_recompute = False + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @staticmethod + def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype): + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if len(attention_mask.shape) == 2: + expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1]) + # For decoding phase in generation, seq_length = 1, we don't need to add causal mask + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + past_key_values_length=past_key_values_length, + ) + expanded_attn_mask = expanded_attn_mask & combined_attention_mask + # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len] + elif len(attention_mask.shape) == 3: + expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool") + # if attention_mask is already 4-D, do nothing + else: + expanded_attn_mask = attention_mask + else: + expanded_attn_mask = _make_causal_mask( + input_shape, + past_key_values_length=past_key_values_length, + ) + # Convert bool attention_mask to float attention mask, which will be added to attention_scores later + if get_env_device() == "xpu": + x = paddle.to_tensor(0.0, dtype="float32") + y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32") + expanded_attn_mask = paddle.where(expanded_attn_mask, x, y) + else: + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min).astype( + dtype + ) + return expanded_attn_mask + + @paddle.jit.not_to_static + def recompute_training_full( + self, + layer_module: nn.Layer, + hidden_states: Tensor, + position_ids: Optional[Tensor], + attention_mask: Tensor, + output_attentions: bool, + past_key_value: Tensor, + use_cache: bool, + attn_mask_startend_row_indices: Optional[Tensor] = None, + ): + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = recompute( + create_custom_forward(layer_module), + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + use_reentrant=self.config.recompute_use_reentrant, + ) + + return hidden_states + + def forward( + self, + input_ids: paddle.Tensor = None, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + attn_mask_startend_row_indices: Optional[Tensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPastAndMTP]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.config.num_nextn_predict_layers > 0: + seq_length -= self.config.num_nextn_predict_layers + + if attention_mask is not None: + attention_mask = attention_mask[ + :, :, : -self.config.num_nextn_predict_layers, : -self.config.num_nextn_predict_layers + ] + + if self.enable_recompute and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." + ) + use_cache = False + + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + # NOTE: to make cache can be clear in-time + past_key_values = list(past_key_values) + + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[1] + seq_length_with_past += past_key_values_length + + if position_ids is None: + position_ids = paddle.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=paddle.int64 + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + # [bs, seq_len, dim] + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + if attn_mask_startend_row_indices is not None or get_use_casual_mask(): + attention_mask = None + else: + # [bs, seq_len] + attention_mask = ( + paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + if attention_mask is None + else attention_mask + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), past_key_values_length, inputs_embeds.dtype + ) # [bs, 1, seq_len, seq_len] + if self.config.use_flash_attention: + attention_mask = None if is_casual_mask(attention_mask) else attention_mask + + if self.config.num_nextn_predict_layers > 0: + inputs_embeds_extra = inputs_embeds[:, -self.config.num_nextn_predict_layers :, :] # [B, S, D] + inputs_embeds = inputs_embeds[:, : -self.config.num_nextn_predict_layers, :] + inputs_embeds_ori = inputs_embeds + + if self.config.sequence_parallel: + # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] + bs, seq_len, hidden_size = inputs_embeds.shape + inputs_embeds = paddle.reshape(inputs_embeds, [bs * seq_len, hidden_size]) + # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) + inputs_embeds = ScatterOp.apply(inputs_embeds) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + mtp_outputs = [] + + for idx in range(self.config.num_hidden_layers): + decoder_layer = self.layers[idx] + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and idx not in self.no_recompute_layers + and has_gradient + and self.recompute_granularity == "full" + ): + layer_outputs = self.recompute_training_full( + decoder_layer, + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + ) + + # NOTE: clear outdate cache after it has been used for memory saving + past_key_value = past_key_values[idx] = None + if type(layer_outputs) is tuple: + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if self.config.num_nextn_predict_layers > 0: + mtp_outputs.append(hidden_states) + + for nextn in range(self.config.num_nextn_predict_layers): + decoder_layer = self.layers[nextn + self.config.num_hidden_layers] + + if self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + hidden_states = hidden_states.reshape([-1, seq_length, hidden_states.shape[-1]]) + + inputs_embeds_cur_depth = paddle.concat( + [inputs_embeds_ori[:, (nextn + 1) :, :], inputs_embeds_extra[:, : (nextn + 1), :]], axis=1 + ) + + past_key_value = None + layer_outputs = decoder_layer( + hidden_states, + inputs_embeds_cur_depth, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + ) + + if isinstance(layer_outputs, (tuple, list)): + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + mtp_outputs.append(hidden_states) + mtp_outputs = [self.norm(hidden_states) for hidden_states in mtp_outputs] + hidden_states, mtp_outputs = mtp_outputs[0], mtp_outputs[1:] + else: + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple( + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, mtp_outputs] if v is not None + ) + return BaseModelOutputWithPastAndMTP( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + mtp_outputs=mtp_outputs, + ) + + +class DeepseekV2PretrainingCriterionFast(nn.Layer): + """ + Criterion for Mixtral. + It calculates the final loss. + """ + + def __init__(self, config: DeepseekV2FastConfig): + super(DeepseekV2PretrainingCriterionFast, self).__init__() + self.ignore_index = getattr(config, "ignore_index", -100) + self.config = config + self.enable_parallel_cross_entropy = config.tensor_parallel_degree > 1 and config.tensor_parallel_output + + if self.enable_parallel_cross_entropy: # and False: # and lm_head is distributed + self.loss_func = mpu.ParallelCrossEntropy(ignore_index=self.ignore_index) + else: + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + def forward(self, prediction_scores, masked_lm_labels, router_loss=None, mtp_logits=None): + if self.enable_parallel_cross_entropy: + if prediction_scores.shape[-1] == self.config.vocab_size: + warnings.warn( + f"enable_parallel_cross_entropy, the vocab_size should be splitted: {prediction_scores.shape[-1]}, {self.config.vocab_size}" + ) + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + def compute_loss(preds, labels): + with paddle.amp.auto_cast(False): + masked_lm_loss = FastCrossEntropyFunction.apply(preds, labels.unsqueeze(2)) + binary_sequence = paddle.where( + masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss) + ) + count = paddle.sum(binary_sequence) + loss = paddle.where( + count == 0, + paddle.sum(masked_lm_loss * binary_sequence), + paddle.sum(masked_lm_loss * binary_sequence) / count, + ) + return loss + + def add_loss(main_loss, loss): + return main_loss + loss - loss.detach() + + if mtp_logits is not None and self.config.num_nextn_predict_layers > 0: + assert len(mtp_logits) == self.config.num_nextn_predict_layers + masked_lm_labels_ori = masked_lm_labels + masked_lm_labels = masked_lm_labels[:, : -self.config.num_nextn_predict_layers] + seq_length = masked_lm_labels.shape[1] + loss = compute_loss(prediction_scores, masked_lm_labels) + + mtp_loss_res = [] + for depth in range(self.config.num_nextn_predict_layers): + prediction_scores_cur_depth = mtp_logits[depth] + masked_lm_labels_cur_depth = masked_lm_labels_ori[:, (depth + 1) : (depth + 1 + seq_length)] + res_cur_depth = compute_loss(prediction_scores_cur_depth, masked_lm_labels_cur_depth) + mtp_loss_res.append(res_cur_depth) + loss = add_loss(loss, self.config.num_nextn_predict_lambda * sum([x for x in mtp_loss_res]) / len(mtp_loss_res)) # fmt: skip + + else: + loss = compute_loss(prediction_scores, masked_lm_labels) + + if router_loss is not None and isinstance(router_loss, paddle.Tensor): + loss = add_loss(loss, router_loss) + + return loss + + +class DeepseekV2YarnRotaryEmbedding(DeepseekV2RotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base) + + def _set_cos_sin_cache(self, seq_len): + with paddle.amp.auto_cast(False): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / (self.base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim)) + freq_inter = 1.0 / ( + self.scaling_factor * self.base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) + self.inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + + t = paddle.arange(seq_len, dtype=paddle.float32) + + freqs = paddle.outer(t, paddle.cast(self.inv_freq, dtype="float32")) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = paddle.concat((freqs, freqs), axis=-1) + self.cos_cached = emb.cos() * _mscale + self.sin_cached = emb.sin() * _mscale + + +class DeepseekV2RMSNorm(nn.Layer): + def __init__(self, config: DeepseekV2FastConfig, hidden_size=None, eps=1e-6, use_sequence_parallel=True): + """DeepseekV2RMSNorm is equivalent to T5LayerNorm + + Args: + config (DeepseekV2FastConfig): config dict of DeepseekV2 + hidden_size (_type_): history_states size + eps (_type_, optional): eps value. Defaults to 1e-6. + use_sequence_parallel (bool, optional): A switch to disable sequence parallelism for inputs that are not in tensor parallel mode. + By default, this is set to True. + """ + super().__init__() + self.config = config + self.hidden_size = hidden_size if hidden_size is not None else config.hidden_size + self.variance_epsilon = eps + + self.weight = paddle.create_parameter( + shape=[self.hidden_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(1.0), + ) + + if config.sequence_parallel and use_sequence_parallel: + mark_as_sequence_parallel_parameter(self.weight) + + def forward(self, hidden_states): + if self.config.use_fused_rms_norm: + return fusion_rms_norm(hidden_states, self.weight, self.variance_epsilon, self.config.use_fast_layer_norm) + + with paddle.amp.auto_cast(False): + hidden_states = hidden_states.astype("float32") + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + + if self.weight.dtype in [paddle.float16, paddle.bfloat16]: + hidden_states = paddle.cast(hidden_states, self.weight.dtype) + return hidden_states * self.weight + + def extra_repr(self): + return f"hidden_size={self.hidden_size}, dtype={self.weight.dtype}" + + +class DeepseekV2RotaryEmbedding(nn.Layer): + def __init__(self, dim, max_position_embeddings=2048, base=10000): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + # [dim / 2] + with device_guard("cpu"): + self.inv_freq = 1.0 / ( + self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim) + ) + self._set_cos_sin_cache(seq_len=max_position_embeddings) + + self.max_seq_len_cached = None + + def _set_cos_sin_cache(self, seq_len): + self.max_seq_len_cached = seq_len + # [seq_len] + t = paddle.arange(seq_len, dtype="float32") + # [seq_len, axis/2] + freqs = paddle.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + # [seq_len, axis] + emb = paddle.concat([freqs, freqs], axis=-1) + # [1, seqlen, 1, axis] + self.cos_cached = emb.cos()[None, :, None, :] + self.sin_cached = emb.sin()[None, :, None, :] + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len) + cos = self.cos_cached[:seq_len] + sin = self.sin_cached[:seq_len] + return ( + cos.cast(x.dtype) if cos.dtype != x.dtype else cos, + sin.cast(x.dtype) if sin.dtype != x.dtype else sin, + ) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, fuse_rope=False): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + b, s, h, d = q.shape + q = q.reshape([b, s, h, d // 2, 2]).transpose([0, 1, 2, 4, 3]).reshape([b, s, h, d]) + + b, s, h, d = k.shape + k = k.reshape([b, s, h, d // 2, 2]).transpose([0, 1, 2, 4, 3]).reshape([b, s, h, d]) + + if (get_env_device() == "xpu" or get_env_device() == "gpu") and fuse_rope: + q_embed, k_embed, _ = fused_rotary_position_embedding( + q, + k, + None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + return q_embed, k_embed + + if position_ids is None: + # Note: Only for MixtralForCausalLMPipe model pretraining + cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, axis] + sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, axis] + else: + cos = cos.squeeze(axis=[0, 2]) # [seq_len, axis] + sin = sin.squeeze(axis=[0, 2]) # [seq_len, axis] + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, axis] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, axis] + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class FusedNormGateFunc(paddle.autograd.PyLayer): + """recompute of postnorm and gate""" + + _current_norm_output = None + _current_invar = None + + @classmethod + def set_temporary_vars(cls, norm_output, invar): + FusedNormGateFunc._current_norm_output = norm_output + FusedNormGateFunc._current_invar = invar + + @classmethod + def clear_temporary_vars(cls): + FusedNormGateFunc._current_norm_output = None + FusedNormGateFunc._current_invar = None + + @staticmethod + def forward(ctx, x, rms_norm_weight, moe_gate_weight, eps): + ctx.dtype = paddle.float32 + norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + with paddle.amp.auto_cast(False): + gate_logits = F.linear(cast_if_needed(norm_output, ctx.dtype), cast_if_needed(moe_gate_weight, ctx.dtype)) + + ctx.save_for_backward(x, rms_norm_weight, moe_gate_weight, eps) + return gate_logits, norm_output + + @staticmethod + def backward(ctx, d_gate_logits, d_norm_output): + x, rms_norm_weight, moe_gate_weight, eps = ctx.saved_tensor() + # recompute rmsnorm + norm_output = FusedNormGateFunc._current_norm_output + invar = FusedNormGateFunc._current_invar + if norm_output is None or invar is None: + norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + d_norm_output_linear, d_moe_gate_weight = paddle._C_ops.matmul_grad( + cast_if_needed(norm_output, ctx.dtype), + cast_if_needed(moe_gate_weight, ctx.dtype), + d_gate_logits, + False, + False, + ) + d_norm_output_linear, d_moe_gate_weight = cast_if_needed( + d_norm_output_linear, norm_output.dtype + ), cast_if_needed(d_moe_gate_weight, moe_gate_weight.dtype) + + d_norm_output = d_norm_output + d_norm_output_linear + dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, rms_norm_weight, invar, d_norm_output, eps) + + return dx, d_rms_norm_weight, d_moe_gate_weight + + +class TemporaryVarContext: + def __init__(self, norm_output, invar): + self.norm_output = norm_output + self.invar = invar + + def __enter__(self): + FusedNormGateFunc.set_temporary_vars(self.norm_output, self.invar) + + def __exit__(self, exc_type, exc_val, exc_tb): + FusedNormGateFunc.clear_temporary_vars() + + +def balance_expert_assignment(n, m, k): + assert k * n % m == 0 + matrix = paddle.zeros((n, m), dtype=paddle.int32) + for row in range(n): + start_col = row % m + for i in range(k): + col = (start_col + i) % m + matrix[row, col] = 1 + return matrix + + +class FakeGate(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, hidden_states, weight, fakse_gate_restrict_balance=False, num_experts_per_tok=8): + expert_num = weight.shape[1] + bsz, seq, _ = hidden_states.shape + + ctx.x_shape = hidden_states.shape + ctx.x_dtype = hidden_states.dtype + ctx.y_shape = weight.shape + ctx.y_dtype = weight.dtype + if fakse_gate_restrict_balance: + return paddle.reshape( + balance_expert_assignment(bsz * seq, expert_num, num_experts_per_tok), [bsz, seq, expert_num] + ) + else: + return paddle.randn([bsz, seq, expert_num]).cast(weight.dtype) + + @staticmethod + def backward(ctx, grad_output): + return paddle.zeros(ctx.x_shape, dtype=ctx.x_dtype), paddle.zeros(ctx.y_shape, dtype=ctx.y_dtype) + + +class AddAuxiliaryLoss(paddle.autograd.PyLayer): + """ + The trick function of adding auxiliary (aux) loss, + which includes the gradient of the aux loss during backpropagation. + """ + + @staticmethod + def forward(ctx, x, loss): + ctx.dtype = loss.dtype + ctx.required_aux_loss = not loss.stop_gradient + return x.clone() # clone to avoid inplace problem when using overlap + + @staticmethod + def backward(ctx, grad_output): + grad_loss = None + if ctx.required_aux_loss: + grad_loss = paddle.ones(1, dtype=ctx.dtype) + return grad_output, grad_loss + + +@to_static(backend="CINN") +def qkv_pre_process_no_fuse( + q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids +): + bsz, q_len, _ = q.shape + + target_query_shape = [0, 0, num_heads, q_head_dim] + target_key_value_shape = [0, 0, num_heads, qk_nope_head_dim + v_head_dim] + + q = q.reshape(shape=target_query_shape) + q_nope = q[..., :qk_nope_head_dim] + q_pe = q[..., qk_nope_head_dim:] + + # DeepSeekV2 kv_lora_rank+qk_rope_head_dim=512+64 + + kv = kv.reshape(shape=target_key_value_shape) + + k_pe = k_pe.reshape([-1, q_len, 1, qk_rope_head_dim]).expand([-1, q_len, num_heads, qk_rope_head_dim]) + + # self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim = 128+64 + # self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim) = config.qk_nope_head_dim + self.v_head_dim = 128+128 + k_nope = kv[..., :qk_nope_head_dim] + value_states = kv[..., qk_nope_head_dim:] + + kv_seq_len = value_states.shape[1] + + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids, False) + + query_states = paddle.concat([q_nope, q_pe], axis=-1) + key_states = paddle.concat([k_nope, k_pe], axis=-1) + + return query_states, key_states, value_states + + +@to_static(backend="CINN") +def rearrange_kv(kv, k_pe, qk_nope_head_dim, num_heads): + k_nope = kv[..., :qk_nope_head_dim] + value_states = kv[..., qk_nope_head_dim:] + + k_pe = k_pe.expand([k_pe.shape[0], k_pe.shape[1], num_heads, k_pe.shape[3]]) + key_states = paddle.concat([k_nope, k_pe], axis=-1) + + return key_states, value_states + + +def qkv_pre_process( + q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids +): + if (fused_partial_rope is None) or (position_ids is not None): + return qkv_pre_process_no_fuse( + q, + kv, + k_pe, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + ) + + bsz, q_len, _ = q.shape + + target_query_shape = [0, 0, num_heads, q_head_dim] + target_key_value_shape = [0, 0, num_heads, qk_nope_head_dim + v_head_dim] + + q = q.reshape(shape=target_query_shape) + kv = kv.reshape(shape=target_key_value_shape) + k_pe = k_pe.reshape([-1, q_len, 1, qk_rope_head_dim]) + + value_states = kv[..., qk_nope_head_dim:] + + kv_seq_len = value_states.shape[1] + + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + + query_states = fused_partial_rope(q, cos, sin) + k_pe = fused_partial_rope(k_pe, cos, sin) + + key_states, value_states = rearrange_kv(kv, k_pe, qk_nope_head_dim, num_heads) + + return query_states, key_states, value_states + + +def manul_fwd( + q_init, + kv_init, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, +): + + q_ln_t, q_ln_invar = fused_ln.fused_rms_norm(q_init, q_ln_weight, eps) + q = paddle.matmul(q_ln_t, q_up_weight) + + compressed_kv, k_pe = paddle.split(kv_init, [kv_lora_rank, qk_rope_head_dim], axis=-1) + + kv_ln_t, kv_ln_invar = fused_ln.fused_rms_norm(compressed_kv, kv_ln_weight, eps) + + kv = paddle.matmul(kv_ln_t, kv_up_weight) + + query_states, key_states, value_states = qkv_pre_process( + q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids + ) + + q_head_dim = query_states.shape[-1] + softmax_scale = softmax_scale * (q_head_dim**0.5) + query_states = query_states * softmax_scale + + attn_out, _, softmax_lse, seed_offset = _C_ops.flash_attn( + query_states, + key_states, + query_states, + None, + None, + 0.0, + True, + False, + False, + "", + ) + + return attn_out + + +class MemroyRecomputeAttnFunc(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx, + q_init, + kv_init, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + recompute_fa3=False, + fa_version=3, + ): + + bsz = q_init.shape[0] + q_ln_t, q_ln_invar = fused_ln.fused_rms_norm(q_init, q_ln_weight, eps) + # q = paddle.matmul(q_ln_t, q_up_weight) + q_orig_shape = q_ln_t.shape + q = FP8LinearFunctionBase.compute_fp8_linear( + q_ln_t.reshape([-1, q_orig_shape[-1]]), q_up_weight, weight_transpose=True, return_transpose_only=True + ) + q = q.reshape(q_orig_shape[:-1] + [q_up_weight.shape[-1]]) + + compressed_kv, k_pe = paddle.split(kv_init, [kv_lora_rank, qk_rope_head_dim], axis=-1) + + kv_ln_t, kv_ln_invar = fused_ln.fused_rms_norm(compressed_kv, kv_ln_weight, eps) + # kv = paddle.matmul(kv_ln_t, kv_up_weight) + kv_orig_shape = kv_ln_t.shape + kv = FP8LinearFunctionBase.compute_fp8_linear( + kv_ln_t.reshape([-1, kv_orig_shape[-1]]), kv_up_weight, weight_transpose=True, return_transpose_only=True + ) + kv = kv.reshape(kv_orig_shape[:-1] + [kv_up_weight.shape[-1]]) + + query_states, key_states, value_states = qkv_pre_process( + q, + kv, + k_pe, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + ) + + q_head_dim = query_states.shape[-1] + + if fa_version == 2: + softmax_scale = softmax_scale * (q_head_dim**0.5) + query_states = query_states * softmax_scale + kv_seq_len = value_states.shape[1] + v_num_heads = value_states.shape[2] + value_padding = paddle.zeros( + [bsz, kv_seq_len, v_num_heads, q_head_dim - v_head_dim], + dtype=value_states.dtype, + ) + value_states_pad = paddle.concat([value_states, value_padding], axis=-1) + + attn_out, _, softmax_lse, seed_offset = _C_ops.flash_attn( + query_states, + key_states, + value_states_pad, + None, + None, + 0.0, + True, + False, + False, + "", + ) + + elif fa_version == 3: + attn_out, softmax_lse = _C_ops.flash_attn_v3( + query_states, + key_states, + value_states, + None, # q_v_ + None, # q_descale_ + None, # k_descale_ + None, # v_descale_ + softmax_scale, + True, + -1, # window_size_left + -1, # window_size_right + 0.0, # softcap + 1, # num_splits + False, # manual_set_pack_gqa + False, # pack_gqa_ + 0, # sm_margin + ) + else: + assert False, f"invalid {fa_version=}" + + if fa_version == 2: + ctx.save_for_backward( + q_init, + kv_init, + attn_out, + softmax_lse, + seed_offset, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + ) + elif fa_version == 3: + if recompute_fa3: + ctx.save_for_backward( + q_init, + kv_init, + None, + None, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + recompute_fa3, + ) + else: + ctx.save_for_backward( + q_init, + kv_init, + attn_out, + softmax_lse, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + recompute_fa3, + ) + else: + assert False, f"invalid {fa_version=}" + + ctx.fa_version = fa_version + + return attn_out + + @staticmethod + def backward(ctx, dout): + fa_version = ctx.fa_version + if fa_version == 2: + ( + q_init, + kv_init, + attn_out, + softmax_lse, + seed_offset, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + ) = ctx.saved_tensor() + elif fa_version == 3: + ( + q_init, + kv_init, + attn_out, + softmax_lse, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + recompute_fa3, + ) = ctx.saved_tensor() + else: + assert False, f"invalid {fa_version=}" + + if fa_version == 2: + assert "recompute_fa3" not in locals() + assert attn_out is not None and softmax_lse is not None + if fa_version == 3 and not recompute_fa3: + assert attn_out is not None and softmax_lse is not None + + q_ln_t, q_ln_invar = fused_ln.fused_rms_norm(q_init, q_ln_weight, eps) + + q_ln_fp8, q_ln_scale, q_ln_trans_fp8, q_ln_trans_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + q_ln_t.reshape([-1, q_ln_t.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + ) + + q_orig_shape = q_ln_t.shape + q = FP8LinearFunctionBase.compute_fp8_linear( + (q_ln_fp8, q_ln_scale), q_up_weight, weight_transpose=True, return_transpose_only=True + ) + q = q.reshape(q_orig_shape[:-1] + [q_up_weight.shape[-1]]) + + compressed_kv, k_pe = paddle.split(kv_init, [kv_lora_rank, qk_rope_head_dim], axis=-1) + + kv_ln_t, kv_ln_invar = fused_ln.fused_rms_norm(compressed_kv, kv_ln_weight, eps) + + kv_ln_fp8, kv_ln_scale, kv_ln_trans_fp8, kv_ln_trans_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + kv_ln_t.reshape([-1, kv_ln_t.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + ) + kv_orig_shape = kv_ln_t.shape + kv = FP8LinearFunctionBase.compute_fp8_linear( + (kv_ln_fp8, kv_ln_scale), kv_up_weight, weight_transpose=True, return_transpose_only=True + ) + kv = kv.reshape(kv_orig_shape[:-1] + [kv_up_weight.shape[-1]]) + + paddle.base.core._set_has_grad(True) + q.stop_gradient = False + kv.stop_gradient = False + k_pe.stop_gradient = False + query_states, key_states, value_states = qkv_pre_process( + q, + kv, + k_pe, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + ) + + if fa_version == 2: + q_head_dim = query_states.shape[-1] + query_states = query_states * softmax_scale + + bsz = value_states.shape[0] + kv_seq_len = value_states.shape[1] + v_num_heads = value_states.shape[2] + value_padding = paddle.zeros( + [bsz, kv_seq_len, v_num_heads, q_head_dim - v_head_dim], + dtype=value_states.dtype, + ) + value_states_pad = paddle.concat([value_states, value_padding], axis=-1) + + with paddle.no_grad(): + + q_grad, k_grad, v_grad = _C_ops.flash_attn_grad( + query_states, + key_states, + value_states_pad, + attn_out, + softmax_lse.view("bfloat16"), + seed_offset, + None, + dout, + 0.0, + True, + ) + + v_grad = v_grad[..., :v_head_dim] + q_grad = q_grad * softmax_scale + elif fa_version == 3: + # recompute fa3 + if recompute_fa3: + with paddle.no_grad(): + attn_out, softmax_lse = _C_ops.flash_attn_v3( + query_states, + key_states, + value_states, + None, # q_v_ + None, # q_descale_ + None, # k_descale_ + None, # v_descale_ + softmax_scale, + True, + -1, # window_size_left + -1, # window_size_right + 0.0, # softcap + 1, # num_splits + False, # manual_set_pack_gqa + False, # pack_gqa_ + 0, # sm_margin + ) + with paddle.no_grad(): + q_grad, k_grad, v_grad = _C_ops.flash_attn_v3_grad( + query_states, + key_states, + value_states, + attn_out, + softmax_lse.view("bfloat16"), + dout, + softmax_scale, + True, + -1, + -1, + 0.0, + 0, + ) + else: + assert False, f"invalid {fa_version=}" + + d_q, d_kv, d_k_pe = paddle.grad( + outputs=[query_states, key_states, value_states], + inputs=[q, kv, k_pe], + grad_outputs=[q_grad, k_grad, v_grad], + create_graph=False, + retain_graph=False, + ) + + paddle.base.core._set_has_grad(False) + + # call up proj + if hasattr(kv_up_weight, "main_grad"): + d_kv_fp8, d_kv_scale, d_kv_t_fp8, d_kv_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + d_kv.reshape([-1, d_kv.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + ) + + d_kv_ln_t = FP8LinearFunctionBase.compute_fp8_linear( + (d_kv_fp8, d_kv_scale), kv_up_weight, weight_transpose=False + ) + d_kv_ln_t = d_kv_ln_t.reshape(d_kv.shape[:-1] + [kv_up_weight.shape[0]]) + + def kv_up_weight_grad(kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_scale, kv_up_weight): + FP8LinearFunctionBase.kitchen_gemm( + kv_ln_trans_fp8, + kv_ln_trans_scale, + d_kv_t_fp8, + d_kv_t_scale, + True, + True, + kv_up_weight.main_grad, + paddle.float32, + ) + + if WeightGradStore.enabled: + + WeightGradStore.put( + partial( + kv_up_weight_grad, kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_scale, kv_up_weight + ) + ) + else: + kv_up_weight_grad(kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_scale, kv_up_weight) + + d_kv_up_weight = None + + else: + d_kv_ln_t, d_kv_up_weight = _C_ops.matmul_grad(kv_ln_t, kv_up_weight, d_kv, False, False) + + d_compressed_kv, d_kv_ln_weight = fused_ln.fused_rms_norm_grad_func( + compressed_kv, kv_ln_weight, kv_ln_invar, d_kv_ln_t, eps + ) + + d_kv_init = paddle.concat([d_compressed_kv, d_k_pe], axis=-1) + + if hasattr(q_up_weight, "main_grad"): + + d_q_fp8, d_q_scale, d_q_t_fp8, d_q_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + d_q.reshape([-1, d_q.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + ) + # d_q_ln_t = paddle.matmul(d_q, q_up_weight, transpose_y=True) + + d_q_ln_t = FP8LinearFunctionBase.compute_fp8_linear( + (d_q_fp8, d_q_scale), q_up_weight, weight_transpose=False + ) + d_q_ln_t = d_q_ln_t.reshape(d_q.shape[:-1] + [q_up_weight.shape[0]]) + + def q_up_weight_grad(q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q_up_weight): + FP8LinearFunctionBase.kitchen_gemm( + q_ln_trans_fp8, + q_ln_trans_scale, + d_q_t_fp8, + d_q_t_scale, + True, + True, + q_up_weight.main_grad, + paddle.float32, + ) + + if WeightGradStore.enabled: + WeightGradStore.put( + partial(q_up_weight_grad, q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q_up_weight) + ) + else: + q_up_weight_grad(q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q_up_weight) + + d_q_up_weight = None + + else: + d_q_ln_t, d_q_up_weight = _C_ops.matmul_grad(q_ln_t, q_up_weight, d_q, False, False) + + d_q_init, d_q_ln_weight = fused_ln.fused_rms_norm_grad_func(q_init, q_ln_weight, q_ln_invar, d_q_ln_t, eps) + + return d_q_init, d_kv_init, d_q_ln_weight, d_kv_ln_weight, d_q_up_weight, d_kv_up_weight + + +class MemroyRecomputeAttn(paddle.nn.Layer): + def __init__( + self, + q_norm_hidden_size, + kv_norm_hidden_size, + q_up_in_dim, + q_up_out_dim, + kv_up_in_dim, + kv_up_out_dim, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + eps, + kv_lora_rank, + softmax_scale, + recompute_fa3=False, + fa_version=3, + ) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.q_ln_weight = paddle.create_parameter( + shape=[q_norm_hidden_size], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(1.0), + ) + self.kv_ln_weight = paddle.create_parameter( + shape=[kv_norm_hidden_size], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(1.0), + ) + + self.q_up_weight = self.create_parameter( + shape=[q_up_in_dim, q_up_out_dim], + dtype=self._dtype, + is_bias=False, + ) + + self.kv_up_weight = self.create_parameter( + shape=[kv_up_in_dim, kv_up_out_dim], + dtype=self._dtype, + is_bias=False, + ) + ( + self.rotary_emb, + self.num_heads, + self.q_head_dim, + self.qk_nope_head_dim, + self.v_head_dim, + self.qk_rope_head_dim, + self.eps, + self.kv_lora_rank, + self.softmax_scale, + self.recompute_fa3, + self.fa_version, + ) = ( + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + eps, + kv_lora_rank, + softmax_scale, + recompute_fa3, + fa_version, + ) + set_parameter_color([self.q_up_weight, self.kv_up_weight], "memory_attn") + + def fp8_quant_weight(self, quant_transpose=None): + cache_fp8_weight(self.q_up_weight, quant_transpose=quant_transpose) + cache_fp8_weight(self.kv_up_weight, quant_transpose=quant_transpose) + + def forward(self, q_init, kv_init, position_ids): + + seq_len = q_init.shape[1] + + if self.rotary_emb.max_seq_len_cached is None or seq_len > self.rotary_emb.max_seq_len_cached: + self.rotary_emb._set_cos_sin_cache(seq_len) + + return MemroyRecomputeAttnFunc.apply( + q_init, + kv_init, + self.q_ln_weight, + self.kv_ln_weight, + self.q_up_weight, + self.kv_up_weight, + self.rotary_emb, + self.num_heads, + self.q_head_dim, + self.qk_nope_head_dim, + self.v_head_dim, + self.qk_rope_head_dim, + position_ids, + self.eps, + self.kv_lora_rank, + self.softmax_scale, + recompute_fa3=self.recompute_fa3, + fa_version=self.fa_version, + ) + + +class FusedRMSLinearFunc(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, rms_norm_weight, q_down_weight, kv_down_weight, eps): + + hidden_states, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + + h_fp8, h_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + hidden_states.reshape([-1, hidden_states.shape[-1]]), output_scale_transpose=True, quant_method="1x128" + ) + + h_orig_shape = hidden_states.shape + q = FP8LinearFunctionBase.compute_fp8_linear( + (h_fp8, h_scale), q_down_weight, weight_transpose=True, return_transpose_only=True + ) + q = q.reshape(h_orig_shape[:-1] + [q_down_weight.shape[-1]]) + + kv = paddle.matmul(hidden_states, kv_down_weight) + + ctx.save_for_backward(x, rms_norm_weight, q_down_weight, kv_down_weight) + ctx.eps = eps + return q, kv + + @staticmethod + def backward(ctx, d_q, d_kv): + x, rms_norm_weight, q_down_weight, kv_down_weight = ctx.saved_tensor() + eps = ctx.eps + hidden_states, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + + h_t_fp8, h_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + hidden_states.reshape([-1, hidden_states.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + return_transpose_only=True, + ) + + h_grad, d_kv_down_weight = _C_ops.matmul_grad(hidden_states, kv_down_weight, d_kv, False, False) + + if hasattr(q_down_weight, "main_grad"): + d_q_fp8, d_q_scale, d_q_t_fp8, d_q_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + d_q.reshape([-1, d_q.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + ) + FP8LinearFunctionBase.compute_fp8_linear( + (d_q_fp8, d_q_scale), q_down_weight, weight_transpose=False, out=h_grad.view([-1, h_grad.shape[-1]]) + ) + + def q_down_weight_grad(h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight): + FP8LinearFunctionBase.kitchen_gemm( + h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, True, True, q_down_weight.main_grad, paddle.float32 + ) + + if WeightGradStore.enabled: + WeightGradStore.put( + partial(q_down_weight_grad, h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight) + ) + else: + q_down_weight_grad(h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight) + + d_q_down_weight = None + + else: + h_grad_0, d_q_down_weight = _C_ops.matmul_grad(hidden_states, q_down_weight, d_q, False, False) + h_grad = h_grad + h_grad_0 + + dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, rms_norm_weight, invar, h_grad, eps) + + return dx, d_rms_norm_weight, d_q_down_weight, d_kv_down_weight + + +class FusedRMSLinear(paddle.nn.Layer): + def __init__(self, hidden_size, q_out_dim, kv_outdim, eps=1e-6) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.rms_norm_weight = paddle.create_parameter( + shape=[hidden_size], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(1.0), + ) + + self.q_down_weight = self.create_parameter( + shape=[hidden_size, q_out_dim], + dtype=self._dtype, + is_bias=False, + ) + + self.kv_down_weight = self.create_parameter( + shape=[hidden_size, kv_outdim], + dtype=self._dtype, + is_bias=False, + ) + self.eps = eps + set_parameter_color([self.q_down_weight], "rms_linear") + + def fp8_quant_weight(self, quant_transpose=None): + cache_fp8_weight(self.q_down_weight, quant_transpose=quant_transpose) + + def forward(self, x): + + return FusedRMSLinearFunc.apply(x, self.rms_norm_weight, self.q_down_weight, self.kv_down_weight, self.eps) + + +class FusedRMSLinearSingleFunc(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, rms_norm_weight, linear_weight, eps): + + hidden_states, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + q = paddle.matmul(hidden_states, linear_weight) + + ctx.save_for_backward(x, rms_norm_weight, linear_weight, eps) + return q + + @staticmethod + def backward(ctx, d_q, d_kv): + x, rms_norm_weight, linear_weight, eps = ctx.saved_tensor() + hidden_states, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + + h_grad, d_linear_weight = _C_ops.matmul_grad(hidden_states, linear_weight, d_q, False, False) + + dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, rms_norm_weight, invar, h_grad, eps) + + return dx, d_rms_norm_weight, d_linear_weight + + +class FusedRMSLinearSingle(paddle.nn.Layer): + def __init__(self, hidden_size, q_out_dim, kv_outdim, eps=1e-6) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.rms_norm_weight = paddle.create_parameter( + shape=[hidden_size], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(1.0), + ) + + self.linear_weight = self.create_parameter( + shape=[hidden_size, q_out_dim], + dtype=self._dtype, + is_bias=False, + ) + self.eps = eps + + def forward(self, x): + + return FusedRMSLinearFunc.apply(x, self.rms_norm_weight, self.linear_weight, self.eps) + + +class FastCrossEntropyFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, preds, labels): + softmax_val, loss = paddle._C_ops.cross_entropy_with_softmax(preds, labels, False, True, False, -100, -1) + + ctx.save_for_backward(labels, softmax_val) + return loss + + @staticmethod + def backward(ctx, dout): + labels, softmax_val = ctx.saved_tensor() + + preds_grad = paddle.incubate.nn.functional.cross_entropy_with_softmax_bwd_w_downcast( + labels, softmax_val.cast(paddle.float32), dout.cast(paddle.float32) + ) + + return preds_grad, None + + +class DeepseekV2LMHead(nn.Layer): + def __init__(self, config: DeepseekV2FastConfig, embedding_weight=None): + super(DeepseekV2LMHead, self).__init__() + self.config = config + + if config.num_nextn_predict_layers > 0: + self.seq_length = config.seq_length - config.num_nextn_predict_layers + else: + self.seq_length = config.seq_length + + if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: + vocab_size = config.vocab_size // config.tensor_parallel_degree + else: + vocab_size = config.vocab_size + + if embedding_weight is not None: + self.transpose_y = True + self.weight = embedding_weight + else: + self.transpose_y = False + self.weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.XavierNormal(1.0), + ) + # Must set distributed attr for Tensor Parallel ! + self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False + if get_env_device() == "xpu": + try: + from paddle_xpu.layers.nn import ( # noqa: F401 + parallel_matmul as xpu_parallel_matmul, + ) + + self.xpu_parallel_matmul = xpu_parallel_matmul() + except ImportError: + self.xpu_parallel_matmul = None + + def forward(self, hidden_states, tensor_parallel_output=None): + if self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + hidden_states = paddle.reshape_(hidden_states, [-1, self.seq_length, self.config.hidden_size]) + + if tensor_parallel_output is None: + tensor_parallel_output = self.config.tensor_parallel_output + + if get_env_device() == "xpu" and self.xpu_parallel_matmul is not None: + logits = self.xpu_parallel_matmul( + hidden_states, + self.weight, + transpose_y=False, + tensor_parallel_output=tensor_parallel_output, + training=self.training, + ) + else: + logits = parallel_matmul( + hidden_states, self.weight, transpose_y=self.transpose_y, tensor_parallel_output=tensor_parallel_output + ) + return logits + + def extra_repr(self): + return f"hidden_size={self.weight.shape[0]}, vocab_size={self.weight.shape[1]}, dtype={self.weight.dtype}" diff --git a/examples/experiments/deepseek_v3_pretrain/modeling_pp.py b/examples/experiments/deepseek_v3_pretrain/modeling_pp.py new file mode 100644 index 00000000000..b8c0df45d81 --- /dev/null +++ b/examples/experiments/deepseek_v3_pretrain/modeling_pp.py @@ -0,0 +1,2324 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import OrderedDict, Tuple, Union + +import paddle +import paddle.distributed.fleet as fleet +import paddle.nn as nn +from paddle.distributed.fleet.meta_parallel import ( + LayerDesc, + LocalSharedLayerDesc, + PipelineLayer, + ScheduleChunk, + ScheduleNode, + SharedLayerDesc, +) +from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import WeightGradStore + +try: + from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import EventStore +except ImportError: + EventStore = None + +from config.configuration import DeepseekV2FastConfig +from modeling import DeepseekV2DecoderLayer, DeepseekV2LMHead +from modeling import DeepseekV2ModelFast as DeepseekV2Model +from modeling import DeepseekV2MoE, DeepseekV2MTPLayer +from modeling import DeepseekV2PretrainedModelFast as DeepseekV2PretrainedModel +from modeling import ( + DeepseekV2PretrainingCriterionFast, + DeepseekV2RMSNorm, + TemporaryVarContext, + set_global_step, +) +from moe_utils import get_env_device +from paddle.distributed.fleet.recompute.recompute import recompute +from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp + +from paddleformers.transformers.model_utils import PipelinePretrainedModel +from paddleformers.utils.log import logger + +try: + import paddle.distributed.communication.deep_ep as deep_ep +except ImportError: + deep_ep = None + +from moe_layer import FusionMoeNode + +from paddleformers.transformers.fp8_utils import ( + FP8LinearFunction, + FP8LinearFunctionBase, +) +from paddleformers.transformers.fused_a2a import ( + fused_combine_backward_func, + fused_combine_forward_func, + fused_dispatch_backward_func, + fused_dispatch_forward_func, +) + +__all__ = [ + "DeepseekV2ForCausalLMPipe", +] + +import queue + +global_inputs_embeds_mtp_queue = queue.Queue() + + +def parse_args(args): + if isinstance(args, (tuple, list)): + if len(args) == 4: + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = args + + elif len(args) == 3: + hidden_states, attention_mask, attn_mask_startend_row_indices = args + position_ids = None + elif len(args) == 2: + hidden_states, attention_mask = args + attn_mask_startend_row_indices, position_ids = None, None + else: # len(args) == 1: + hidden_states = args[0] + attention_mask, attn_mask_startend_row_indices, position_ids = None, None, None + else: + hidden_states = args + attention_mask, attn_mask_startend_row_indices, position_ids = None, None, None + + if position_ids is not None: + position_ids.stop_gradient = True + + if attention_mask is not None: + attention_mask.stop_gradient = True + + if attn_mask_startend_row_indices is not None: + attn_mask_startend_row_indices.stop_gradient = True + + return hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids + + +def return_args(hidden_states, attention_mask=None, attn_mask_startend_row_indices=None, position_ids=None): + ret = (hidden_states,) + + if attention_mask is not None: + ret += (attention_mask.clone(),) + if attn_mask_startend_row_indices is not None: + ret += (attn_mask_startend_row_indices.clone(),) + if position_ids is not None: + ret += (position_ids.clone(),) + if len(ret) == 1: + ret = ret[0] + + return ret + + +def get_attr(layer, name): + if getattr(layer, name, None) is not None: + return getattr(layer, name, None) + else: + return get_attr(layer._layer, name) + + +def calc_stream_wait(group_id): + comm_event = deep_ep.get_event_from_comm_stream(group_id) + comm_event.calc_stream_wait(group_id) + + +class TensorMeta: + """Recording the meta info of forward inputs, to avoid 0-size problems""" + + def __init__(self, tensor): + self.shape = tensor.shape + self.dtype = tensor.dtype + + +class PostProcessNode(ScheduleNode): + def __init__( + self, + send_mtp_embed, + training, + alpha, + config, + shared_experts=None, + using_post_norm_recompute=False, + output_mtp_embed_first=False, + name="PostProcessNode", + ): + self.send_mtp_embed = send_mtp_embed + self.shared_experts = shared_experts + self.traning = training + self.config = config + self.alpha = alpha + self.using_post_norm_recompute = using_post_norm_recompute + self.output_mtp_embed_first = output_mtp_embed_first + self.name = name + + if self.using_post_norm_recompute: + assert self.shared_experts is not None + assert self.shared_experts.norm_weight is not None and self.shared_experts.norm_eps is not None + + def forward_without_residual(self, inputs): + + if isinstance(inputs, list): + inputs = tuple(inputs) + if self.using_post_norm_recompute: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states, norm_out) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states, norm_out) = inputs + else: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states) = inputs + + with paddle.no_grad(): + if self.shared_experts is not None: + if self.using_post_norm_recompute: + _, _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( + norm_out, self.shared_experts.w1, self.shared_experts.w2 + ) + norm_out = None + del norm_out + else: + _, _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( + hidden_states, self.shared_experts.w1, self.shared_experts.w2 + ) + residual = residual + shared_expert_output + + self.x = hidden_states + self.l_aux = l_aux + + hidden_states = residual + hidden_states.stop_gradient = False + + if self.send_mtp_embed: + assert not self.output_mtp_embed_first, "forward_without_residual doesn't support output_mtp_embed_first" + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + self.mtp_embed_shape = ( + inputs_embeds_mtp.shape + ) # Save the shape of mtp_embed, used for backward propagation + + return return_args(hidden_states) + + def forward(self, inputs): + + if isinstance(inputs, list): + inputs = tuple(inputs) + + if self.using_post_norm_recompute: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states, norm_out) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states, norm_out) = inputs + else: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states) = inputs + + with paddle.no_grad(): + if self.shared_experts is not None: + if self.using_post_norm_recompute: + _, _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( + norm_out, self.shared_experts.w1, self.shared_experts.w2 + ) + norm_out = None + del norm_out + else: + _, _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( + hidden_states, self.shared_experts.w1, self.shared_experts.w2 + ) + final_hidden_states = final_hidden_states + shared_expert_output + + self.x = hidden_states + self.l_aux = l_aux + hidden_states = residual + final_hidden_states + + if self.send_mtp_embed: + if self.output_mtp_embed_first: + hidden_states = paddle.concat([inputs_embeds_mtp, hidden_states], axis=-1) + else: + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + self.mtp_embed_shape = ( + inputs_embeds_mtp.shape + ) # Save the shape of mtp_embed shape, used for backward propagation + + return return_args(hidden_states) + + @paddle.no_grad() + def backward(self, output_grad): + (do3,) = output_grad + + if self.send_mtp_embed: + # Split gradient: first part of do3 corresponds to hidden_states, second part corresponds to inputs_embeds_mtp + hidden_size = do3.shape[-1] - self.mtp_embed_shape[-1] + if self.output_mtp_embed_first: + hidden_states_grad = do3[..., hidden_size:] + inputs_embeds_mtp_grad = do3[..., :hidden_size] + else: + hidden_states_grad = do3[..., :hidden_size] + inputs_embeds_mtp_grad = do3[..., hidden_size:] + else: + hidden_states_grad = do3 + inputs_embeds_mtp_grad = None + + if self.using_post_norm_recompute: + dx, norm_out, invar = FP8LinearFunctionBase.fp8_mlp_bwd_norm_rc( + hidden_states_grad, + self.x, + self.shared_experts.norm_weight, + self.shared_experts.norm_eps, + self.shared_experts.w1, + self.shared_experts.w2, + ) + else: + dx = FP8LinearFunctionBase.fp8_mlp_bwd( + hidden_states_grad, self.x, self.shared_experts.w1, self.shared_experts.w2, True + ) + + self.x = None + + residual_grad = hidden_states_grad + l_aux_grad = paddle.ones(1, dtype=self.l_aux.dtype) * self.alpha + final_hidden_states_grad = hidden_states_grad + + if self.using_post_norm_recompute: + if self.send_mtp_embed: + return ( + inputs_embeds_mtp_grad, + dx, + residual_grad, + l_aux_grad, + final_hidden_states_grad, + norm_out, + invar, + ) + else: + return (dx, residual_grad, l_aux_grad, final_hidden_states_grad, norm_out, invar) + else: + if self.send_mtp_embed: + return (inputs_embeds_mtp_grad, dx, residual_grad, l_aux_grad, final_hidden_states_grad) + else: + return (dx, residual_grad, l_aux_grad, final_hidden_states_grad) + + +class DecoderLayerNode(ScheduleNode): + def __init__( + self, + attn_node, + dispatch_node, + mlp_node, + combine_node, + post_process_node, + mlp_layer, + name="DecoderLayerNode", + ): + super().__init__(fwd_func=None, name=name) + assert (dispatch_node is None and combine_node is None) or ( + dispatch_node is not None and combine_node is not None + ) + self.attn_node = attn_node + self.dispatch_node = dispatch_node + self.mlp_node = mlp_node + self.combine_node = combine_node + self.post_process_node = post_process_node + + self.mlp_layer = mlp_layer + self.moe_group = mlp_layer.moe_group + self.moe_num_experts = mlp_layer.moe_num_experts + + self.states = None + self.hidden_states_meta = None + self.dispatched_probs_meta = None + self.combine_output_meta = None + + def dispatch_forward(self, inputs, previous_event=None, allocate_on_comm_stream=False): + paddle.base.core.nvprof_nvtx_push("raw_dispatch_forward") + if isinstance(inputs, list): + inputs = tuple(inputs) + ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + intermediate_hidden_states, + token_indices, + token_probs, + ) = inputs + + with paddle.no_grad(): + intermediate_hidden_states, dispatched_probs, states, _ = fused_dispatch_forward_func( + intermediate_hidden_states, + token_indices, + token_probs, + self.moe_num_experts, + self.moe_group, + previous_event=previous_event, + async_finish=True, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + dispatched_indices = states["dispatched_indices"] + self.mlp_layer.set_tokens_per_expert(states["tokens_per_expert"]) + dispatched_indices.stop_gradient = True + intermediate_hidden_states.stop_gradient = False + dispatched_probs.stop_gradient = False + self.states = states + self.hidden_states_meta = TensorMeta(intermediate_hidden_states) + self.dispatched_probs_meta = TensorMeta(dispatched_probs) + + inputs = ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + intermediate_hidden_states, + dispatched_indices, + dispatched_probs, + ) + paddle.base.core.nvprof_nvtx_pop() + return inputs + + def combine_forward(self, inputs, previous_event=None): + paddle.base.core.nvprof_nvtx_push("raw_combine_forward") + if isinstance(inputs, list): + inputs = tuple(inputs) + (inputs_embeds_mtp, hidden_states, residual, l_aux, expert_output) = inputs + + with paddle.no_grad(): + combine_output = fused_combine_forward_func( + expert_output, self.moe_group, self.states, previous_event=previous_event, async_finish=True + ) + combine_output.stop_gradient = False + self.combine_output_meta = TensorMeta(combine_output) + inputs = (inputs_embeds_mtp, hidden_states, residual, l_aux, combine_output) + paddle.base.core.nvprof_nvtx_pop() + return inputs + + def dispatch_backward(self, output_grad): + paddle.base.core.nvprof_nvtx_push("raw_dispatch_backward") + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + intermediate_hidden_states_grad, + dispatched_indices_grad, + dispatched_probs_grad, + ) = output_grad + + if intermediate_hidden_states_grad is None: + intermediate_hidden_states_grad = paddle.zeros( + self.hidden_states_meta.shape, self.hidden_states_meta.dtype + ) + if dispatched_probs_grad is None: + dispatched_probs_grad = paddle.zeros(self.dispatched_probs_meta.shape, self.dispatched_probs_meta.dtype) + with paddle.no_grad(): + intermediate_hidden_states_grad, token_indices_grad, token_probs_grad = fused_dispatch_backward_func( + intermediate_hidden_states_grad, + dispatched_probs_grad, + self.moe_group, + self.states["handle"], + async_finish=True, + ) + + output_grad = ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + intermediate_hidden_states_grad, + token_indices_grad, + token_probs_grad, + ) + paddle.base.core.nvprof_nvtx_pop() + return output_grad + + def combine_backward(self, output_grad): + paddle.base.core.nvprof_nvtx_push("raw_combine_backward") + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + combine_output_grad, + ) = output_grad + + if combine_output_grad is None: + combine_output_grad = paddle.zeros(self.combine_output_meta.shape, self.combine_output_meta.dtype) + with paddle.no_grad(): + expert_output_grad = fused_combine_backward_func( + combine_output_grad, self.moe_group, self.states["handle"], async_finish=True + ) + + output_grad = ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + expert_output_grad, + ) + paddle.base.core.nvprof_nvtx_pop() + return output_grad + + def forward(self, inputs): + inputs = self.attn_node.forward(inputs) + + if self.dispatch_node is None: + inputs = self.dispatch_forward(inputs) + calc_stream_wait(self.moe_group.id) + else: + inputs = self.dispatch_node.forward(inputs) + + inputs = self.mlp_node.forward(inputs) + + if self.combine_node is None: + inputs = self.combine_forward(inputs) + calc_stream_wait(self.moe_group.id) + else: + inputs = self.combine_node.forward(inputs) + + inputs = self.post_process_node.forward(inputs) + return inputs + + def backward(self, output_grad=None, scaler=None): + assert (output_grad is not None) and (scaler is None) + + output_grad = self.post_process_node.backward(output_grad) + + if self.combine_node is None: + output_grad = self.combine_backward(output_grad) + calc_stream_wait(self.moe_group.id) + else: + output_grad = self.combine_node.backward(output_grad) + + output_grad = self.mlp_node.backward(output_grad) + + if self.dispatch_node is None: + output_grad = self.dispatch_backward(output_grad) + calc_stream_wait(self.moe_group.id) + else: + output_grad = self.dispatch_node.backward(output_grad) + + output_grad = self.attn_node.backward(output_grad) + return output_grad + + +class OverlapedScheduleChunk: + def __init__(self, forward_nodes, backward_nodes, use_fuion=True): + assert len(forward_nodes) == len(backward_nodes) + self.nodes = [] + for f, b in zip(forward_nodes, backward_nodes): + schedule_node_class = OverlapedScheduleNode + if use_fuion: + schedule_node_class = OverlapedFUsionScheduleNode + if isinstance(f, DenseDecoderLayerNode) or isinstance(b, DenseDecoderLayerNode): + schedule_node_class = OverlapedDenseFusionScheduleNode + self.nodes.append(schedule_node_class(f, b, f"OverlapedNode_{len(self.nodes)}")) + + def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None): + event_to_wait = combine_bw_event_to_wait + for i, n in enumerate(self.nodes): + pp_stream_t = pp_stream + if i + 1 != len(self.nodes): + pp_stream_t = None + + inputs, output_grad, event_to_wait = n.forward_backward( + inputs, output_grad, combine_bw_event_to_wait=event_to_wait, pp_stream=pp_stream_t + ) + return inputs, output_grad, None + + +class DecoderBackwardScheduleChunk: + def __init__(self, nodes): + self.nodes = nodes + + def backward(self, output_grad, combine_bw_event_to_wait=None, pp_stream=None): + event_to_wait = combine_bw_event_to_wait + for i, n in enumerate(self.nodes): + pp_stream_t = pp_stream if i + 1 == len(self.nodes) else None + output_grad, event_to_wait = n.backward_for_fusion( + output_grad, combine_bw_event_to_wait=event_to_wait, pp_stream=pp_stream_t + ) + return output_grad + + +class OverlapedScheduleNode: + def __init__(self, forward_node, backward_node, name=""): + assert isinstance(forward_node, DecoderLayerNode) and isinstance(backward_node, DecoderLayerNode) + self.forward_node = forward_node + self.backward_node = backward_node + self.name = name + + def forward_backward(self, inputs, output_grad, event_to_wait=None): + paddle.base.core.nvprof_nvtx_push("forward_backward") + output_grad = self.backward_node.post_process_node.backward(output_grad) + + output_grad = self.backward_node.combine_backward(output_grad) + inputs = self.forward_node.attn_node.forward(inputs) + + calc_stream_wait(self.backward_node.moe_group.id) + attn_compute_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + output_grad = self.backward_node.mlp_node.backward(output_grad) + inputs = self.forward_node.dispatch_forward( + inputs, previous_event=attn_compute_event, allocate_on_comm_stream=True + ) + + calc_stream_wait(self.forward_node.moe_group.id) + output_grad = self.backward_node.dispatch_backward(output_grad) + inputs = self.forward_node.mlp_node.forward(inputs) + + calc_stream_wait(self.backward_node.moe_group.id) + inputs = self.forward_node.combine_forward(inputs) + output_grad = self.backward_node.attn_node.backward(output_grad) + + calc_stream_wait(self.forward_node.moe_group.id) + inputs = self.forward_node.post_process_node.forward(inputs) + paddle.base.core.nvprof_nvtx_pop() + return inputs, output_grad + + +class FusionFp8DecoderLayerNode(ScheduleNode): + def __init__( + self, + attn_and_gate_node, + fp8_fusion_moe_node, + post_process_node, + mlp_layer, + send_mtp_embed, + using_post_norm_recompute=False, + stepped_recompute_fwd_gate_up=False, + dsv3_use_fp8_dispatch=True, + name="", + ): + self.attn_and_gate_node = attn_and_gate_node + self.fp8_fusion_moe_node = fp8_fusion_moe_node + self.post_process_node = post_process_node + self.send_mtp_embed = send_mtp_embed + + self.using_post_norm_recompute = using_post_norm_recompute + self.stepped_recompute_fwd_gate_up = stepped_recompute_fwd_gate_up + self.name = name + + self.moe_group = mlp_layer.moe_group + self.dsv3_use_fp8_dispatch = dsv3_use_fp8_dispatch + + def attn_forward(self, inputs): + inputs = self.attn_and_gate_node.forward(inputs) + + if self.send_mtp_embed: + if self.using_post_norm_recompute: + inputs_embeds_mtp, hidden_states, residual, probs, routing_map, l_aux, norm_out = inputs + else: + inputs_embeds_mtp, hidden_states, residual, probs, routing_map, l_aux = inputs + else: + if self.using_post_norm_recompute: + hidden_states, residual, probs, routing_map, l_aux, norm_out = inputs + else: + hidden_states, residual, probs, routing_map, l_aux = inputs + + if self.using_post_norm_recompute: + hs_2d, token_indices, token_probs = self.fp8_fusion_moe_node.dispatch_quant_node.forward( + norm_out, probs, routing_map + ) + # common return values + ret = (hidden_states, residual, l_aux, hs_2d, token_indices, token_probs, norm_out) + + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + return ret + else: + hs_2d, token_indices, token_probs = self.fp8_fusion_moe_node.dispatch_quant_node.forward( + hidden_states, probs, routing_map + ) + + # common return values + ret = (hidden_states, residual, l_aux, hs_2d, token_indices, token_probs) + + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + return ret + + def dispatch_forward(self, inputs, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + inputs_embeds_mtp, hidden_states, residual, l_aux, hs_2d, token_indices, token_probs, norm_out = inputs + else: + hidden_states, residual, l_aux, hs_2d, token_indices, token_probs, norm_out = inputs + else: + if self.send_mtp_embed: + inputs_embeds_mtp, hidden_states, residual, l_aux, hs_2d, token_indices, token_probs = inputs + else: + hidden_states, residual, l_aux, hs_2d, token_indices, token_probs = inputs + + (hs_dispatched, dispatched_indices, dispatched_probs,) = self.fp8_fusion_moe_node.dispatch_node.forward( + hs_2d, + token_indices, + token_probs, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + ret = (hidden_states, residual, l_aux, hs_dispatched, dispatched_indices, dispatched_probs) + + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret + return ret + + def mlp_forward(self, inputs): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + hs_dispatched, + dispatched_indices, + dispatched_probs, + norm_out, + ) = inputs + else: + hidden_states, residual, l_aux, hs_dispatched, dispatched_indices, dispatched_probs, norm_out = inputs + else: + if self.send_mtp_embed: + ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + hs_dispatched, + dispatched_indices, + dispatched_probs, + ) = inputs + else: + hidden_states, residual, l_aux, hs_dispatched, dispatched_indices, dispatched_probs = inputs + + hidden_states_out = self.fp8_fusion_moe_node.mlp_node.forward( + hs_dispatched, dispatched_indices, dispatched_probs + ) + ret = (hidden_states, residual, l_aux, hidden_states_out) + + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret + return ret + + def combine_forward(self, inputs, async_finish=False, previous_event=None, allocate_on_comm_stream=False): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, hidden_states_out, norm_out) = inputs + else: + (hidden_states, residual, l_aux, hidden_states_out, norm_out) = inputs + else: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, hidden_states_out) = inputs + else: + (hidden_states, residual, l_aux, hidden_states_out) = inputs + + output_combine = self.fp8_fusion_moe_node.combine_node.forward( + hidden_states_out, + async_finish=async_finish, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream and previous_event is not None, + ) + + ret = (hidden_states, residual, l_aux, output_combine) + + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret + return ret + + def post_process_forward(self, inputs, with_residual=True): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, output_combine, norm_out) = inputs + else: + (hidden_states, residual, l_aux, output_combine, norm_out) = inputs + else: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, output_combine) = inputs + else: + (hidden_states, residual, l_aux, output_combine) = inputs + final_hidden_states = self.fp8_fusion_moe_node.combine_quant_node.forward(output_combine) + + inputs = (hidden_states, residual, l_aux, final_hidden_states) + inputs = (inputs_embeds_mtp, *inputs) if self.send_mtp_embed else inputs + inputs = (*inputs, norm_out) if self.using_post_norm_recompute else inputs + + if with_residual: + inputs = self.post_process_node.forward(inputs) + else: + inputs = self.post_process_node.forward_without_residual(inputs) + return inputs + + def post_process_backward(self, output_grad, event_to_wait=None): + grad = self.post_process_node.backward(output_grad) + + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + final_hidden_states_grad, + norm_out, + invar, + ) = grad + else: + hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad, norm_out, invar = grad + else: + if self.send_mtp_embed: + inputs_embeds_mtp_grad, hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad = grad + else: + hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad = grad + + output_combine_grad, quant_event = self.fp8_fusion_moe_node.combine_quant_node.backward( + final_hidden_states_grad, event_to_wait + ) + + ret = (hidden_states_grad, residual_grad, l_aux_grad, output_combine_grad, quant_event) + ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret + return ret + + def combine_backward(self, output_grad, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + output_combine_grad, + quant_event, + norm_out, + invar, + ) = output_grad + else: + ( + hidden_states_grad, + residual_grad, + l_aux_grad, + output_combine_grad, + quant_event, + norm_out, + invar, + ) = output_grad + else: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + output_combine_grad, + quant_event, + ) = output_grad + else: + ( + hidden_states_grad, + residual_grad, + l_aux_grad, + output_combine_grad, + quant_event, + ) = output_grad + + if self.dsv3_use_fp8_dispatch and quant_event is not None: + combine_backward_wait_event = quant_event + else: + combine_backward_wait_event = previous_event + hidden_states_out_grad = self.fp8_fusion_moe_node.combine_node.backward( + output_combine_grad, + async_finish=async_finish, + previous_event=combine_backward_wait_event, + allocate_on_comm_stream=allocate_on_comm_stream and quant_event is not None, + ) + + ret = (hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad) + ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret + return ret + + def mlp_backward(self, output_grad): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hidden_states_out_grad, + norm_out, + invar, + ) = output_grad + else: + hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad, norm_out, invar = output_grad + else: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hidden_states_out_grad, + ) = output_grad + else: + hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad = output_grad + hs_dispatched_grad, dispatched_probs_grad = self.fp8_fusion_moe_node.mlp_node.backward(hidden_states_out_grad) + + ret = (hidden_states_grad, residual_grad, l_aux_grad, hs_dispatched_grad, dispatched_probs_grad) + ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret + return ret + + def dispatch_backward(self, output_grad, async_finish=False, previous_event=None, allocate_on_comm_stream=False): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_dispatched_grad, + dispatched_probs_grad, + norm_out, + invar, + ) = output_grad + else: + ( + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_dispatched_grad, + dispatched_probs_grad, + norm_out, + invar, + ) = output_grad + else: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_dispatched_grad, + dispatched_probs_grad, + ) = output_grad + else: + hidden_states_grad, residual_grad, l_aux_grad, hs_dispatched_grad, dispatched_probs_grad = output_grad + + hs_grad, token_probs_grad = self.fp8_fusion_moe_node.dispatch_node.backward( + hs_dispatched_grad, + dispatched_probs_grad, + async_finish=async_finish, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream and previous_event is not None, + ) + + ret = (hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad) + ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret + return ret + + def attn_backward(self, output_grad): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_grad, + token_probs_grad, + norm_out, + invar, + ) = output_grad + inputs_embeds_mtp_grad_shape = hidden_states_grad.shape + inputs_embeds_mtp_grad_shape[-1] = -1 + inputs_embeds_mtp_grad = inputs_embeds_mtp_grad.view(inputs_embeds_mtp_grad_shape) + else: + hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad, norm_out, invar = output_grad + else: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_grad, + token_probs_grad, + ) = output_grad + inputs_embeds_mtp_grad_shape = hidden_states_grad.shape + inputs_embeds_mtp_grad_shape[-1] = -1 + inputs_embeds_mtp_grad = inputs_embeds_mtp_grad.view(inputs_embeds_mtp_grad_shape) + else: + hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad = output_grad + + hidden_states_grad_, probs_grad, routing_map_grad = self.fp8_fusion_moe_node.dispatch_quant_node.backward( + hs_grad, token_probs_grad + ) + + output_grad = (residual_grad, probs_grad, routing_map_grad, l_aux_grad) + + output_grad = ( + (hidden_states_grad, *output_grad, hidden_states_grad_) + if self.using_post_norm_recompute + else (hidden_states_grad + hidden_states_grad_, *output_grad) + ) + output_grad = (inputs_embeds_mtp_grad, *output_grad) if self.send_mtp_embed else output_grad + + if self.using_post_norm_recompute: + with TemporaryVarContext(norm_out, invar): + output_grad = self.attn_and_gate_node.backward(output_grad) + else: + output_grad = self.attn_and_gate_node.backward(output_grad) + return output_grad + + def backward_for_fusion(self, output_grad, combine_bw_event_to_wait=None, pp_stream=None): + paddle.base.core.nvprof_nvtx_push("backward") + if combine_bw_event_to_wait is None: + combine_bw_event_to_wait = deep_ep.get_event_from_calc_stream(self.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("post_process_backward") + output_grad = self.post_process_backward(output_grad, combine_bw_event_to_wait) + paddle.base.core.nvprof_nvtx_pop() + + paddle.base.core.nvprof_nvtx_push("combine_backward") + output_grad = self.combine_backward( + output_grad, previous_event=combine_bw_event_to_wait, async_finish=True, allocate_on_comm_stream=True + ) + combine_backward_event = deep_ep.get_event_from_comm_stream(self.moe_group.id) + combine_backward_event.calc_stream_wait(self.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() + + if WeightGradStore.enabled: + paddle.base.core.nvprof_nvtx_push("mlp_backward") + output_grad = self.mlp_backward(output_grad) + paddle.base.core.nvprof_nvtx_pop() + + paddle.base.core.nvprof_nvtx_push("dispatch_backward") + output_grad = self.dispatch_backward(output_grad) + paddle.base.core.nvprof_nvtx_pop() + + paddle.base.core.nvprof_nvtx_push("attn_backward") + output_grad = self.attn_backward(output_grad) + paddle.base.core.nvprof_nvtx_pop() + + event_to_wait = None + + else: + paddle.base.core.nvprof_nvtx_push("mlp_backward_dx") + assert WeightGradStore.funcs_queue.empty() + WeightGradStore.enabled = True + output_grad = self.mlp_backward(output_grad) + WeightGradStore.enabled = False + WeightGradStore.flush() + output_grad_event = deep_ep.get_event_from_calc_stream(self.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() + + paddle.base.core.nvprof_nvtx_push("dispatch_backward") + output_grad = self.dispatch_backward( + output_grad, async_finish=True, previous_event=output_grad_event, allocate_on_comm_stream=True + ) + dispatch_backward_event = deep_ep.get_event_from_comm_stream(self.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() + + paddle.base.core.nvprof_nvtx_push("mlp_backward_dw") + WeightGradStore.pop() + assert WeightGradStore.funcs_queue.empty() + paddle.base.core.nvprof_nvtx_pop() + + paddle.base.core.nvprof_nvtx_push("attn_backward_dx") + dispatch_backward_event.calc_stream_wait(self.moe_group.id) + WeightGradStore.enabled = True + output_grad = self.attn_backward(output_grad) + WeightGradStore.enabled = False + WeightGradStore.flush() + event_to_wait = deep_ep.get_event_from_calc_stream(self.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() + + paddle.base.core.nvprof_nvtx_push("attn_backward_dw") + WeightGradStore.pop() + assert WeightGradStore.funcs_queue.empty() + paddle.base.core.nvprof_nvtx_pop() + + paddle.base.core.nvprof_nvtx_pop() + return output_grad, event_to_wait + + def forward(self, inputs): + if self.stepped_recompute_fwd_gate_up: + self.fp8_fusion_moe_node.mlp_node.set_recompute_fwd_gate_up(True) + inputs = self.attn_forward(inputs) + inputs = self.dispatch_forward(inputs) + inputs = self.mlp_forward(inputs) + inputs = self.combine_forward(inputs) + inputs = self.post_process_forward(inputs) + return inputs + + def backward(self, output_grad=None, scaler=None): + assert (output_grad is not None) and (scaler is None) + output_grad = self.post_process_backward(output_grad) + output_grad = self.combine_backward(output_grad) + output_grad = self.mlp_backward(output_grad) + # todo(phlrain): overlap here + output_grad = self.dispatch_backward(output_grad) + output_grad = self.attn_backward(output_grad) + return output_grad + + +class DenseDecoderLayerNode(ScheduleNode): + def __init__( + self, + attn_node, + mlp_node, + name="DenseDecoderLayerNode", + ): + super().__init__(fwd_func=None, name=name) + self.attn_node = attn_node + self.mlp_node = mlp_node + + def forward(self, inputs): + inputs = self.attn_node.forward(inputs) + inputs = self.mlp_node.forward(inputs) + return inputs + + def backward(self, output_grad=None, scaler=None): + assert (output_grad is not None) and (scaler is None) + output_grad = self.mlp_node.backward(output_grad) + output_grad = self.attn_node.backward(output_grad) + return output_grad + + +class OverlapedFUsionScheduleNode: + def __init__(self, forward_node, backward_node, name=""): + assert isinstance(forward_node, FusionFp8DecoderLayerNode) and isinstance( + backward_node, FusionFp8DecoderLayerNode + ) + self.forward_node = forward_node + self.backward_node = backward_node + self.name = name + + def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None): + paddle.base.core.nvprof_nvtx_push("forward_backward") + + combine_bwd_event = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("attn_forward") + inputs = self.forward_node.attn_forward(inputs) + paddle.base.core.nvprof_nvtx_pop() + attn_compute_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("post_process_backward") + output_grad = self.backward_node.post_process_backward(output_grad, combine_bw_event_to_wait) + paddle.base.core.nvprof_nvtx_pop() + + paddle.base.core.nvprof_nvtx_push("combine_backward") + if combine_bw_event_to_wait is not None: + output_grad = self.backward_node.combine_backward( + output_grad, previous_event=combine_bw_event_to_wait, async_finish=True, allocate_on_comm_stream=True + ) + else: + output_grad = self.backward_node.combine_backward( + output_grad, previous_event=combine_bwd_event, async_finish=True, allocate_on_comm_stream=True + ) + # get combine event + combine_backward_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() + + combine_backward_event.calc_stream_wait(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_push("mlp_backward_dx") + assert WeightGradStore.funcs_queue.empty() + WeightGradStore.enabled = True + output_grad = self.backward_node.mlp_backward(output_grad) + WeightGradStore.enabled = False + WeightGradStore.flush() + paddle.base.core.nvprof_nvtx_pop() + + output_grad_event = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("dispatch_forward") + inputs = self.forward_node.dispatch_forward( + inputs, previous_event=attn_compute_event, async_finish=True, allocate_on_comm_stream=True + ) + paddle.base.core.nvprof_nvtx_pop() + dispatch_forward_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("dispatch_backward") + output_grad = self.backward_node.dispatch_backward( + output_grad, async_finish=True, previous_event=output_grad_event, allocate_on_comm_stream=True + ) + paddle.base.core.nvprof_nvtx_pop() + # get dispatch backward event + dispatch_backward_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("dispatch_backward_dw") + WeightGradStore.pop() + assert WeightGradStore.funcs_queue.empty() + paddle.base.core.nvprof_nvtx_pop() + + dispatch_forward_event.calc_stream_wait(self.forward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_push("mlp_forward") + inputs = self.forward_node.mlp_forward(inputs) + paddle.base.core.nvprof_nvtx_pop() + + if pp_stream is not None: + paddle.base.core.nvprof_nvtx_push("post_process_forward") + + final_out = self.forward_node.post_process_node.forward_without_residual(inputs) + paddle.base.core.nvprof_nvtx_pop() + + final_out_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("combine_forward") + inputs = self.forward_node.combine_forward( + inputs, previous_event=final_out_event, async_finish=True, allocate_on_comm_stream=True + ) + paddle.base.core.nvprof_nvtx_pop() + + combine_forward_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id) + + combine_fwd_out = inputs[-2] if self.forward_node.using_post_norm_recompute else inputs[-1] + + if pp_stream is not None: + send_recv_stream = paddle.device.Stream(stream_base=pp_stream) + + paddle.base.core.nvprof_nvtx_push("pp stream add") + + with paddle.device.stream_guard(send_recv_stream): + combine_forward_event.current_stream_wait() + final_out_event.current_stream_wait() + + # TODO: check correct + # if final_out.shape[-1] != combine_fwd_out.shape[-1]: + # final_out[:, :, : combine_fwd_out.shape[-1]] += combine_fwd_out # Directly broadcast and add + # else: + # final_out += combine_fwd_out + inputs = final_out + combine_fwd_out + + final_out._record_stream() + combine_fwd_out._record_stream() + + paddle.base.core.nvprof_nvtx_pop() + + dispatch_backward_event.calc_stream_wait(self.backward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("attn_backward") + assert WeightGradStore.funcs_queue.empty() + WeightGradStore.enabled = True + output_grad = self.backward_node.attn_backward(output_grad) + event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) + + if EventStore is not None: + EventStore.set(event_to_wait) + + WeightGradStore.enabled = False + WeightGradStore.flush() + WeightGradStore.pop() + assert WeightGradStore.funcs_queue.empty() + + paddle.base.core.nvprof_nvtx_pop() + + # residual add + if pp_stream is None: + combine_forward_event.calc_stream_wait(self.forward_node.moe_group.id) + + final_out = self.forward_node.post_process_node.forward_without_residual(inputs) + if final_out.shape[-1] != combine_fwd_out.shape[-1]: + final_out[:, :, : combine_fwd_out.shape[-1]] += combine_fwd_out + else: + final_out += combine_fwd_out + inputs = final_out + combine_fwd_out._record_stream() + + paddle.base.core.nvprof_nvtx_pop() + return inputs, output_grad, event_to_wait + + +class OverlapedDenseFusionScheduleNode: + def __init__(self, forward_node, backward_node, name=""): + assert isinstance(forward_node, FusionFp8DecoderLayerNode) or isinstance( + backward_node, FusionFp8DecoderLayerNode + ) + assert isinstance(forward_node, DenseDecoderLayerNode) or isinstance(backward_node, DenseDecoderLayerNode) + self.forward_node = forward_node + self.backward_node = backward_node + self.name = name + + def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None): + # Dense forward + MoE backward + if isinstance(self.forward_node, DenseDecoderLayerNode): + paddle.base.core.nvprof_nvtx_push("dense_fw_moe_bw") + + paddle.base.core.nvprof_nvtx_push("dense_attn_moe_combine") + # Note: the input combine_bw_event_to_wait is unreliable, we need to record a new event here. + combine_bw_event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) + output_grad = self.backward_node.post_process_backward(output_grad, combine_bw_event_to_wait) + output_grad = self.backward_node.combine_backward( + output_grad, previous_event=combine_bw_event_to_wait, async_finish=True, allocate_on_comm_stream=True + ) + combine_bw_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id) + inputs = self.forward_node.attn_node.forward(inputs) + combine_bw_event.calc_stream_wait(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_attn_moe_combine + + paddle.base.core.nvprof_nvtx_push("moe_mlp") + assert WeightGradStore.funcs_queue.empty() + WeightGradStore.enabled = True + output_grad = self.backward_node.mlp_backward(output_grad) + WeightGradStore.enabled = False + WeightGradStore.flush() + paddle.base.core.nvprof_nvtx_pop() # moe_mlp + + paddle.base.core.nvprof_nvtx_push("dense_mlp_moe_dispatch") + output_grad = self.backward_node.dispatch_backward( + output_grad, async_finish=True, allocate_on_comm_stream=True + ) + dispatch_bw_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_push("dispatch_backward_dw") + WeightGradStore.pop() + assert WeightGradStore.funcs_queue.empty() + paddle.base.core.nvprof_nvtx_pop() + inputs = self.forward_node.mlp_node.forward(inputs) + dispatch_bw_event.calc_stream_wait(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_mlp_moe_dispatch + + paddle.base.core.nvprof_nvtx_push("moe_attn") + output_grad = self.backward_node.attn_backward(output_grad) + paddle.base.core.nvprof_nvtx_pop() # moe_attn + + event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_fw_moe_bw + + # Dense backward + MoE forward + else: + paddle.base.core.nvprof_nvtx_push("dense_bw_moe_fw") + + paddle.base.core.nvprof_nvtx_push("moe_attn") + inputs = self.forward_node.attn_forward(inputs) + attn_fw_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # moe_attn + + paddle.base.core.nvprof_nvtx_push("dense_mlp_moe_dispatch") + if combine_bw_event_to_wait is not None: + combine_bw_event_to_wait.calc_stream_wait(self.forward_node.moe_group.id) + output_grad = self.backward_node.mlp_node.backward(output_grad) + inputs = self.forward_node.dispatch_forward( + inputs, previous_event=attn_fw_event, async_finish=True, allocate_on_comm_stream=True + ) + dispatch_fw_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id) + dispatch_fw_event.calc_stream_wait(self.forward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_mlp_moe_dispatch + + paddle.base.core.nvprof_nvtx_push("moe_mlp") + inputs = self.forward_node.mlp_forward(inputs) + paddle.base.core.nvprof_nvtx_pop() # moe_mlp + + paddle.base.core.nvprof_nvtx_push("dense_attn_moe_combine") + inputs = self.forward_node.combine_forward(inputs, async_finish=True, allocate_on_comm_stream=True) + combine_fw_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id) + output_grad = self.backward_node.attn_node.backward(output_grad) + combine_fw_event.calc_stream_wait(self.forward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_attn_moe_combine + + paddle.base.core.nvprof_nvtx_push("moe_post") + inputs = self.forward_node.post_process_forward(inputs) + paddle.base.core.nvprof_nvtx_pop() # moe_post + + event_to_wait = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_bw_moe_fw + + return inputs, output_grad, event_to_wait + + +def build_overlapped_nodes(config: DeepseekV2FastConfig, forward_chunk, backward_chunk): + overlap_element_class = ( + FusionFp8DecoderLayerNode if config.dsv3_use_fp8_gemm else DecoderLayerNode, + DenseDecoderLayerNode, + ) + forward_decoder_layer_num = 0 + backward_decoder_layer_num = 0 + assert isinstance(forward_chunk, ScheduleChunk) and isinstance(backward_chunk, ScheduleChunk) + for n in forward_chunk.nodes: + if isinstance(n, overlap_element_class): + forward_decoder_layer_num += 1 + for n in reversed(backward_chunk.nodes): + if isinstance(n, overlap_element_class): + backward_decoder_layer_num += 1 + + overlap_layers_num = min(forward_decoder_layer_num, backward_decoder_layer_num) + forward_pre_overlap_layers = [] + forward_post_overlap_layers = [] + forward_overlap_layers = [] + is_pre = True + for n in forward_chunk.nodes: + if not isinstance(n, overlap_element_class): + if is_pre: + forward_pre_overlap_layers.append(n) + else: + forward_post_overlap_layers.append(n) + else: + is_pre = False + if len(forward_overlap_layers) == overlap_layers_num: + forward_post_overlap_layers.append(n) + else: + forward_overlap_layers.append(n) + forward_pre_node = ScheduleChunk(forward_pre_overlap_layers) + forward_post_node = ScheduleChunk(forward_post_overlap_layers) + + backward_pre_overlap_layers = [] + backward_post_overlap_layers = [] + backward_overlap_layers = [] + is_pre = True + for n in reversed(backward_chunk.nodes): + if not isinstance(n, overlap_element_class): + if is_pre: + backward_pre_overlap_layers.append(n) + else: + backward_post_overlap_layers.append(n) + else: + is_pre = False + if len(backward_overlap_layers) == overlap_layers_num: + backward_post_overlap_layers.append(n) + else: + backward_overlap_layers.append(n) + + backward_pre_node = ScheduleChunk(list(reversed(backward_pre_overlap_layers))) + backward_post_node = ScheduleChunk(list(reversed(backward_post_overlap_layers))) + + if not forward_chunk.nodes and all(isinstance(n, FusionFp8DecoderLayerNode) for n in backward_chunk.nodes): + backward_post_node = DecoderBackwardScheduleChunk(backward_post_overlap_layers) + + overlap_node = OverlapedScheduleChunk( + forward_overlap_layers, backward_overlap_layers, use_fuion=config.dsv3_use_fp8_gemm + ) + return forward_pre_node, backward_pre_node, overlap_node, forward_post_node, backward_post_node + + +class EmbeddingFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, weight): + out = paddle.nn.functional.embedding( + x, weight=weight, padding_idx=None, max_norm=None, norm_type=2.0, sparse=False, scale_grad_by_freq=False + ) + + ctx.save_for_backward(x, weight) + return out + + @staticmethod + def backward(ctx, dout): + x, weight = ctx.saved_tensor() + + if hasattr(weight, "main_grad"): + paddle.incubate.nn.functional.embedding_grad_add_to_(x, weight.main_grad, dout) + else: + paddle.incubate.nn.functional.embedding_grad_add_to_(x, weight.grad, dout) + + return None, None + + +class DeepseekV2EmbeddingPipe(nn.Layer): + def __init__(self, config: DeepseekV2FastConfig): + super(DeepseekV2EmbeddingPipe, self).__init__() + self.config = config + self.sequence_parallel = config.sequence_parallel + self.hidden_size = config.hidden_size + if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: + self.embed_tokens = fleet.meta_parallel.VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()), + ) + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + + @property + def embedding_weight(self): + return get_attr(self.embed_tokens, "weight") + + def forward(self, args): + """_summary_ + + Args: + input (_type_): _description_ + + Returns: + _type_: _description_ + """ + input_ids, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) + inputs_embeds = EmbeddingFunction.apply(input_ids, self.embed_tokens.weight) + + batch_size, seq_length = input_ids.shape + if self.config.num_nextn_predict_layers > 0: + seq_length -= self.config.num_nextn_predict_layers + + if attention_mask is not None: + attention_mask = attention_mask[ + :, :, : -self.config.num_nextn_predict_layers, : -self.config.num_nextn_predict_layers + ] + + if attention_mask is not None: + assert ( + attn_mask_startend_row_indices is None + ), "attention_mask and attn_mask_startend_row_indices can not be set at same time" + + attention_mask = DeepseekV2Model._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), 0, inputs_embeds.dtype + ) + attention_mask.stop_gradient = True + if get_env_device() == "npu": + attention_mask = attention_mask.astype("bool") + elif get_env_device() == "npu": + attention_mask = paddle.tril(paddle.ones((seq_length, seq_length), dtype="bool")) + attention_mask.stop_gradient = True + + if self.config.num_nextn_predict_layers > 0: + inputs_embeds_extra = inputs_embeds[:, -self.config.num_nextn_predict_layers :, :] # [B, S, D] + inputs_embeds = inputs_embeds[:, : -self.config.num_nextn_predict_layers, :] + inputs_embeds_ori = inputs_embeds + batch_size, seq_length, _ = inputs_embeds.shape + + if self.sequence_parallel: + # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] + inputs_embeds = paddle.reshape(inputs_embeds, [-1, inputs_embeds.shape[-1]]) + # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) + inputs_embeds = ScatterOp.apply(inputs_embeds) + embeds_res = [inputs_embeds] + mtp_embeds = [] + for depth in range(self.config.num_nextn_predict_layers): + inputs_embeds_mtp = paddle.concat( + [ + inputs_embeds_ori[:, (depth + 1) :, :], + inputs_embeds_extra[:, : (depth + 1), :], + ], + axis=1, + ) + if self.sequence_parallel: + inputs_embeds_mtp = inputs_embeds_mtp.reshape([-1, inputs_embeds_mtp.shape[-1]]) + inputs_embeds_mtp = ScatterOp.apply(inputs_embeds_mtp) + mtp_embeds.append(inputs_embeds_mtp) + + if self.config.send_mtp_embed: + embeds_res.extend(mtp_embeds) + # if not self.sequence_parallel + # mtp_embeds: [B*num_nextn_predict_layers, seq_len, hidden_size] + # else: + # mtp_embeds: [B*seq_len*num_nextn_predict_layers, hidden_size] + inputs_embeds = paddle.concat(embeds_res, axis=-1) + else: + global global_inputs_embeds_mtp_queue + cloned_mtp_embeds = [t.detach() for t in mtp_embeds] + global_inputs_embeds_mtp_queue.put(cloned_mtp_embeds) + return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids) + else: + if self.sequence_parallel: + inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[-1]]) + inputs_embeds = ScatterOp.apply(inputs_embeds) + return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids) + + def build_schedule_node(self): + return ScheduleNode(self.forward, name="DeepseekV2EmbeddingPipe") + + +class DeepseekV2DecoderLayerPipe(DeepseekV2DecoderLayer): + def forward(self, args): + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) + + if self.config.send_mtp_embed: + batch_size, _, hidden_size = hidden_states.shape + batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1) + inputs_embeds_mtp = hidden_states[..., batch_size_mtp:] + hidden_states = hidden_states[..., :batch_size_mtp] + + has_gradient = not hidden_states.stop_gradient + + if attention_mask is not None and attention_mask.dtype == paddle.int32: + attention_mask, attn_mask_startend_row_indices, position_ids = ( + None, + attention_mask, + attn_mask_startend_row_indices, + ) + elif attention_mask is not None and attention_mask.dtype == paddle.int64: + attention_mask, attn_mask_startend_row_indices, position_ids = None, None, attention_mask + elif attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.dtype == paddle.int64: + attn_mask_startend_row_indices, position_ids = None, attn_mask_startend_row_indices + + if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + if attention_mask is not None or attn_mask_startend_row_indices is not None: + hidden_states = recompute( + super().forward, + hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + use_reentrant=False, + ) + else: + # for pretrain + hidden_states = recompute( + super().forward, + hidden_states, + position_ids=position_ids, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + hidden_states = super().forward( + hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + ) + + if self.config.send_mtp_embed: + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + + return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids) + + def attn_compute(self, args): + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) + assert attention_mask is None + assert attn_mask_startend_row_indices is None + assert position_ids is None + assert self.config.send_mtp_embed + + batch_size, _, hidden_size = hidden_states.shape + batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1) + inputs_embeds_mtp = hidden_states[..., batch_size_mtp:] + hidden_states = hidden_states[..., :batch_size_mtp] + + def attn_compute_func(hidden_states): + hidden_states, residual = self.self_attn_compute(hidden_states) + l_aux, _, intermediate_hidden_states, token_indices, token_probs = self.pre_dispatch_compute(hidden_states) + return (hidden_states, residual, l_aux, intermediate_hidden_states, token_indices, token_probs) + + has_gradient = not hidden_states.stop_gradient + if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + # for pretrain + outputs = recompute( + attn_compute_func, + hidden_states, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = attn_compute_func(hidden_states) + + return (inputs_embeds_mtp, *outputs) + + def attn_compute_for_fusion(self, args): + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) + assert attention_mask is None + assert attn_mask_startend_row_indices is None + assert position_ids is None + + send_mtp_embed = self.config.send_mtp_embed + + if send_mtp_embed: + # slice from holy tensor + batch_size, _, hidden_size = hidden_states.shape + batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1) + inputs_embeds_mtp = hidden_states[..., batch_size_mtp:] + hidden_states = hidden_states[..., :batch_size_mtp] + + hidden_states, residual = self.self_attn_compute(hidden_states) + _, _, d_model = hidden_states.shape + + if self.using_post_norm_recompute: + probs, routing_map, l_aux, _, norm_out = self.mlp.router(hidden_states) + else: + probs, routing_map, l_aux, _ = self.mlp.router(hidden_states) + + # common return values + ret = ( + hidden_states, + residual, + probs, + routing_map, + l_aux, + ) + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if send_mtp_embed else ret + # append norm_out if using post_norm recompute + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret + + return ret + + def mlp_compute(self, inputs): + if isinstance(inputs, list): + inputs = tuple(inputs) + send_mtp_embed = self.config.send_mtp_embed + + if send_mtp_embed: + ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + intermediate_hidden_states, + dispatched_indices, + dispatched_probs, + ) = inputs + else: + ( + hidden_states, + residual, + l_aux, + intermediate_hidden_states, + dispatched_indices, + dispatched_probs, + ) = inputs + has_gradient = not intermediate_hidden_states.stop_gradient + if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + expert_output = recompute( + self.expert_forward_compute, + intermediate_hidden_states, + dispatched_indices, + dispatched_probs, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + expert_output = self.expert_forward_compute( + intermediate_hidden_states, dispatched_indices, dispatched_probs + ) + if send_mtp_embed: + return (inputs_embeds_mtp, hidden_states, residual, l_aux, expert_output) + else: + return (hidden_states, residual, l_aux, expert_output) + + def post_process_compute(self, inputs): + send_mtp_embed = self.config.send_mtp_embed + + if isinstance(inputs, list): + inputs = tuple(inputs) + if send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, combine_output) = inputs + else: + (hidden_states, residual, l_aux, combine_output) = inputs + has_gradient = not hidden_states.stop_gradient + if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + hidden_states = recompute( + self.post_combine_compute, + residual, + hidden_states, + combine_output, + l_aux, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + hidden_states = self.post_combine_compute( + residual, + hidden_states, + combine_output, + l_aux, + ) + if send_mtp_embed: + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + + return return_args(hidden_states) + + def post_process_compute_for_fusion(self, inputs): + send_mtp_embed = self.config.send_mtp_embed + + if isinstance(inputs, list): + inputs = tuple(inputs) + + if send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states) = inputs + + final_hidden_states = self.mlp.post_process(hidden_states, final_hidden_states, l_aux) + + hidden_states = residual + final_hidden_states + + hidden_states = (hidden_states,) + + if type(hidden_states) is tuple and len(hidden_states) == 1: + hidden_states = hidden_states[0] + + if send_mtp_embed: + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + + return return_args(hidden_states) + + def attn_compute_dense(self, args): + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) + assert attention_mask is None + assert attn_mask_startend_row_indices is None + assert position_ids is None + + if self.config.send_mtp_embed: + batch_size, _, hidden_size = hidden_states.shape + batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1) + inputs_embeds_mtp = hidden_states[..., batch_size_mtp:] + hidden_states = hidden_states[..., :batch_size_mtp] + + hidden_states, residual = self.self_attn_compute(hidden_states) + + ret = (hidden_states, residual) + ret = (inputs_embeds_mtp, *ret) if self.config.send_mtp_embed else ret + return ret + + def mlp_compute_dense(self, inputs): + if self.config.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual) = inputs + else: + (hidden_states, residual) = inputs + + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + if self.config.send_mtp_embed: + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + + return hidden_states + + def build_schedule_node(self): + if isinstance(self.mlp, DeepseekV2MoE): + self.mlp.update_flex_token() + if self.mlp.using_flex_token: + if self.config.dsv3_use_fp8_gemm: + attn_and_gate_node = ScheduleNode(self.attn_compute_for_fusion, name="attn_and_gate_node") + + # recompute_fwd_gate_up_ may be 1, 0 or -1. 1 means recompute, 0 means disable recompute, -1 means adaptive recompute. + recompute_fwd_gate_up_ = 1 if self.layer_idx in self.config.recompute_fwd_gate_up_list else 0 + if recompute_fwd_gate_up_ == 0 and self.config.adaptive_remained_O1_recompute_ratio: + recompute_fwd_gate_up_ = -1 + + fp8_fusion_moe_node = FusionMoeNode( + self.mlp, + recompute_fwd_gate_up=recompute_fwd_gate_up_, + is_split_group_gemm=self.config.is_split_group_gemm, + mlp_fwd_subbatch_rows=self.config.mlp_fwd_subbatch_rows, + mlp_bwd_subbatch_rows=self.config.mlp_bwd_subbatch_rows, + output_subbatch_rows=self.config.output_subbatch_rows, + dsv3_use_fp8_dispatch=self.config.dsv3_use_fp8_dispatch, + name="fp8_fusion_moe_node", + ) + post_process_node = PostProcessNode( + self.config.send_mtp_embed, + self.mlp.training, + self.mlp.alpha, + self.config, + self.mlp.shared_experts, + self.config.using_post_norm_recompute, + output_mtp_embed_first=isinstance(self, DeepseekV2MTPLayer), + name="post_process_node", + ) + return FusionFp8DecoderLayerNode( + attn_and_gate_node=attn_and_gate_node, + fp8_fusion_moe_node=fp8_fusion_moe_node, + post_process_node=post_process_node, + mlp_layer=self.mlp, + send_mtp_embed=self.config.send_mtp_embed, + using_post_norm_recompute=self.config.using_post_norm_recompute, + stepped_recompute_fwd_gate_up=self.config.stepped_recompute_fwd_gate_up, + dsv3_use_fp8_dispatch=self.config.dsv3_use_fp8_dispatch, + name="FusionFp8DecoderLayerNode", + ) + else: + attn_node = ScheduleNode(self.attn_compute, name="attn_node") + mlp_node = ScheduleNode(self.mlp_compute, name="mlp_node") + post_process_node = ScheduleNode(self.post_process_compute, name="post_process_node") + return DecoderLayerNode( + attn_node=attn_node, + dispatch_node=None, + mlp_node=mlp_node, + combine_node=None, + post_process_node=post_process_node, + mlp_layer=self.mlp, + name="DecoderLayerNode", + ) + + attn_node = ScheduleNode(self.attn_compute_dense, name="attn_node") + mlp_node = ScheduleNode(self.mlp_compute_dense, name="mlp_node") + return DenseDecoderLayerNode( + attn_node=attn_node, + mlp_node=mlp_node, + name="DenseDecoderLayerNode", + ) + + +class DeepseekV2MTPLayerPipe(DeepseekV2MTPLayer): + def forward(self, args): + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) + + if self.config.send_mtp_embed: + hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1) + hidden_states_main_model = hidden_states_list[0] + inputs_embeds_cur_depth_list = hidden_states_list[1:] + else: + hidden_states_main_model = hidden_states + global global_inputs_embeds_mtp_queue + inputs_embeds_cur_depth_list = global_inputs_embeds_mtp_queue.get() + + has_gradient = not hidden_states_main_model.stop_gradient + + if attention_mask is not None and attention_mask.dtype == paddle.int32: + attention_mask, attn_mask_startend_row_indices, position_ids = ( + None, + attention_mask, + attn_mask_startend_row_indices, + ) + elif attention_mask is not None and attention_mask.dtype == paddle.int64: + attention_mask, attn_mask_startend_row_indices, position_ids = None, None, attention_mask + elif attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.dtype == paddle.int64: + attn_mask_startend_row_indices, position_ids = None, attn_mask_startend_row_indices + + output_list = [hidden_states_main_model] + hidden_states = hidden_states_main_model + for depth in range(self.config.num_nextn_predict_layers): + inputs_embeds_cur_depth = inputs_embeds_cur_depth_list[depth] + if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + if attention_mask is not None or attn_mask_startend_row_indices is not None: + hidden_states = recompute( + super().forward, + hidden_states, + inputs_embeds_cur_depth, + position_ids=position_ids, + attention_mask=attention_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + use_reentrant=False, + ) + else: + # for pretrain + hidden_states = recompute( + super().forward, + hidden_states, + inputs_embeds_cur_depth, + position_ids=position_ids, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + hidden_states = super().forward( + hidden_states, + inputs_embeds_cur_depth, + position_ids=position_ids, + attention_mask=attention_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + ) + output_list.append(hidden_states) + + hidden_states = paddle.concat(output_list, axis=-1) + return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids) + + def attn_compute_for_fusion(self, args): + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) + assert attention_mask is None + assert attn_mask_startend_row_indices is None + assert position_ids is None + assert self.config.num_nextn_predict_layers == 1 + + if self.config.send_mtp_embed: + hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1) + hidden_states_main_model = hidden_states_list[0] + inputs_embeds_cur_depth_list = hidden_states_list[1:] + else: + hidden_states_main_model = hidden_states + global global_inputs_embeds_mtp_queue + inputs_embeds_cur_depth_list = global_inputs_embeds_mtp_queue.get() + + hidden_states = hidden_states_main_model + nextn_hidden_state = inputs_embeds_cur_depth_list[0] + + # mtp compute + hidden_states = self.hnorm(hidden_states) + nextn_hidden_state = self.enorm(nextn_hidden_state) + + concat_h = paddle.concat([nextn_hidden_state, hidden_states], axis=-1) + hidden_states = FP8LinearFunction.apply(concat_h, self.eh_proj) + + # attention compute + hidden_states, residual = self.self_attn_compute(hidden_states) + + if self.using_post_norm_recompute: + probs, routing_map, l_aux, _, norm_out = self.mlp.router(hidden_states) + else: + probs, routing_map, l_aux, _ = self.mlp.router(hidden_states) + + # common return values + ret = ( + hidden_states_main_model, + hidden_states, + residual, + probs, + routing_map, + l_aux, + ) + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret + + return ret + + def build_schedule_node(self): + if isinstance(self.mlp, DeepseekV2MoE): + self.mlp.update_flex_token() + if ( + self.mlp.using_flex_token + and self.config.dsv3_use_fp8_gemm + and self.config.num_nextn_predict_layers == 1 + ): + prev_send_mtp_embed = self.config.send_mtp_embed + self.config.send_mtp_embed = True # must be True in MTP node + + node = DeepseekV2DecoderLayerPipe.build_schedule_node(self) + assert isinstance(node, FusionFp8DecoderLayerNode) + + self.config.send_mtp_embed = prev_send_mtp_embed + return node + return ScheduleNode(self.forward, name="DeepseekV2MTPLayerPipe") + + +class DeepseekV2RMSNormPipe(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + self.norm = DeepseekV2RMSNorm(config) + + def forward(self, args): + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) + + if self.config.num_nextn_predict_layers > 0: + hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1) + hidden_states = hidden_states_list[0] + hidden_states_mtp = hidden_states_list[-self.config.num_nextn_predict_layers :] + + output_list = [self.norm(hidden_states)] + for hidden_states in hidden_states_mtp: + output_list.append(self.norm(hidden_states)) + return output_list + else: + return self.norm(hidden_states) + + def build_schedule_node(self): + return ScheduleNode(self.forward, name="DeepseekV2RMSNormPipe") + + +class DeepseekV2LMHeadPipe(DeepseekV2LMHead): + def __init__(self, config, embedding_weight=None): + super(DeepseekV2LMHeadPipe, self).__init__(config, embedding_weight=embedding_weight) + + @property + def embedding_weight(self): + return get_attr(self, "weight") + + def forward(self, args: Union[Tuple, paddle.Tensor]): + if self.config.num_nextn_predict_layers > 0: + logits = [] + for _hidden_states in args: + logits.append(super().forward(_hidden_states)) + return logits + hidden_states = args + logits = super().forward(hidden_states) + return logits + + def build_schedule_node(self): + return ScheduleNode(self.forward, name="DeepseekV2LMHeadPipe") + + +class DeepseekV2PretrainingCriterionPipe(DeepseekV2PretrainingCriterionFast): + def forward(self, logits, labels): + if self.config.num_nextn_predict_layers > 0: + mtp_logits = logits[1:] + logits = logits[0] + loss = super().forward(logits, labels, mtp_logits=mtp_logits) + else: + if isinstance(logits, (tuple, list)): + logits = logits[0] + loss = super().forward(logits, labels) + return loss + + def build_schedule_node(self): + return ScheduleNode(self.forward, name="DeepseekV2PretrainingCriterionPipe") + + +class DeepseekV2ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): + """DeepseekV2ForPretraining adapted for pipeline parallelism. + + The largest change is flattening the DeepseekV2Model class so we can express it as a + sequence of layers including embedding, transformer layers, and output. + """ + + config_class = DeepseekV2FastConfig + _base_model = DeepseekV2PretrainedModel + _get_tensor_parallel_mappings = DeepseekV2PretrainedModel._get_tensor_parallel_mappings + _init_weights = DeepseekV2PretrainedModel._init_weights + _keys_to_ignore_on_load_unexpected = DeepseekV2PretrainedModel._keys_to_ignore_on_load_unexpected + _get_model_flops = DeepseekV2PretrainedModel._get_model_flops + _get_hardware_flops = DeepseekV2PretrainedModel._get_hardware_flops + + _tied_weights_keys = ["lm_head.weight"] + + # DONOT Add base_model_prefix !!!! + + def step_flex_token(self, cur_step): + set_global_step(cur_step) + + @classmethod + def _prepare_pipeline_inputs_func(cls, inputs): + first_stage_keys = ["input_ids", "attention_mask", "attn_mask_startend_row_indices", "position_ids"] + last_stage_keys = ["labels"] + + def get_expected_keys(inputs, keys): + ret = tuple([inputs.pop(k) if k in inputs else None for k in keys]) + if len(ret) == 1: + ret = ret[0] + return ret + + if type(inputs) is dict or type(inputs) is OrderedDict: + return [ + get_expected_keys(inputs, first_stage_keys), + get_expected_keys(inputs, last_stage_keys), + ] + + keys = list(inputs[0].keys()) + inputs_batch = {key: [data.pop(key) for data in inputs] for key in keys} + return [ + get_expected_keys(inputs_batch, first_stage_keys), + get_expected_keys(inputs_batch, last_stage_keys), + ] + + def __init__(self, config: DeepseekV2FastConfig): + self.config = config + + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.recompute_granularity = self.config.recompute_granularity + self.pp_recompute_interval = self.config.pp_recompute_interval + self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + if self.recompute_granularity == "full": + assert len(self.no_recompute_layers) == 0, "for pp with full recompute, no_recompute_layers is not support" + + virtual_pp_degree = getattr(self.config, "virtual_pp_degree", 1) + use_dualpipev = getattr(self.config, "use_dualpipev", False) + if use_dualpipev: + assert LocalSharedLayerDesc is not None, "LocalSharedLayerDesc is None, please update your paddle." + shared_class = LocalSharedLayerDesc if use_dualpipev else SharedLayerDesc + + def get_hcg(): + return fleet.get_hybrid_communicate_group() + + hcg = get_hcg() + tensor_parallel_degree = max(hcg.get_model_parallel_world_size(), 1) + tensor_parallel_rank = max(hcg.get_model_parallel_rank(), 0) + + # TODO: fix tensor_parallel_degree rewrite in here + config.tensor_parallel_degree = tensor_parallel_degree + config.tensor_parallel_rank = tensor_parallel_rank + + if config.tie_word_embeddings: + self.add_sequential_layer( + shared_class( + "DeepseekV2_shared_weight", + DeepseekV2EmbeddingPipe, + shared_weight_attr="embedding_weight", + config=config, + ), + self._base_model.base_model_prefix, + ) + else: + self.add_sequential_layer( + LayerDesc(DeepseekV2EmbeddingPipe, config=config), self._base_model.base_model_prefix + ) + + def compute_recompute_fwd_gate_up_list(pp_nums, all_dl_nums, dense_dl_nums, recompute_fwd_gate_up): + all_layers_nums = all_dl_nums + 4 # embedding, rms, lm_head, mtp + segment_size = all_layers_nums // pp_nums + boundary = math.ceil((1 + dense_dl_nums) / segment_size) * segment_size + recompute_fwd_gate_up_list = [dense_dl_nums] + for idx in range(boundary - 1, all_dl_nums, segment_size): + recompute_fwd_gate_up_list.append(idx) + + # If `recompute_fwd_gate_up` is a Boolean value and is True, means all O1 will be recomputed. + # Otherwise `recompute_fwd_gate_up` should be an integer representing how many O1 are recomputed. + assert isinstance(recompute_fwd_gate_up, (int, bool)) + if type(recompute_fwd_gate_up) is bool: + enable_k_o1_rc = segment_size if recompute_fwd_gate_up is True else 0 + else: + enable_k_o1_rc = recompute_fwd_gate_up + + ret = [] + for i in range(len(recompute_fwd_gate_up_list)): + for k in range(min(segment_size, enable_k_o1_rc)): + ret.append(recompute_fwd_gate_up_list[i] + k) + return ret + + def compute_recompute_fa3_list(pp_nums, all_dl_nums, recompute_fa3): + all_layers_nums = all_dl_nums + 4 # embedding, rms, lm_head, mtp + segment_size = all_layers_nums // pp_nums + recompute_fa3_list = [0] + for idx in range(segment_size - 1, all_dl_nums, segment_size): + recompute_fa3_list.append(idx) + + # If `recompute_fa3` is a Boolean value and is True, means all O1 will be recomputed. + # Otherwise `recompute_fa3` should be an integer representing how many O1 are recomputed. + assert isinstance(recompute_fa3, (int, bool)) + if type(recompute_fa3) is bool: + enable_k_o1_rc = segment_size if recompute_fa3 is True else 0 + else: + enable_k_o1_rc = recompute_fa3 + + ret = [] + for i in range(len(recompute_fa3_list)): + for k in range(min(segment_size, enable_k_o1_rc)): + ret.append(recompute_fa3_list[i] + k) + return ret + + pp_nums = ( + self.config["pipeline_parallel_degree"] * 2 + if self.config.use_dualpipev + else self.config["pipeline_parallel_degree"] + ) + recompute_fwd_gate_up_list = compute_recompute_fwd_gate_up_list( + pp_nums, + self.config.num_hidden_layers, + self.config.first_k_dense_replace, + self.config.recompute_fwd_gate_up, + ) + recompute_fa3_list = compute_recompute_fa3_list( + pp_nums, self.config.num_hidden_layers, self.config.recompute_fa3 + ) + + logger.info(f"recompute_fa3_list: {recompute_fa3_list}") + logger.info(f"recompute_fwd_gate_up_list: {recompute_fwd_gate_up_list}") + config.recompute_fwd_gate_up_list = recompute_fwd_gate_up_list + + for i in range(config.num_hidden_layers): + self.add_sequential_layer( + LayerDesc( + DeepseekV2DecoderLayerPipe, + config=config, + layer_idx=i, + layerwise_recompute=i not in self.no_recompute_layers, + recompute_fa3=i in recompute_fa3_list, + ), + f"{self._base_model.base_model_prefix}.layers.{i}", + ) + for i in range(config.num_nextn_predict_layers): + self.add_sequential_layer( + LayerDesc(DeepseekV2MTPLayerPipe, config=config, layer_idx=config.num_hidden_layers + i), + f"{self._base_model.base_model_prefix}.layers.{config.num_hidden_layers + i}", + ) + + self.add_sequential_layer(LayerDesc(DeepseekV2RMSNormPipe, config=config), self._base_model.base_model_prefix) + + if config.tie_word_embeddings: + self.add_sequential_layer( + shared_class( + "DeepseekV2_shared_weight", + DeepseekV2LMHeadPipe, + shared_weight_attr="embedding_weight", + config=config, + **{"transpose_y": True}, + ), + "lm_head", + ) + else: + self.add_sequential_layer(LayerDesc(DeepseekV2LMHeadPipe, config=config), "lm_head") + + recompute_interval = 0 + if self.enable_recompute and self.recompute_granularity == "full": + assert self.config.pp_recompute_interval <= config.num_hidden_layers // ( + virtual_pp_degree * get_hcg().topology().get_dim_size("pipe") + ), "pp recompute interval should smaller than num layers of each pp chunk" + recompute_interval = self.config.pp_recompute_interval + + seg_method = "layer:DeepseekV2DecoderLayer|DeepseekV2MTPLayerPipe" + if config.num_hidden_layers % get_hcg().topology().get_dim_size("pipe") != 0: + seg_method = "uniform" + + PipelineLayer.__init__( + self, + layers=self.get_sequential_layers(), + loss_fn=self.get_loss_fn(config), + topology=get_hcg().topology(), + seg_method=seg_method, + recompute_interval=recompute_interval, + recompute_ctx={ + "mp_group": get_hcg().get_model_parallel_group(), + "offload": False, + "partition": False, + }, + num_virtual_pipeline_stages=virtual_pp_degree, + use_dualpipev=use_dualpipev, + ) + # You should call init here, since there is a diamond inheritance problem + self.apply(self._init_weights) + # DON'T init PipelinePretrainedModel + # PipelinePretrainedModel.__init__(self.super(), config=config) + + def fp8_quant_weight(self, batch_mode=False, quant_transpose=True): + """fp8_quant_weight""" + with paddle.no_grad(): + for i, layer in self._sub_layers.items(): + if isinstance( + layer, paddle.distributed.fleet.meta_parallel.parallel_layers.pp_layers.PipelineLayerChunk + ): + for i, sub_layer in layer.named_sublayers(): + if isinstance(sub_layer, DeepseekV2DecoderLayer) and hasattr(sub_layer, "fp8_quant_weight"): + sub_layer.fp8_quant_weight(batch_mode, quant_transpose) + if isinstance(layer, DeepseekV2DecoderLayer) and hasattr(layer, "fp8_quant_weight"): + layer.fp8_quant_weight(batch_mode, quant_transpose) + + def get_loss_fn(self, config): + return DeepseekV2PretrainingCriterionPipe(config) + + def overlapped_forward_backward( + self, + forward_chunk, # the module of the forward chunk + forward_inputs, + forward_loss_fn_node, + backward_chunk, # the module of the backward chunk, maybe not used + backward_loss_fn_node, + backward_input_grads, + scaler, + combine_bw_event_to_wait=None, + pp_stream=None, + ): + if backward_loss_fn_node is not None: + if scaler: + backward_input_grads = backward_loss_fn_node.backward(scaler=scaler) + else: + backward_input_grads = backward_loss_fn_node.backward() + + ( + forward_pre_node, + backward_pre_node, + overlap_node, + forward_post_node, + backward_post_node, + ) = build_overlapped_nodes(self.config, forward_chunk, backward_chunk) + forward_inputs = forward_pre_node.forward(forward_inputs) + backward_input_grads = backward_pre_node.backward(backward_input_grads) + forward_inputs, backward_input_grads, _ = overlap_node.forward_backward( + forward_inputs, + backward_input_grads, + combine_bw_event_to_wait=combine_bw_event_to_wait, + pp_stream=pp_stream, + ) + forward_inputs = forward_post_node.forward(forward_inputs) + backward_input_grads = backward_post_node.backward(backward_input_grads) + + if forward_loss_fn_node is not None: + forward_loss = forward_loss_fn_node.forward(forward_inputs) + else: + forward_loss = None + + forward_inputs = [forward_inputs] if isinstance(forward_inputs, paddle.Tensor) else forward_inputs + return forward_inputs, forward_loss, backward_input_grads diff --git a/examples/experiments/deepseek_v3_pretrain/moe_gate.py b/examples/experiments/deepseek_v3_pretrain/moe_gate.py new file mode 100644 index 00000000000..49ff7cf0b7d --- /dev/null +++ b/examples/experiments/deepseek_v3_pretrain/moe_gate.py @@ -0,0 +1,452 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Tuple + +import paddle +import paddle.distributed as dist +import paddle.nn as nn +import paddle.nn.functional as F + +from paddleformers.transformers import MoEGateMixin + + +class PretrainedMoEGate(nn.Layer, MoEGateMixin): + def __init__(self, config, num_experts, expert_hidden_size, **kwargs): + super(PretrainedMoEGate, self).__init__() + + self.config = config + + self.num_experts = num_experts + self.expert_hidden_size = expert_hidden_size + + # force keep in float32 when using amp + self._cast_to_low_precision = False + + self.capacity_factor = kwargs.pop("capacity_factor", 1.0) + self.eval_capacity_factor = kwargs.pop("eval_capacity_factor", 1.0) + self.min_capacity = kwargs.pop("min_capacity", 1.0) + self.max_capacity = kwargs.pop("max_capacity", pow(2, 32)) + + self.group = kwargs.pop("group", None) + self.global_aux_loss = kwargs.pop("global_aux_loss", False) + if self.global_aux_loss: + assert self.group is not None, "group is required when global_aux_loss is True" + self.rank = dist.get_rank(self.group) + + self.expert_drop = kwargs.pop("expert_drop", False) + self.noisy_gate_policy = kwargs.pop("noisy_gate_policy", None) + self.drop_tokens = kwargs.pop("drop_tokens", True) + self.use_rts = kwargs.pop("use_rts", True) + self.top2_2nd_expert_sampling = kwargs.pop("top2_2nd_expert_sampling", True) + + self.drop_policy = kwargs.pop("drop_policy", "probs") + # Qwen2MoE: greedy + # DeepSeekV2&V3: group_limited_greedy for training, and noaux_tc for inference + self.topk_method = kwargs.pop("topk_method", "greedy") + self.top_k = kwargs.pop("top_k", 2) + self.n_group = kwargs.pop("n_group", 1) # for group_limited_greedy + self.topk_group = kwargs.pop("topk_group", 1) # for group_limited_greedy + self.norm_topk_prob = kwargs.pop("norm_topk_prob", False) + self.routed_scaling_factor = kwargs.pop("routed_scaling_factor", 1.0) + + # for flex token moe layer + self.using_flex_token = kwargs.pop("using_flex_token", False) + + def _priority(self, topk_idx: paddle.Tensor, capacity: int) -> paddle.Tensor: + """_summary_ + The priority is the cumulative sum of the expert indices. + + This method is used in hunyuan model + Args: + topk_idx (paddle.Tensor): [batch_size * seq_len, topk] + + Returns: + paddle.Tensor: cumsum locations + """ + _, k = topk_idx.shape + # Shape: [seq_len * k] + chosen_expert = topk_idx.reshape([-1]) + # Shape: [seq_len * k, num_experts]. + token_priority = F.one_hot(chosen_expert, self.num_experts).cast(paddle.int32) + token_priority = paddle.logical_and(token_priority > 0, token_priority.cumsum(axis=0) < capacity) + # Shape: [seq_len, num_experts]. + token_priority = token_priority.reshape([-1, k, self.num_experts]).sum(axis=1) + + return (token_priority > 0.0).astype("float32") + + def _topk_greedy(self, scores: paddle.Tensor, k: int) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + """ + topk_weight, topk_idx = paddle.topk(scores, k=k, axis=-1, sorted=True) + return topk_weight, topk_idx + + def _topk_group_limited_greedy( + self, scores: paddle.Tensor, k: int, n_group: int, topk_group: int + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts in each group + n_groups (int): the number of groups for all experts + topk_group (int): the number of groups selected + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + + Note: the group size is normal greater than the number of k + """ + bsz_seq_len, n_experts = scores.shape + assert n_experts % n_group == 0, "n_experts must be divisible by n_groups" + + group_scores = scores.reshape([0, n_group, -1]).max(axis=-1) # [n, n_group] + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [n, top_k_group] + group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.ones([], dtype="float32"), axis=-1) # fmt:skip + score_mask = ( + group_mask.unsqueeze(-1).expand([bsz_seq_len, n_group, n_experts // n_group]).reshape([bsz_seq_len, -1]) + ) # [n, e] + tmp_scores = scores * score_mask # [n, e] + topk_weight, topk_idx = paddle.topk(tmp_scores, k=k, axis=-1, sorted=True) + + return topk_weight, topk_idx + + def _topk_noaux_tc( + self, scores: paddle.Tensor, k: int, n_group: int, topk_group: int + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts in each group + n_groups (int): the number of groups for all experts + topk_group (int): the number of groups selected + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + + Note: the group size is normal greater than the number of k + """ + bsz_seq_len, n_experts = scores.shape + assert n_experts % n_group == 0, "n_experts must be divisible by n_groups" + + assert self.e_score_correction_bias is not None, "e_score_correction_bias is None" + scores_for_choice = scores.reshape([bsz_seq_len, -1]) + self.e_score_correction_bias.unsqueeze(0) + reshape_tmp_rst = scores_for_choice.reshape([bsz_seq_len, self.n_group, -1]) + top_k = min(reshape_tmp_rst.shape[2], 2) + group_scores = reshape_tmp_rst.topk(top_k, axis=-1)[0].sum(axis=-1) # fmt:skip [n, n_group] + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [n, top_k_group] + group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.ones([], dtype="float32"), axis=-1) # fmt:skip + score_mask = ( + group_mask.unsqueeze(-1).expand([bsz_seq_len, n_group, n_experts // n_group]).reshape([bsz_seq_len, -1]) + ) # [n, e] + tmp_scores = scores_for_choice * score_mask # [n, e] + topk_weight, topk_idx = paddle.topk(tmp_scores, k=k, axis=-1, sorted=True) + topk_weight = scores.take_along_axis(topk_idx, axis=1) if not self.training else topk_weight + + return topk_weight, topk_idx + + def top1gating( + self, + logits: paddle.Tensor, + used_token: paddle.Tensor = None, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Implements Top1Gating on logits.""" + if self.noisy_gate_policy == "RSample": + logits += self.gumbel_rsample(logits.shape) + + gates = self.gate_score_func(logits=logits) + capacity = self._capacity(gates, self.capacity_factor, self.max_capacity, self.min_capacity) + + # Create a mask for 1st's expert per token + # noisy gating + # Only save the position of the maximum value + indices1_s = paddle.argmax(logits if self.noisy_gate_policy == "RSample" else gates, axis=1) + # Convert the position of the maximum value to a one-hot vector [s, e] + mask1 = self._one_hot_to_float(indices1_s, num_classes=self.num_experts) + + # mask only used tokens + if used_token is not None: + mask1 = paddle.einsum( + "s,se->se", used_token, mask1 + ) # Element-wise multiply used_token with mask1 to obtain a new mask1 + + # gating decisions + exp_counts = paddle.sum(mask1, axis=0) # Calculate the number of tokens for each expert + + # if we don't want to drop any tokens + if not self.drop_tokens: + new_capacity = paddle.max(exp_counts) # Calculate the number of tokens for each expert + # Communicate across expert processes to pick the maximum capacity. + if self.group is not None: + dist.all_reduce( + new_capacity, op=dist.ReduceOp.MAX, group=self.group + ) # Calculate the maximum value among expert processes + # Make sure the capacity value does not exceed the number of tokens. + capacity = int(min(new_capacity, paddle.tensor(mask1.size(0)))) + + l_aux = self._cal_aux_loss(gates, mask1) + l_zloss = self._cal_z_loss(logits) + + # Random Token Selection + if self.use_rts: + mask1_rand = mask1 * self.uniform_sample(mask1) + else: + mask1_rand = mask1 + + assert ( + logits.shape[0] >= self.min_capacity + ), "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size." + + _, top_idx = paddle.topk(mask1_rand, k=capacity, axis=0) # Select top_capacity tokens + + new_mask1 = mask1 * paddle.zeros_like(mask1).put_along_axis(top_idx, paddle.ones([], dtype="float32"), axis=0) + mask1 = new_mask1 + + # Compute locations in capacity buffer + locations1 = paddle.cumsum(mask1, axis=0) - 1 # Compute the position of each token in mask1 + + # Store the capacity location for each token + locations1_s = paddle.sum(locations1 * mask1, axis=1).cast(paddle.int64) + + # Normalize gate probabilities + mask1_float = mask1.cast(paddle.float32) + gates = gates / gates * mask1_float + + locations1_sc = self._one_hot_to_float(locations1_s, capacity) + combine_weights = paddle.einsum("se,sc->sec", gates, locations1_sc) + dispatch_mask = combine_weights.cast(paddle.bool).detach() + + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss + + def top2gating( + self, + logits: paddle.Tensor, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + # everything is in fp32 in this function + gates = self.gate_score_func(logits=logits) + + # Create a mask for 1st's expert per token. + indices1_s = paddle.argmax(gates, axis=1) # [S, 1] + mask1 = self._one_hot_to_int64(indices1_s, self.num_experts) # [S, E] + + if self.top2_2nd_expert_sampling: + # Create a mask for 2nd's expert per token using Gumbel-max trick. + # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ + logits += self.gumbel_rsample(logits) + + # Replace top-expert with min value + logits_except1 = logits.masked_fill(mask1.cast(paddle.bool), float("-inf")) # [S, E] + indices2_s = paddle.argmax(logits_except1, axis=1) # [S, 1] + mask2 = self._one_hot_to_int64(indices2_s, self.num_experts) # [S, E] + + # Note: mask1 and mask2 can be combined to form a single mask. + # mask = paddle.concat([mask1, mask2], axis=0) + # locations = paddle.cumsum(mask, axis=0) - 1 + # locations1, locations2 = locations.split(2, axis=0) + # Compute locations in capacity buffer. + locations1 = paddle.cumsum(mask1, axis=0) - 1 # [S, E] + locations2 = paddle.cumsum(mask2, axis=0) - 1 # [S, E] + # Update 2nd's location by accounting for locations of 1st. + locations2 += paddle.sum(mask1, axis=0, keepdim=True) + + l_aux = self._cal_aux_loss(gates, mask1) + l_zloss = self._cal_z_loss(logits) + + # gating decisions + exp_counts = paddle.sum(mask1 + mask2, axis=0) + if self.drop_tokens: + # Calculate configured capacity and remove locations outside capacity from mask + capacity = self._capacity(gates, self.capacity_factor, self.max_capacity, self.min_capacity) + # Remove locations outside capacity from mask. + mask1 *= (locations1 < capacity).cast(paddle.int64) + mask2 *= (locations2 < capacity).cast(paddle.int64) + else: + # Do not drop tokens - set capacity according to current expert assignments + new_capacity = paddle.max(exp_counts) + if self.group is not None: + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=self.group) + capacity = int(new_capacity) + + # Store the capacity location for each token. + locations1_s = paddle.sum(locations1 * mask1, axis=1) + locations2_s = paddle.sum(locations2 * mask2, axis=1) + + # Normalize gate probabilities + mask1_float = mask1.cast(paddle.float32) + mask2_float = mask2.cast(paddle.float32) + gates1_s = paddle.einsum("se,se->s", gates, mask1_float) + gates2_s = paddle.einsum("se,se->s", gates, mask2_float) + denom_s = gates1_s + gates2_s + # Avoid divide-by-zero + denom_s = paddle.clip(denom_s, min=paddle.finfo(denom_s.dtype).eps) + gates1_s /= denom_s + gates2_s /= denom_s + + # Calculate combine_weights and dispatch_mask + gates1 = paddle.einsum("s,se->se", gates1_s, mask1_float) + gates2 = paddle.einsum("s,se->se", gates2_s, mask2_float) + locations1_sc = self._one_hot_to_float(locations1_s, capacity) + locations2_sc = self._one_hot_to_float(locations2_s, capacity) + combine1_sec = paddle.einsum("se,sc->sec", gates1, locations1_sc) + combine2_sec = paddle.einsum("se,sc->sec", gates2, locations2_sc) + combine_weights = combine1_sec + combine2_sec + dispatch_mask = combine_weights.cast(paddle.bool) + + return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss + + def topkgating( + self, + gates: paddle.Tensor, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Implements TopKGating on logits.""" + batch_size, seq_len, d_model = gates.shape + gates_ori = gates + gates = gates.reshape([-1, d_model]) + + l_zloss = self._cal_z_loss(gates) + + # get topk gates + if self.topk_method == "greedy": + top_gate, top_idx = self._topk_greedy(gates, k=self.top_k) + elif self.topk_method == "group_limited_greedy": + top_gate, top_idx = self._topk_group_limited_greedy( + gates, k=self.top_k, n_group=self.n_group, topk_group=self.topk_group + ) + elif self.topk_method == "noaux_tc": + top_gate, top_idx = self._topk_noaux_tc( + gates, k=self.top_k, n_group=self.n_group, topk_group=self.topk_group + ) + # norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = top_gate.sum(axis=-1, keepdim=True) + 1e-20 + top_gate = top_gate / denominator + top_gate = top_gate * self.routed_scaling_factor + + # get topk mask + mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.ones([], dtype="float32"), axis=1) + if hasattr(self.config, "seq_aux") and self.config.seq_aux: + l_aux = self._cal_seq_aux_loss(gates_ori, self.top_k, top_idx) + else: + l_aux = self._cal_aux_loss(gates, mask) + + exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0) + + if self.drop_tokens: + # Calculate configured capacity and remove locations outside capacity from mask + capacity = self._capacity( + gates, + self.capacity_factor * self.top_k, + self.max_capacity, + self.min_capacity, + ) + + # update mask and locations by capacity + if self.drop_policy == "probs": + topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1) + capacity_probs, capacity_indices = paddle.topk(topk_masked_gates, k=capacity, axis=0, sorted=False) + token_priority = self._priority(capacity_indices, capacity) + + elif self.drop_policy == "position": + token_priority = self._priority(top_idx, capacity) + else: + raise ValueError(f"Invalid drop_policy: {self.drop_policy}") + else: + # Do not drop tokens - set capacity according to current expert assignments + local_capacity = paddle.max(exp_counts) + if self.group is not None: + dist.all_reduce(local_capacity, op=dist.ReduceOp.MAX, group=self.group) + capacity = int(local_capacity) + token_priority = self._priority(top_idx, capacity) + + # normalize gates + gates_masked = gates * mask + if self.training: + gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True) + denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps) + if self.norm_topk_prob: + gates_masked = gates_masked / denom_s + + return ( + capacity, + gates_masked.take_along_axis(top_idx, axis=-1), + top_idx, + token_priority.take_along_axis(top_idx, axis=-1), + l_aux, + l_zloss, + ) + + def topkgating_nodrop(self, gates: paddle.Tensor): + """Implements TopKGating on logits.""" + batch_size, seq_len, d_model = gates.shape + gates_ori = gates + gates = gates.reshape([-1, d_model]) + + l_zloss = self._cal_z_loss(gates) + + # get topk gates + if self.topk_method == "greedy": + top_gate, top_idx = self._topk_greedy(gates, k=self.top_k) + elif self.topk_method == "group_limited_greedy": + top_gate, top_idx = self._topk_group_limited_greedy( + gates, k=self.top_k, n_group=self.n_group, topk_group=self.topk_group + ) + elif self.topk_method == "noaux_tc": + top_gate, top_idx = self._topk_noaux_tc( + gates, k=self.top_k, n_group=self.n_group, topk_group=self.topk_group + ) + + # norm gate to sum 1 + # if self.top_k > 1 and self.norm_topk_prob: + # denominator = top_gate.sum(axis=-1, keepdim=True) + 1e-20 + # top_gate = top_gate / denominator + # top_gate = top_gate * self.routed_scaling_factor + + # get topk mask + mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.ones([], dtype="float32"), axis=1) + + gates_masked = gates * mask + gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True) + denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps) + + if self.norm_topk_prob: + gates_masked = gates_masked / denom_s + + gates_masked *= self.routed_scaling_factor + + if hasattr(self.config, "seq_aux") and self.config.seq_aux: + l_aux = self._cal_seq_aux_loss(gates_ori, self.top_k, top_idx) + else: + l_aux = self._cal_aux_loss(gates, mask) + + exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0) + # topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1) + return gates_masked, mask, exp_counts, l_aux, l_zloss diff --git a/examples/experiments/deepseek_v3_pretrain/moe_layer.py b/examples/experiments/deepseek_v3_pretrain/moe_layer.py new file mode 100644 index 00000000000..a51d181012d --- /dev/null +++ b/examples/experiments/deepseek_v3_pretrain/moe_layer.py @@ -0,0 +1,1140 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import numpy as np +import paddle +import paddle.distributed as dist +from moe_gate import PretrainedMoEGate +from moe_utils import ( + UnZipNode, + ZipNode, + merge_subbatch_cast, + tokens_zip_unique_add_with_subbatch, +) +from paddle import nn +from token_dispatcher import MoEFlexTokenDispatcherFast as MoEFlexTokenDispatcher +from token_dispatcher import PreDispatchNode + +from paddleformers.transformers import _AllToAll +from paddleformers.transformers.fp8_utils import ( + FP8GroupGemmMlpFunctionNode, + extract_first_if_tuple, +) +from paddleformers.transformers.fused_a2a import ( + CombineNode, + DispatchNode, + get_buffer, + get_hidden_bytes, +) +from paddleformers.transformers.moe_utils import offload, reload +from paddleformers.utils.log import logger + +try: + import paddle.distributed.communication.deep_ep as deep_ep +except ImportError: + deep_ep = None + +try: + import TokenDispatcherUtils as TDU +except ImportError: + TDU = None + + +def record_stream_for_multi_input(x): + if isinstance(x, (tuple, list)): + for i in range(len(x)): + x[i]._record_stream() + else: + x._record_stream() + + +def stop_gradient_for_multi_input(x): + if isinstance(x, (tuple, list)): + x[0].stop_gradient = False + else: + x.stop_gradient = False + + +class MoELayer(nn.Layer): + def __init__( + self, + config, + moe_num_experts: int, + expert_class: nn.Layer, + expert_kwargs: dict, + gate: PretrainedMoEGate, + capacity: int = 1.0, + moe_group: str = "data", + all_to_all_dropout=0.0, + using_post_norm_recompute=False, + ): + super().__init__() + + self.config = config + + self.moe_num_experts = moe_num_experts + self.capacity = capacity + + try: + dist.fleet.get_hybrid_communicate_group() + is_fleet_init = True + except AttributeError: + is_fleet_init = False + + if is_fleet_init and dist.get_world_size() > 1: + if moe_group == "data": + self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + elif moe_group == "expert": + self.moe_group = dist.fleet.get_hybrid_communicate_group().expert_parallel_group + self.moe_rank = dist.get_rank(self.moe_group) + self.moe_rank = 0 if self.moe_rank < 0 else self.moe_rank + self.expert_parallel_degree = dist.get_world_size(self.moe_group) + self.expert_parallel_degree = 1 if self.expert_parallel_degree < 0 else self.expert_parallel_degree + self.moe_num_experts_per_device = self._parse_moe_expert_parallel( + self.moe_num_experts, self.expert_parallel_degree + ) + self.is_dummy_moe = False if self.expert_parallel_degree > 1 else True + else: + # when moe_group is dummy, we don't need to use all_to_all + self.moe_group = None + self.moe_rank = 0 + self.expert_parallel_degree = 1 + self.moe_num_experts_per_device = self.moe_num_experts + self.is_dummy_moe = True + + self.all_to_all_dropout = all_to_all_dropout + self.enable_recompute = False + + self.experts = nn.LayerList([]) + for i in range(self.moe_num_experts): + if i // self.moe_num_experts_per_device == self.moe_rank: + self.experts.append(expert_class(**expert_kwargs)) + else: + self.experts.append(None) + + self.gate = gate + self.gate.group = self.moe_group + # for flex token moe layer + self.router = gate + self.ep_size = dist.get_world_size(self.moe_group) + self.moe_router_topk = gate.top_k + self.num_local_experts = moe_num_experts // self.ep_size + if self.moe_group is not None: + self.token_dispatcher = MoEFlexTokenDispatcher( + self.num_local_experts, self.moe_router_topk, self.moe_num_experts, self.moe_group + ) + self.token_drop_steps = config.token_drop_steps if hasattr(config, "token_drop_steps") else None + self.using_flex_token = False + + self.using_post_norm_recompute = using_post_norm_recompute + self._post_init() + + def update_flex_token(self): + from modeling import get_global_step + + if ( + (not hasattr(self.config, "using_flex_token")) + or (not self.config.using_flex_token) + or (get_global_step() < self.token_drop_steps) + ): + self.using_flex_token = False + self.router.using_flex_token = False + else: + if not self.using_flex_token: + logger.info("Changing to flex token moe mode") + self.using_flex_token = True + self.router.using_flex_token = True + + def _parse_moe_expert_parallel(self, moe_num_experts, expert_parallel_degree): + assert ( + moe_num_experts >= expert_parallel_degree + ), f"expert moe_num_experts={moe_num_experts} >= moe_world_size={expert_parallel_degree}" + assert ( + moe_num_experts % expert_parallel_degree == 0 + ), f"expert moe_num_experts={moe_num_experts} % moe_world_size={expert_parallel_degree} == 0" + moe_num_experts_per_device = moe_num_experts // expert_parallel_degree + return moe_num_experts_per_device + + def _post_init(self): + for p in self.gate.parameters(): + p.is_gate = True + + for k in self.experts: + if k is not None: + for p in k.parameters(): + p.expert = not self.is_dummy_moe + p.no_sync = not self.is_dummy_moe + # logger.info(f"expert param={p.name}, no-sync={p.no_sync}") + + def forward( + self, + hidden_states: paddle.Tensor, + probs=None, + routing_map=None, + capacity=None, + topk_weight=None, + topk_ids=None, + token_priority=None, + l_aux=None, + l_zloss=None, + ): + self.update_flex_token() + + if self.using_flex_token: + return self.forward_flex_token(hidden_states, probs, routing_map, l_aux, l_zloss) + else: + return self.forward_drop_token( + hidden_states, capacity, topk_weight, topk_ids, token_priority, l_aux, l_zloss + ) + + def forward_drop_token( + self, + hidden_state: paddle.Tensor, + capacity=None, + topk_weight=None, + topk_ids=None, + token_priority=None, + l_aux=None, + l_zloss=None, + ): + """MoE Layer forward function + 1. Gate Forward. + 2. Dispatch export. + 3. Experts Forward. + + Args: + hidden_state: MoE Layer input + + Returns: + final_out: MoE Layer main output. + l_aux: MoE auxiliary loss. l_zloss: MoE z loss.""" + batch_size, seq_len, d_model = hidden_state.shape + + reshaped_input = hidden_state.reshape([-1, d_model]) + + # self.l_aux : + # topk_weight : se + # topk_ids : sk + # token_priority : se + # self.exp_counts : + if self.using_post_norm_recompute: + assert ( + capacity is not None + and topk_weight is not None + and topk_ids is not None + and token_priority is not None + and l_aux is not None + and l_zloss is not None + ) + else: + capacity, topk_weight, topk_ids, token_priority, l_aux, l_zloss = self.gate(hidden_state) + + """MoE expert dispatch from: https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py""" + cnts = paddle.zeros([topk_ids.shape[0], len(self.experts)], dtype=topk_ids.dtype) + cnts = cnts.put_along_axis(topk_ids, 1, axis=1) + + tokens_per_expert = cnts.sum(axis=0) + idxs = topk_ids.reshape([topk_ids.shape[0] * topk_ids.shape[1]]).argsort() + sorted_tokens = reshaped_input[idxs // topk_ids.shape[1]] + tokens_per_expert = tokens_per_expert.detach() + sorted_tokens_shape = sorted_tokens.shape + + if self.expert_parallel_degree > 1: + tokens_per_ep_rank = tokens_per_expert.reshape([self.expert_parallel_degree, -1]).sum(axis=1) + tokens_per_expert_group = _AllToAll.apply( + [tokens_per_expert.shape[0]], tokens_per_expert, group=self.moe_group + ) + output_splits = ( + tokens_per_expert_group.reshape([self.expert_parallel_degree, -1]).sum(axis=1).cpu().tolist() + ) + input_split_sizes = tokens_per_ep_rank.cpu().tolist() + gathered_tokens = _AllToAll.apply( + [tokens_per_expert_group.sum(axis=0).cpu().item(), sorted_tokens.shape[1]], + sorted_tokens, + out_split_sizes=output_splits, + in_split_sizes=input_split_sizes, + group=self.moe_group, + ) + + tokens_per_expert_post_gather = tokens_per_expert_group.reshape( + [self.expert_parallel_degree, self.moe_num_experts_per_device] + ).sum(axis=0) + gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) + s = 0 + for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): + gatherd_idxs[s : s + k] = i % self.moe_num_experts_per_device + s += k + gatherd_idxs = gatherd_idxs.argsort() + sorted_tokens = gathered_tokens[gatherd_idxs] + tokens_per_expert = tokens_per_expert_post_gather + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.moe_rank * self.moe_num_experts_per_device] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + outs = paddle.concat(outputs, axis=0) if len(outputs) > 0 else paddle.to_tensor(0, dtype=sorted_tokens.dtype) + if self.expert_parallel_degree > 1: + new_x = paddle.empty_like(outs) + new_x[gatherd_idxs] = outs + gathered_tokens = _AllToAll.apply( + sorted_tokens_shape, + new_x, + out_split_sizes=input_split_sizes, + in_split_sizes=output_splits, + group=self.moe_group, + ) + outs = gathered_tokens + + new_x = paddle.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.reshape(topk_ids.shape + [-1]) + .astype(topk_weight.dtype) + .multiply_(topk_weight.unsqueeze(-1)) + .multiply_(token_priority.unsqueeze(-1)) + .sum(axis=1) + .astype(new_x.dtype) + .reshape([batch_size, seq_len, -1]) + ) + + return final_out, l_aux, l_zloss + + +class MoEFlexTokenLayer(nn.Layer): + def __init__(self, config, moe_num_experts, expert_class, expert_kwargs, gate, moe_group): + + super().__init__() + self.config = config + self.moe_group = moe_group + self.ep_size = dist.get_world_size(self.moe_group) + self.moe_router_topk = gate.top_k + self.moe_num_experts = moe_num_experts + self.num_local_experts = moe_num_experts // self.ep_size + self.token_dispatcher = MoEFlexTokenDispatcher( + self.num_local_experts, self.moe_router_topk, self.moe_num_experts, moe_group + ) + + self.experts = nn.LayerList([expert_class(**expert_kwargs)] * self.num_local_experts) + self.router = gate + + def expert_forward(self, dispatched_input, tokens_per_expert): + outputs = [] + tokens_per_expert = tokens_per_expert.tolist() + # print(f"all tokens: {sum(tokens_per_expert)}, detail: {tokens_per_expert}") + chunks = paddle.split(dispatched_input, num_or_sections=tokens_per_expert, axis=0) + for chunk, expert in zip(chunks, self.experts): + chunk = chunk.contiguous() + # assert chunk.shape[0] != 0, "Cannot dispatch empty input" + outputs += [expert(chunk)] + + return paddle.concat(outputs, axis=0) + + def forward(self, hidden_states: paddle.Tensor): + _, _, d_model = hidden_states.shape + # reshaped_input = hidden_states.reshape([-1, d_model]) + probs, routing_map, l_aux, l_zloss = self.router(hidden_states) + (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation( + hidden_states, probs, routing_map + ) + expert_output = self.expert_forward(dispatched_input, tokens_per_expert) + output, _ = self.token_dispatcher.token_unpermutation(expert_output, None) + return output, l_aux, l_zloss + + def forward_flex_token(self, hidden_states: paddle.Tensor, probs=None, routing_map=None, l_aux=None, l_zloss=None): + _, _, d_model = hidden_states.shape + # reshaped_input = hidden_states.reshape([-1, d_model]) + if self.using_post_norm_recompute: + assert probs is not None and routing_map is not None and l_aux is not None and l_zloss is not None + else: + probs, routing_map, l_aux, l_zloss = self.router(hidden_states) + if hasattr(self.config, "dsv3_use_fp8_gemm") and self.config.dsv3_use_fp8_gemm: + if hasattr(self.config, "dsv3_use_fp8_dispatch") and self.config.dsv3_use_fp8_dispatch: + output = FusionMoe.apply( + hidden_states, + probs, + routing_map, + self, + recompute_fwd_gate_up=self.config.recompute_fwd_gate_up, + is_split_group_gemm=self.config.is_split_group_gemm, + ) + else: + hidden_states, token_indices, token_probs = self.token_dispatcher.pre_dispatch( + hidden_states, probs, routing_map + ) + output = FusionMoe.apply( + hidden_states, + token_indices, + token_probs, + self, + recompute_fwd_gate_up=self.config.recompute_fwd_gate_up, + is_split_group_gemm=self.config.is_split_group_gemm, + ) + else: + ( + dispatched_input, + token_permuted_indices, + prob_permuted_indices, + dispatched_probs, + ) = self.token_dispatcher.token_permutation(hidden_states, probs, routing_map) + + expert_output = self.expert_forward(dispatched_input) + output, _ = self.token_dispatcher.token_unpermutation( + expert_output, token_permuted_indices, prob_permuted_indices, dispatched_probs, None + ) + return output, l_aux, l_zloss + + def get_tokens_per_expert(self): + return self.token_dispatcher._comm_manager.tokens_per_expert_list + + def set_tokens_per_expert(self, tokens_per_expert_list): + self.token_dispatcher._comm_manager.tokens_per_expert_list = tokens_per_expert_list + + def pre_dispatch_compute(self, hidden_states): + _, _, d_model = hidden_states.shape + probs, routing_map, l_aux, l_zloss = self.router(hidden_states) + hidden_states, token_indices, token_probs = self.token_dispatcher.pre_dispatch( + hidden_states, probs, routing_map + ) + return l_aux, l_zloss, hidden_states, token_indices, token_probs + + def post_dispatch_compute(self, hidden_states, dispatched_indices, dispatched_probs): + (global_input_tokens, token_permuted_indices, prob_permuted_indices) = self.token_dispatcher.post_dispatch( + hidden_states, dispatched_indices + ) + return (global_input_tokens, token_permuted_indices, prob_permuted_indices) + + def pre_combine_compute(self, hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs): + hidden_states = self.token_dispatcher.pre_combine( + hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs + ) + return hidden_states + + def post_combine_compute(self, hidden_states): + hidden_states = self.token_dispatcher.post_combine(hidden_states) + return hidden_states + + +class Fp8DispatchQuantNode: + def __init__(self, token_dispatcher, dsv3_use_fp8_dispatch, name="fp8_dispatch_quant_node"): + self.token_dispatcher = token_dispatcher + self.pre_dispatch_node = PreDispatchNode(token_dispatcher) + self.name = name + self.dsv3_use_fp8_dispatch = dsv3_use_fp8_dispatch + + @paddle.no_grad() + def forward(self, hidden_states, probs, routing_map): + # reshape + self.token_dispatcher.hidden_shape = hidden_states.shape + hs_2d = hidden_states.view([-1, self.token_dispatcher.hidden_shape[-1]]) + + if self.dsv3_use_fp8_dispatch: + # quant + hs_fp8, hs_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + hs_2d, output_scale_transpose=False, quant_method="1x128", input_transpose=False + ) + + # pre_dispatch + token_indices, token_probs = self.pre_dispatch_node.forward(routing_map, probs) + + self.hidden_states_shape = hidden_states.shape + hs_fp8.stop_gradient = False + token_probs.stop_gradient = False + return (hs_fp8, hs_scale), token_indices, token_probs + else: + # pre_dispatch + token_indices, token_probs = self.pre_dispatch_node.forward(routing_map, probs) + + self.hidden_states_shape = hidden_states.shape + hs_2d.stop_gradient = False + token_probs.stop_gradient = False + return hs_2d, token_indices, token_probs + + @paddle.no_grad() + def backward(self, hs_grad, token_probs_grad): + # predispatch grad + probs_grad = self.pre_dispatch_node.backward(token_probs_grad) + token_probs_grad._record_stream() + + # reshape_grad + hs_grad = hs_grad.view(self.hidden_states_shape) + hs_grad._record_stream() + + return hs_grad, probs_grad, None + + +class Fp8DispatchNode: + def __init__(self, token_dispatcher, name="fp8_dispatch_node"): + self.token_dispatcher = token_dispatcher + self.dispatch_act_node = DispatchNode(token_dispatcher) + self.name = name + + @paddle.no_grad() + def forward( + self, + hs_2d, + token_indices, + token_probs, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, + ): + # dispatch + hs_2d_dispatched, dispatched_probs, states = self.dispatch_act_node.forward( + hs_2d, + token_indices, + token_probs, + self.token_dispatcher._comm_manager.num_experts, + self.token_dispatcher._comm_manager.group, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + self.token_dispatcher._comm_manager.handle = states["handle"] + self.token_dispatcher._comm_manager.tokens_per_expert = states["tokens_per_expert"] + dispatched_indices = states["dispatched_indices"] + + stop_gradient_for_multi_input(hs_2d_dispatched) + dispatched_probs.stop_gradient = False + return hs_2d_dispatched, dispatched_indices, dispatched_probs + + @paddle.no_grad() + def backward( + self, + hs_dispatched_grad, + dispatched_probs_grad, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, + ): + # dispatch grad + hs_grad, _, token_probs_grad = self.dispatch_act_node.backward( + hs_dispatched_grad, + dispatched_probs_grad, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + return hs_grad, token_probs_grad + + +class Fp8CombineNode: + def __init__(self, token_dispatcher, name="fp8_combine_node"): + self.token_dispatcher = token_dispatcher + self.combine_node = CombineNode(token_dispatcher) + self.name = name + + @paddle.no_grad() + def forward(self, hidden_states_out, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + # combine + output_combine = self.combine_node.forward( + hidden_states_out, + self.token_dispatcher._comm_manager.group, + self.token_dispatcher._comm_manager.handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + output_combine.stop_gradient = False + self.token_dispatcher._comm_manager.handle = None + return output_combine + + @paddle.no_grad() + def backward(self, output_combine_grad, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + # combine grad -> fp8 + hidden_states_out_grad = self.combine_node.backward( + output_combine_grad, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + return hidden_states_out_grad + + +class Fp8CombineQuantNode: + def __init__(self, token_dispatcher, dsv3_use_fp8_dispatch, moe_group=None, name="fp8_combine_quant_node"): + self.token_dispatcher = token_dispatcher + self.name = name + self.moe_group = moe_group + self.dsv3_use_fp8_dispatch = dsv3_use_fp8_dispatch + + @paddle.no_grad() + def forward(self, output_combine): + # post combine + output = output_combine.reshape(self.token_dispatcher.hidden_shape) + output_combine._record_stream() + self.output_combine_shape = output_combine.shape + output.stop_gradient = False + return output + + @paddle.no_grad() + def backward(self, output_grad, event_to_wait=None): + # post combine grad + if self.dsv3_use_fp8_dispatch: + if event_to_wait is not None: + assert self.moe_group is not None + event_to_wait.comm_stream_wait(self.moe_group.id) + buffer = get_buffer(self.token_dispatcher._comm_manager.group, get_hidden_bytes(output_grad)) + custom_stream = paddle.device.Stream(stream_base=buffer.runtime.get_comm_stream()) + else: + custom_stream = paddle.device.current_stream() + with paddle.device.stream_guard(custom_stream): + output_combine_grad = paddle.reshape(output_grad, [-1, output_grad.shape[-1]]) + # output_combine_grad quant to fp8 + output_combine_grad_fp8, output_combine_grad_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + output_combine_grad, output_scale_transpose=False, quant_method="1x128", input_transpose=False + ) + output_grad._record_stream() + quant_event = None + if event_to_wait is not None: + quant_event = deep_ep.get_event_from_custom_stream(custom_stream.stream_base) + return (output_combine_grad_fp8, output_combine_grad_scale), quant_event + else: + output_combine_grad = paddle.reshape(output_grad, [-1, output_grad.shape[-1]]) + return output_combine_grad, None + + +class FusionMlpNode: + """ + The FusedMoeLayer class includes operations for unzipping, expert computation, and zipping. + """ + + def __init__( + self, + custom_map, + max_topk, + recompute_fwd_gate_up=False, + is_split_group_gemm=True, + dsv3_use_fp8_dispatch=True, + mlp_fwd_subbatch_rows=0, + mlp_bwd_subbatch_rows=0, + output_subbatch_rows=0, + ): + self.token_dispatcher = custom_map.token_dispatcher + self.experts = custom_map.experts + self.unzip_node = UnZipNode() + self.zip_node = ZipNode() + self.experts_group_gemm_node = FP8GroupGemmMlpFunctionNode( + custom_map, + recompute_fwd_gate_up=recompute_fwd_gate_up, + is_split_group_gemm=is_split_group_gemm, + ) + self.dsv3_use_fp8_dispatch = dsv3_use_fp8_dispatch + + self.seq_length = custom_map.config.seq_length + self.num_experts_per_tok = custom_map.config.num_experts_per_tok + self.adaptive_remained_O1_recompute_ratio = custom_map.config.adaptive_remained_O1_recompute_ratio + + self.recompute_fwd_gate_up = recompute_fwd_gate_up + self.dispatched_indices = None + self.dispatched_probs = None + self.tokens_per_expert = None + self.padding_token_per_experts = None + self.router_topk = max_topk + self.mlp_fwd_subbatch_rows = mlp_fwd_subbatch_rows + self.mlp_bwd_subbatch_rows = mlp_bwd_subbatch_rows + self.output_subbatch_rows = output_subbatch_rows + + def set_recompute_fwd_gate_up(self, recompute_fwd_gate_up): + self.experts_group_gemm_node.recompute_fwd_gate_up = recompute_fwd_gate_up + + def reset_statue(self): + """ + Reset the state of the FusionMlpNode object. + + Args: + None. + + Returns: + None. + + """ + self.dispatched_indices = None + self.dispatched_probs = None + self.tokens_per_expert = None + self.padding_token_per_experts = None + self.router_topk = None + + del self.unzip_node + del self.zip_node + self.unzip_node = None + self.zip_node = None + + self.experts_group_gemm_node.reset_statue() + self.experts_group_gemm_node = None + + def prepare_env_subbatch(self, unzipped_tokens=None, unzipped_tokens_scale=None, is_fwd=True): + if is_fwd: + assert unzipped_tokens is not None and unzipped_tokens_scale is not None + self.experts_group_gemm_node.input_fp8 = unzipped_tokens + self.experts_group_gemm_node.input_scale = unzipped_tokens_scale + self.m_indices = self.experts_group_gemm_node.gen_m_indices(self.padding_token_per_experts) + self.experts_group_gemm_node.fwd_subbatch = True + else: + self.m_indices = ( + self.experts_group_gemm_node.gen_m_indices(self.padding_token_per_experts) + if not hasattr(self, "m_indices") + else self.m_indices + ) + self.experts_group_gemm_node.bwd_subbatch = True + reload(self.experts_group_gemm_node.input_fp8) + reload(self.experts_group_gemm_node.input_scale) + + def gemm_forward_subbatch( + self, + unzipped_tokens, + unzipped_tokens_scale, + unzipped_probs, + map_unzipped_indices_to_zipped, + output, + total_zipped_tokens, + padding_token_per_experts, + start_idx=None, + end_idx=None, + output_subbatch_rows=None, + ): + if start_idx is None or end_idx is None: + start_idx = 0 + end_idx = unzipped_tokens.shape[0] + start_idx = max(0, start_idx) + end_idx = min(unzipped_tokens.shape[0], end_idx) + + expert_out = self.experts_group_gemm_node.forward( + (unzipped_tokens[start_idx:end_idx], unzipped_tokens_scale[start_idx:end_idx]), + unzipped_probs[start_idx:end_idx], + padding_token_per_experts, + m_indices=self.m_indices[start_idx:end_idx], + ) + + output = tokens_zip_unique_add_with_subbatch( + output, + expert_out, + map_unzipped_indices_to_zipped[start_idx:end_idx], + total_zipped_tokens, + subbatch_rows=output_subbatch_rows, + ) + return output + + def gemm_backward_subbatch( + self, + unzipped_grad, + map_unzipped_indices_to_zipped, + total_zipped_tokens, + output, + padding_token_per_experts, + start_idx=None, + end_idx=None, + output_subbatch_rows=None, + reset_status=False, + ): + def split_list_prefix(l, start, end): + prefix_sum = [0] * (len(l) + 1) + for i in range(len(l)): + prefix_sum[i + 1] = prefix_sum[i] + l[i] + + result = [] + for i in range(len(l)): + segment_start = prefix_sum[i] + segment_end = prefix_sum[i + 1] + overlap_start = max(start, segment_start) + overlap_end = min(end, segment_end) + selected = max(0, overlap_end - overlap_start) + result.append(selected) + return result + + if start_idx is None or end_idx is None: + start_idx = 0 + end_idx = extract_first_if_tuple(unzipped_grad).shape[0] + + start_idx = max(0, start_idx) + end_idx = min(extract_first_if_tuple(unzipped_grad).shape[0], end_idx) + + # m_indices = self.experts_group_gemm_node.gen_m_indices(self.tokens_per_expert) + unzipped_inp_grad = ( + (unzipped_grad[0][start_idx:end_idx].contiguous(), unzipped_grad[1][start_idx:end_idx].contiguous()) + if isinstance(unzipped_grad, tuple) + else unzipped_grad[start_idx:end_idx].contiguous() + ) + unzipped_grad, unzipped_probs_grad = self.experts_group_gemm_node.backward( + unzipped_inp_grad, + self.unzipped_probs[start_idx:end_idx].contiguous(), + input_fp8_slice=self.experts_group_gemm_node.input_fp8[start_idx:end_idx].contiguous(), + input_scale_slice=self.experts_group_gemm_node.input_scale[start_idx:end_idx].contiguous(), + tokens_per_expert=split_list_prefix(padding_token_per_experts, start_idx, end_idx), + m_indices=self.m_indices[start_idx:end_idx].contiguous(), + reset_status=reset_status, + ) + + output = tokens_zip_unique_add_with_subbatch( + output, + unzipped_grad, + map_unzipped_indices_to_zipped[start_idx:end_idx], + zipped_rows=total_zipped_tokens, + subbatch_rows=output_subbatch_rows, + ) + + return output, unzipped_probs_grad + + @paddle.no_grad() + def forward(self, hs_2d_dispatched, dispatched_indices, dispatched_probs): + """ + Perform forward computation on input data. + + Args: + hs_fp8_dispatched (Tensor): Input data dispatched to experts. + dispatched_indices (Tensor): Expert indices assigned to input data. + dispatched_probs (Tensor): Probabilities of input data being dispatched to experts. + + Returns: + Tensor: Output data after forward computation. + """ + self.tokens_per_expert = self.token_dispatcher._comm_manager.tokens_per_expert + self.dispatched_probs = dispatched_probs + num_experts = len(self.tokens_per_expert) + padding_token_per_experts = [(x + 127) // 128 * 128 for x in self.tokens_per_expert] + self.padding_token_per_experts = padding_token_per_experts + # 1 unzip + self.dispatched_indices = dispatched_indices.to(paddle.int32) + + total_zipped_tokens = extract_first_if_tuple(hs_2d_dispatched).shape[0] + (unzipped_tokens, zipped_expertwise_rowmap, unzipped_probs, unzipped_tokens_scale,) = self.unzip_node.forward( + hs_2d_dispatched, + self.dispatched_indices, + dispatched_probs, + topk=self.router_topk, + num_experts=num_experts, + tokens_per_expert=self.tokens_per_expert, + ) + record_stream_for_multi_input(hs_2d_dispatched) + dispatched_indices._record_stream() + dispatched_probs._record_stream() + + self.unzipped_probs = unzipped_probs.unsqueeze(-1) + + if self.dsv3_use_fp8_dispatch: + total_unzipped_tokens = extract_first_if_tuple(unzipped_tokens).shape[0] + # If adaptive O1 recompute is enabled, determine whether to enable recompute O1 based on the degree of imbalance + if self.recompute_fwd_gate_up == -1: + if ( + total_unzipped_tokens + > self.seq_length * self.num_experts_per_tok * self.adaptive_remained_O1_recompute_ratio + ): + # logger.debug(f"recompute_fwd_gate_up changed to True, Because the receives {unzipped_tokens.shape[0]} Tensors greater then {self.seq_length*self.num_experts_per_tok*self.adaptive_remained_O1_recompute_ratio}.") + self.set_recompute_fwd_gate_up(True) + else: + # logger.debug(f"recompute_fwd_gate_up changed to False, Because the receives {unzipped_tokens.shape[0]} Tensors less then {self.seq_length*self.num_experts_per_tok*self.adaptive_remained_O1_recompute_ratio}.") + self.set_recompute_fwd_gate_up(False) + + # if use_mlp_subbatch is enabled, then split the unzipped_tokens into subbatches + if self.mlp_fwd_subbatch_rows != 0 and total_unzipped_tokens > self.mlp_fwd_subbatch_rows * 2: + assert ( + self.experts_group_gemm_node.recompute_fwd_gate_up + ), "recompute_fwd_gate_up must be true when use_mlp_subbatch = True" + map_unzipped_indices_to_zipped = TDU.tokens_unzip_slice( + extract_first_if_tuple(hs_2d_dispatched), + zipped_expertwise_rowmap, + num_experts, + total_unzipped_tokens, + 0, + total_unzipped_tokens + 1, + ) + if isinstance(hs_2d_dispatched, tuple): + hs_2d_dispatched[0]._clear_to_zero_allocation() + hs_2d_dispatched[1]._clear_to_zero_allocation() + else: + hs_2d_dispatched._clear_to_zero_allocation() + + subbatch_rows = min((total_unzipped_tokens // num_experts) // 128 * 128, self.mlp_fwd_subbatch_rows) + nparts = (total_unzipped_tokens + subbatch_rows - 1) // subbatch_rows + output = paddle.empty([0, extract_first_if_tuple(hs_2d_dispatched).shape[-1]], dtype=paddle.float32) + self.prepare_env_subbatch(unzipped_tokens, unzipped_tokens_scale, True) + logger.info( + f"Enable subbatch_forward!! total_zipped_tokens:{total_zipped_tokens}, total_unzipped_tokens:{total_unzipped_tokens}, nparts:{nparts}, subbatch_rows:{subbatch_rows}, output_sub_rows:{self.output_subbatch_rows}" + ) + for i in range(nparts): + start_idx = i * subbatch_rows + end_idx = min(start_idx + subbatch_rows, total_unzipped_tokens) + output = self.gemm_forward_subbatch( + unzipped_tokens, + unzipped_tokens_scale, + unzipped_probs, + map_unzipped_indices_to_zipped, + output, + total_zipped_tokens, + padding_token_per_experts, + start_idx=start_idx, + end_idx=end_idx, + output_subbatch_rows=self.output_subbatch_rows, + ) + + output = merge_subbatch_cast(output, paddle.bfloat16) + output.stop_gradient = False + offload(self.experts_group_gemm_node.input_fp8) + offload(self.experts_group_gemm_node.input_scale) + return output + + # 2 experts + expert_out = self.experts_group_gemm_node.forward( + (unzipped_tokens, unzipped_tokens_scale), unzipped_probs, padding_token_per_experts + ) + else: + # If adaptive O1 recompute is enabled, determine whether to enable recompute O1 based on the degree of imbalance + if self.recompute_fwd_gate_up == -1: + if ( + unzipped_tokens.shape[0] + > self.seq_length * self.num_experts_per_tok * self.adaptive_remained_O1_recompute_ratio + ): + self.set_recompute_fwd_gate_up(True) + else: + self.set_recompute_fwd_gate_up(False) + + # 2 experts + expert_out = self.experts_group_gemm_node.forward( + unzipped_tokens, unzipped_probs, padding_token_per_experts + ) + + # 3 zip + if isinstance(hs_2d_dispatched, tuple): + hs_2d_dispatched[0]._clear_to_zero_allocation() + hs_2d_dispatched[1]._clear_to_zero_allocation() + else: + hs_2d_dispatched._clear_to_zero_allocation() + expert_out_tmp = expert_out.reshape([-1, expert_out.shape[-1]]) + + expert_out_zipped = self.zip_node.forward( + expert_out_tmp, + zipped_expertwise_rowmap, + self.dispatched_indices, + unzipped_probs, + total_zipped_tokens=total_zipped_tokens, + num_experts=num_experts, + ) + + expert_out_zipped.stop_gradient = False + return expert_out_zipped + + @paddle.no_grad() + def backward(self, hidden_states_out_grad): + """ + Backward propagation function. + + Args: + hidden_states_out_grad_fp8 (Tensor): Gradient of hidden states. + + Returns: + Tuple[Tensor, Tensor]: Contains two elements: + - hs_fp8_dispatched_grad (Tensor): Gradient of unzipped hidden states. + - dispatched_probs_grad (Tensor): Gradient of dispatch probabilities. + """ + # zip_grad + unzipped_grad = self.zip_node.backward( + hidden_states_out_grad, + self.dispatched_indices, + self.dispatched_probs, + top_k=self.router_topk, + num_experts=len(self.tokens_per_expert), + tokens_per_expert=self.tokens_per_expert, + ) + record_stream_for_multi_input(hidden_states_out_grad) + + total_zipped_tokens = extract_first_if_tuple(hidden_states_out_grad).shape[0] + total_unzipped_tokens = extract_first_if_tuple(unzipped_grad).shape[0] + hidden_states_size = extract_first_if_tuple(hidden_states_out_grad).shape[-1] + num_experts = len(self.tokens_per_expert) + padding_token_per_experts = [(x + 127) // 128 * 128 for x in self.tokens_per_expert] + + if self.mlp_bwd_subbatch_rows != 0 and total_unzipped_tokens > self.mlp_bwd_subbatch_rows * 2: + map_unzipped_indices_to_zipped = TDU.tokens_unzip_slice( + extract_first_if_tuple(hidden_states_out_grad), + self.unzip_node.zipped_expertwise_rowmap, + num_experts, + total_unzipped_tokens, + 0, + total_unzipped_tokens + 1, + ) + if isinstance(hidden_states_out_grad, tuple): + hidden_states_out_grad[0]._clear_to_zero_allocation() + hidden_states_out_grad[1]._clear_to_zero_allocation() + else: + hidden_states_out_grad._clear_to_zero_allocation() + + subbatch_rows = min((total_unzipped_tokens // num_experts) // 128 * 128, self.mlp_bwd_subbatch_rows) + nparts = (total_unzipped_tokens + subbatch_rows - 1) // subbatch_rows + output = paddle.empty([0, hidden_states_size], dtype=paddle.float32) + probs_grad_list = [] + self.prepare_env_subbatch(is_fwd=False) + logger.info( + f"Enable subbatch_backward!! total_zipped_tokens:{total_zipped_tokens}, total_unzipped_tokens:{total_unzipped_tokens}, nparts:{nparts}, subbatch_rows:{subbatch_rows}, output_sub_rows:{self.output_subbatch_rows}" + ) + for i in range(nparts): + reset_status = True if i == nparts - 1 else False # release saved status in the last part. + start_idx = i * subbatch_rows + end_idx = min(start_idx + subbatch_rows, total_unzipped_tokens) + output, probs_grad = self.gemm_backward_subbatch( + unzipped_grad, + map_unzipped_indices_to_zipped, + total_zipped_tokens, + output, + padding_token_per_experts, + start_idx=start_idx, + end_idx=end_idx, + output_subbatch_rows=self.output_subbatch_rows, + reset_status=reset_status, + ) + probs_grad_list.append(probs_grad) + if isinstance(unzipped_grad, tuple): + unzipped_grad[0]._clear_to_zero_allocation() + unzipped_grad[1]._clear_to_zero_allocation() + else: + unzipped_grad._clear_to_zero_allocation() + hs_dispatched_grad = merge_subbatch_cast(output, paddle.bfloat16) + dispatched_probs_grad = TDU.tokens_zip_prob_seq_subbatch( + probs_grad_list, self.unzip_node.zipped_expertwise_rowmap, self.dispatched_indices, subbatch_rows + ) + self.reset_statue() + return hs_dispatched_grad, dispatched_probs_grad + + if isinstance(hidden_states_out_grad, tuple): + hidden_states_out_grad[0]._clear_to_zero_allocation() + hidden_states_out_grad[1]._clear_to_zero_allocation() + else: + hidden_states_out_grad._clear_to_zero_allocation() + + # expert_grad + expert_out, probs_grad = self.experts_group_gemm_node.backward( + unzipped_grad, self.unzipped_probs, padding_token_per_experts + ) + + hs_dispatched_grad, dispatched_probs_grad = self.unzip_node.backward( + expert_out, + total_zipped_tokens, + probs_grad, + self.dispatched_indices, + num_experts=num_experts, + ) + + self.reset_statue() + return hs_dispatched_grad, dispatched_probs_grad + + +class FusionMoeNode: + def __init__( + self, + custom_map, + recompute_fwd_gate_up=False, + is_split_group_gemm=True, + dsv3_use_fp8_dispatch=True, + mlp_fwd_subbatch_rows=0, + mlp_bwd_subbatch_rows=0, + output_subbatch_rows=0, + name="fusion_moe_node", + ): + self.token_dispatcher = custom_map.token_dispatcher + self.moe_router_topk = custom_map.moe_router_topk + self.dsv3_use_fp8_dispatch = dsv3_use_fp8_dispatch + self.dispatch_quant_node = Fp8DispatchQuantNode(self.token_dispatcher, dsv3_use_fp8_dispatch) + self.dispatch_node = Fp8DispatchNode(self.token_dispatcher) + self.mlp_node = FusionMlpNode( + custom_map, + self.moe_router_topk, + recompute_fwd_gate_up=recompute_fwd_gate_up, + is_split_group_gemm=is_split_group_gemm, + dsv3_use_fp8_dispatch=dsv3_use_fp8_dispatch, + mlp_fwd_subbatch_rows=mlp_fwd_subbatch_rows, + mlp_bwd_subbatch_rows=mlp_bwd_subbatch_rows, + output_subbatch_rows=output_subbatch_rows, + ) + self.combine_node = Fp8CombineNode(self.token_dispatcher) + self.combine_quant_node = Fp8CombineQuantNode( + self.token_dispatcher, dsv3_use_fp8_dispatch, custom_map.moe_group + ) + self.name = name + + @paddle.no_grad() + def forward(self, hidden_states, probs, routing_map): + if self.dsv3_use_fp8_dispatch: + (hs_fp8, hs_scale), token_indices, token_probs = self.dispatch_quant_node.forward( + hidden_states, probs, routing_map + ) + ( + (hs_fp8_dispatched, hs_scale_dispatched), + dispatched_indices, + dispatched_probs, + ) = self.dispatch_node.forward((hs_fp8, hs_scale), token_indices, token_probs) + hidden_states_out = self.mlp_node.forward( + (hs_fp8_dispatched, hs_scale_dispatched), dispatched_indices, dispatched_probs + ) + output_combine = self.combine_node.forward(hidden_states_out) + output = self.combine_quant_node.forward(output_combine) + output.stop_gradient = False + return output + else: + hs_2d_dispatched, dispatched_indices, dispatched_probs = self.dispatch_node.forward( + hidden_states, probs, routing_map + ) + hidden_states_out = self.mlp_node.forward(hs_2d_dispatched, dispatched_indices, dispatched_probs) + output_combine = self.combine_node.forward(hidden_states_out) + output = self.combine_quant_node.forward(output_combine) + output.stop_gradient = False + return output + + @paddle.no_grad() + def backward(self, output_grad): + output_combine_grad, _ = self.combine_quant_node.backward(output_grad) + hidden_states_out_grad = self.combine_node.backward(output_combine_grad) + + hs_dispatched_grad, dispatched_probs_grad = self.mlp_node.backward(hidden_states_out_grad) + + if self.dsv3_use_fp8_dispatch: + hs_fp8_grad, token_probs_grad = self.dispatch_node.backward(hs_dispatched_grad, dispatched_probs_grad) + hs_grad, probs_grad, routing_map_grad = self.dispatch_quant_node.backward(hs_fp8_grad, token_probs_grad) + return hs_grad, probs_grad, routing_map_grad + else: + hs_bf16_grad, token_probs_grad = self.dispatch_node.backward(hs_dispatched_grad, dispatched_probs_grad) + return hs_bf16_grad, None, token_probs_grad + + +class FusionMoe(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx, + hidden_states, + probs, + routing_map, + custom_map, + recompute_fwd_gate_up=False, + is_split_group_gemm=True, + dsv3_use_fp8_dispatch=True, + ): + ctx.node = FusionMoeNode( + custom_map, + recompute_fwd_gate_up=recompute_fwd_gate_up, + is_split_group_gemm=is_split_group_gemm, + dsv3_use_fp8_dispatch=dsv3_use_fp8_dispatch, + ) + return ctx.node.forward(hidden_states, probs, routing_map) + + @staticmethod + def backward(ctx, output_grad): + return ctx.node.backward(output_grad) diff --git a/examples/experiments/deepseek_v3_pretrain/moe_utils.py b/examples/experiments/deepseek_v3_pretrain/moe_utils.py new file mode 100644 index 00000000000..d2d2175664d --- /dev/null +++ b/examples/experiments/deepseek_v3_pretrain/moe_utils.py @@ -0,0 +1,436 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 DeepSeek +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + +try: + import TokenDispatcherUtils as TDU +except ImportError: + TDU = None + +from paddleformers.transformers.fp8_utils import FP8LinearFunctionBase + +if not hasattr(paddle.Tensor, "_clear_to_zero_allocation"): + + def _clear_to_zero_allocation(self): + """ + _clear_to_zero_allocation + """ + old_shape = self.shape + dst = paddle.empty([0], dtype=self.dtype) + dst_t = dst.value().get_tensor() + src_t = self.value().get_tensor() + src_t._share_data_with(dst_t) + src_t._set_dims(old_shape) + + setattr(paddle.Tensor, "_clear_to_zero_allocation", _clear_to_zero_allocation) + + +if not hasattr(paddle.Tensor, "_holder_size"): + + def _holder_size(self): + """ + _holder_size + """ + if self._is_initialized(): + return int(np.prod(self.shape)) * paddle.core.size_of_dtype(self.dtype) + else: + return 0 + + setattr(paddle.Tensor, "_holder_size", _holder_size) + + +def topk_to_permuted_indices(x, num_tokens_per_expert_list, topk): + x = paddle.flatten(x) + prob_permuted_indices = paddle.concat( + [ + paddle.tensor.search._restrict_nonzero(x == i, total_true_num) + for i, total_true_num in enumerate(num_tokens_per_expert_list) + ] + ).flatten() + token_permuted_indices = prob_permuted_indices // topk + return token_permuted_indices, prob_permuted_indices + + +def permute_fast( + tokens, + token_permuted_indices, + drop_and_pad: bool = False, +): + """Permute the tokens and probs based on the mask. + Tokens with the same designated expert will be grouped together. + The shape of mask is [tokens, num_experts], it indicates which experts were selected + by each token. + + Args: + tokens (paddle.Tensor): The input token tensor, [num_tokens, hidden]. + drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop + and pads the number of tokens to the expert capacity. + """ + assert not drop_and_pad, "token-drop and pads is not supported" + # permuted_input = paddle.gather(tokens, token_permuted_indices) + permuted_input = tokens.index_select(axis=0, index=token_permuted_indices) + return permuted_input + + +def unpermute_fast( + permuted_tokens: paddle.Tensor, + token_permuted_indices: paddle.Tensor, + prob_permuted_indices: paddle.Tensor, + restore_shape: paddle.shape, + probs: paddle.Tensor = None, + drop_and_pad: bool = False, +): + """ + Restore the original order of tokens after permutation. If probs are provided, it + will also apply them to the tokens before restoring the order. + + Args: + permuted_tokens (paddle.Tensor): The permuted token tensor. + token_permuted_indices (paddle.Tensor): The indices used to sort the tokens. + restore_shape (paddle.shape): The shape of the unpermuted tensor. + probs (paddle.Tensor, optional): The unpermuted probs tensor, + drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop + and pads the number of tokens to the expert capacity. + + Returns: + paddle.Tensor: The tokens restored to their original order. + """ + assert not drop_and_pad, "token-drop and pads is not supported" + _, hidden = restore_shape + if probs is not None: + permuted_probs = paddle.gather(probs.flatten(), prob_permuted_indices) + permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1) + # Create an output tensor filled with zeros + output_tokens = paddle.zeros(restore_shape, dtype=permuted_tokens.dtype) + # Scatter add the permuted_input back to the original positions + + output_tokens.put_along_axis_( + axis=0, + indices=token_permuted_indices.unsqueeze(1).expand([-1, hidden]), + values=permuted_tokens, + reduce="add", + include_self=True, + ) + return output_tokens + + +class UnZipNode: + def __init__(self, name="unzip"): + self.name = name + self.unzipped_probs = None + self.zipped_expertwise_rowmap = None + + def reset_statue(self): + self.unzipped_probs = None + self.zipped_expertwise_rowmap = None + + @paddle.no_grad() + def forward( + self, + hs_2d_dispatched, + dispatched_indices, + dispatched_probs, + topk, + num_experts, + tokens_per_expert, + ): + if isinstance(hs_2d_dispatched, tuple): + with paddle.amp.auto_cast(False): + ( + unzipped_tokens, + zipped_expertwise_rowmap, + unzipped_probs, + unzipped_scale, + ) = paddle.nn.functional.moe_permute( + hs_2d_dispatched[0], + hs_2d_dispatched[1], + dispatched_indices, + dispatched_probs, + num_experts=num_experts, + tokens_per_expert=tokens_per_expert, + padding_alignment=128, + ) + else: + with paddle.amp.auto_cast(False): + ( + unzipped_tokens, + zipped_expertwise_rowmap, + unzipped_probs, + unzipped_scale, + ) = paddle.nn.functional.moe_permute( + hs_2d_dispatched, + None, + dispatched_indices, + dispatched_probs, + num_experts=num_experts, + tokens_per_expert=tokens_per_expert, + padding_alignment=128, + ) + self.unzipped_probs = unzipped_probs + self.zipped_expertwise_rowmap = zipped_expertwise_rowmap + return (unzipped_tokens, zipped_expertwise_rowmap, unzipped_probs, unzipped_scale) + + @paddle.no_grad() + def backward(self, dx, total_zipped_tokens, probs_grad, dispatched_indices, num_experts): + with paddle.amp.auto_cast(False): + weighted_zipped_tokens, probs_grad_zipped = paddle.nn.functional.moe_unpermute( + dx, + self.zipped_expertwise_rowmap, + dispatched_indices, + probs_grad, + total_zipped_tokens=total_zipped_tokens, + num_experts=num_experts, + ) + self.reset_statue() + return weighted_zipped_tokens, probs_grad_zipped + + +class ZipNode: + def __init__(self, name="zip"): + self.name = name + + @paddle.no_grad() + def forward( + self, expert_out, zipped_expertwise_rowmap, routemap_topk, unzipped_probs, total_zipped_tokens, num_experts + ): + with paddle.amp.auto_cast(False): + expert_out_zipped, zipped_probs_topk = paddle.nn.functional.moe_unpermute( + expert_out, zipped_expertwise_rowmap, routemap_topk, unzipped_probs, total_zipped_tokens, num_experts + ) + return expert_out_zipped + + @paddle.no_grad() + def backward( + self, + grad_output, + dispatched_indices, + dispatched_probs, + top_k, + num_experts, + tokens_per_expert, + ): + if isinstance(grad_output, tuple): + with paddle.amp.auto_cast(False): + ( + unzipped_grad, + zipped_expertwise_rowmap_grad, + unzipped_probs_grad, + unzipped_scale_grad, + ) = paddle.nn.functional.moe_permute( + grad_output[0], + grad_output[1], + dispatched_indices, + dispatched_probs, + num_experts, + tokens_per_expert, + padding_alignment=128, + ) + return (unzipped_grad, unzipped_scale_grad) + else: + with paddle.amp.auto_cast(False): + ( + unzipped_grad, + zipped_expertwise_rowmap_grad, + unzipped_probs_grad, + unzipped_scale_grad, + ) = paddle.nn.functional.moe_permute( + grad_output, + None, + dispatched_indices, + dispatched_probs, + num_experts, + tokens_per_expert, + padding_alignment=128, + ) + + return unzipped_grad + + +class PermuteNode: + def __init__(self, token_dispatcher, name="permute"): + self.token_dispatcher = token_dispatcher + self.name = name + + def reset_status(self): + self.token_permuted_indices = None + self.prob_permuted_indices = None + + def forward(self, hidden_states, hidden_states_scale, dispatched_indices): + self.token_dispatcher._comm_manager.hidden_shape_before_permute = hidden_states.shape + self.hidden_shape_before_permute = hidden_states.shape + self.token_permuted_indices, self.prob_permuted_indices = topk_to_permuted_indices( + dispatched_indices, + self.token_dispatcher._comm_manager.tokens_per_expert, + self.token_dispatcher._comm_manager.router_topk, + ) + hidden_states = permute_fast(hidden_states, self.token_permuted_indices) + # permute scale + hidden_states_scale = permute_fast(hidden_states_scale, self.token_permuted_indices) + + return hidden_states, hidden_states_scale, self.token_permuted_indices, self.prob_permuted_indices + + def backward(self, out_grad, dispatched_probs): + input_dtype = out_grad.dtype + hidden_states_grad = unpermute_fast( + permuted_tokens=out_grad, + token_permuted_indices=self.token_permuted_indices, + prob_permuted_indices=self.prob_permuted_indices, + restore_shape=self.hidden_shape_before_permute, + probs=dispatched_probs, + ) + self.reset_status() + return hidden_states_grad.to(input_dtype) + + +class UnPermuteNode: + def __init__(self, token_dispatcher, name="unpermute"): + self.token_dispatcher = token_dispatcher + self.name = name + + def reset_status(self): + self.token_permuted_indices = None + self.hidden_states = None + self.prob_permuted_indices = None + self.faltten_dispatched_probs = None + self.hidden = None + self.permuted_tokens = None + self.output_tokens = None + + def forward( + self, + hidden_states, + token_permuted_indices, + prob_permuted_indices, + dispatched_probs, + ): + self.token_permuted_indices = token_permuted_indices + self.input_dtype = hidden_states.dtype + self.hidden_states = hidden_states + self.prob_permuted_indices = prob_permuted_indices + self.dispatched_probs_shape = dispatched_probs.shape + # permute + _, self.hidden = self.token_dispatcher._comm_manager.hidden_shape_before_permute + + self.faltten_dispatched_probs = dispatched_probs.flatten() + + self.permuted_probs = paddle.gather(self.faltten_dispatched_probs, self.prob_permuted_indices) + permuted_tokens = self.hidden_states * self.permuted_probs.unsqueeze(-1) + permuted_tokens = permuted_tokens.cast(self.hidden_states.dtype) + + # Create an output tensor filled with zeros + output_tokens = paddle.zeros( + self.token_dispatcher._comm_manager.hidden_shape_before_permute, dtype=self.hidden_states.dtype + ) + # Scatter add the permuted_input back to the original positions + output_tokens.put_along_axis_( + axis=0, + indices=self.token_permuted_indices.cast("int32").unsqueeze(1).expand([-1, self.hidden]), + values=permuted_tokens, + reduce="add", + include_self=True, + ) + with paddle.base.device_guard("cpu"): + self.output_tokens = paddle.empty(shape=output_tokens.shape, dtype=output_tokens.dtype) + + return output_tokens.to(self.input_dtype) + + def backward(self, out_grad, out_grad_scale): + hidden_states_grad = paddle.gather(out_grad, self.token_permuted_indices) + + output_tokens_grad = FP8LinearFunctionBase.dequantize_fp8_to_fp32(out_grad, out_grad_scale) + permuted_tokens = self.hidden_states * self.permuted_probs.unsqueeze(-1) + permuted_tokens = permuted_tokens.cast(self.hidden_states.dtype) + + _, permuted_tokens_grad = paddle._C_ops.put_along_axis_grad( + self.output_tokens, + self.token_permuted_indices.cast("int32").unsqueeze(1).expand([-1, self.hidden]), + permuted_tokens, + self.output_tokens, + output_tokens_grad, + 0, + "add", + True, + ) + + permuted_probs_grad = (permuted_tokens_grad * self.hidden_states).sum(axis=-1) + + faltten_dispatched_probs_grad = paddle._C_ops.gather_grad( + self.faltten_dispatched_probs, self.prob_permuted_indices, permuted_probs_grad, 0 + ) + + # dispatched_probs_grad = paddle._C_ops.flatten_grad(self.dispatched_probs, faltten_dispatched_probs_grad) + dispatched_probs_grad = faltten_dispatched_probs_grad.reshape(self.dispatched_probs_shape) + + self.reset_status() + return hidden_states_grad, dispatched_probs_grad + + +def tokens_zip_unique_add_with_subbatch(zipped, unzipped, index_unzipped, zipped_rows, subbatch_rows=None): + """ + tokens_zip_unique_add_with_subbatch + """ + if subbatch_rows is None or subbatch_rows <= 0 or zipped_rows <= 0: + return TDU.tokens_zip_unique_add(zipped, unzipped, index_unzipped, zipped_rows) + else: + if isinstance(zipped, paddle.Tensor): + num_split = (zipped_rows + subbatch_rows - 1) // subbatch_rows + remainder = zipped_rows % subbatch_rows + if remainder == 0: + rows = [subbatch_rows] * num_split + else: + rows = [subbatch_rows] * (num_split - 1) + [remainder] + + if zipped.shape[0] == 0: + dtype = zipped.dtype + hidden_size = zipped.shape[1] + zipped = [paddle.zeros([r, hidden_size], dtype=dtype) for r in rows] + else: + zipped = paddle.split(zipped, rows, axis=0) + return TDU.tokens_zip_unique_add_subbatch(zipped, unzipped, index_unzipped, zipped_rows, subbatch_rows) + + +def merge_subbatch_cast(x, dtype): + if isinstance(x, (list, tuple)): + if len(x) == 1: + x = x[0] + return x.cast(dtype) if x.dtype != dtype else x + else: + return TDU.merge_subbatch_cast(x, dtype) + else: + return x.cast(dtype) if x.dtype != dtype else x + + +def get_env_device(): + """ + Return the device name of running environment. + """ + if paddle.is_compiled_with_cuda(): + return "gpu" + elif "npu" in paddle.device.get_all_custom_device_type(): + return "npu" + elif "mlu" in paddle.device.get_all_custom_device_type(): + return "mlu" + elif "gcu" in paddle.device.get_all_custom_device_type(): + return "gcu" + elif "intel_hpu" in paddle.device.get_all_custom_device_type(): + return "intel_hpu" + elif paddle.is_compiled_with_rocm(): + return "rocm" + elif paddle.is_compiled_with_xpu(): + return "xpu" + return "cpu" diff --git a/examples/experiments/deepseek_v3_pretrain/run.sh b/examples/experiments/deepseek_v3_pretrain/run.sh new file mode 100644 index 00000000000..704e5aeba49 --- /dev/null +++ b/examples/experiments/deepseek_v3_pretrain/run.sh @@ -0,0 +1,23 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Download llama model data +# mkdir -p data +# wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k.bin +# wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k.idx + +# mpirun sh script/kill_process.sh +# mpirun rm -rf output +nohup bash script/train_gpu.sh ./config/pretrain_argument.json --reorder_pipeline_priority=True > run.log 2>&1 & + diff --git a/examples/experiments/deepseek_v3_pretrain/run_pretrain.py b/examples/experiments/deepseek_v3_pretrain/run_pretrain.py new file mode 100644 index 00000000000..84d123a2f3d --- /dev/null +++ b/examples/experiments/deepseek_v3_pretrain/run_pretrain.py @@ -0,0 +1,623 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import math +import os +import sys +import time +from dataclasses import dataclass, field +from typing import Optional + +import paddle +from config.configuration import DeepseekV2FastConfig +from load_hf_ckpt import load_huggingface_ckpt +from modeling_pp import DeepseekV2ForCausalLMPipe +from moe_utils import get_env_device + +from paddleformers.data.causal_dataset import ( + build_train_valid_test_datasets, + check_data_split, + print_rank_0, +) +from paddleformers.trainer import ( + FP8QuantWeightCallback, + PdArgumentParser, + StepFlexToken, + Trainer, + TrainingArguments, + get_last_checkpoint, + set_seed, + speed_metrics, +) +from paddleformers.transformers import ( + AutoTokenizer, + CosineAnnealingWithWarmupDecay, + LinearAnnealingWithWarmupDecay, +) +from paddleformers.transformers.configuration_utils import LlmMetaConfig, llmmetaclass +from paddleformers.transformers.deepseek_v2 import DeepseekV2ForCausalLM +from paddleformers.utils.batch_sampler import DistributedBatchSampler +from paddleformers.utils.log import logger + +# Pretaining Environment Variables to support sharding stage1 overlap optimization. +os.environ["USE_CASUAL_MASK"] = "True" + + +from paddleformers.trainer.utils.doc import add_start_docstrings + + +@dataclass +@llmmetaclass +@add_start_docstrings(TrainingArguments.__doc__) +class PreTrainingArguments(TrainingArguments): + min_learning_rate: float = field( + default=1e-5, + metadata={"help": "Minimum learning rate deacyed to."}, + ) + decay_steps: float = field( + default=None, + metadata={ + "help": "The steps use to control the learing rate. If the step > decay_steps, will use the min_learning_rate." + }, + ) + enable_linear_fused_grad_add: bool = field( + default=False, + metadata={ + "help": "Enable fused linear grad add strategy, which will reduce elementwise add for grad accumulation in the backward of nn.Linear ." + }, + ) + # NOTE(gongenlei): new add autotuner_benchmark + autotuner_benchmark: bool = field( + default=False, + metadata={"help": "Weather to run benchmark by autotuner. True for from_scratch and pad_max_length."}, + ) + unified_checkpoint: bool = field( + default=True, + metadata={"help": "Enable fused linear grad add strategy."}, + ) + + def __post_init__(self): + super().__post_init__() + # NOTE(gongenlei): new add autotuner_benchmark + from paddleformers.trainer.trainer_utils import IntervalStrategy + + if self.autotuner_benchmark: + self.max_steps = 5 + self.do_train = True + self.do_export = False + self.do_predict = False + self.do_eval = False + self.overwrite_output_dir = True + self.load_best_model_at_end = False + self.report_to = [] + self.save_strategy = IntervalStrategy.NO + self.evaluation_strategy = IntervalStrategy.NO + self.unified_checkpoint = False + + +@dataclass +class DataArguments: + """ + Arguments pertaining to what data we are going to input our model for training and evaluating. + Using `PdArgumentParser` we can turn this class into argparse arguments to be able to + specify them on the command line. + """ + + input_dir: str = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + split: str = field(default="949,50,1", metadata={"help": "Train/valid/test data split."}) + + max_seq_length: int = field( + default=1024, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + share_folder: bool = field( + default=False, + metadata={"help": "Use share folder for data dir and output dir on multi machine."}, + ) + + data_impl: str = field(default="mmap", metadata={"help": "The format of the preprocessed data."}) + skip_warmup: bool = field( + default=True, + metadata={"help": "Whether to skip the warmup process of mmap files."}, + ) + data_cache: str = field(default=None, metadata={"help": "The path of the cached dataset."}) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to pre-train from. + """ + + model_name_or_path: str = field( + default="__internal_testing__/tiny-random-llama", + metadata={ + "help": "Path to pretrained model or model identifier from https://paddleformers.readthedocs.io/zh/latest/model_zoo/transformers.html" + }, + ) + tokenizer_name_or_path: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + + use_fast_layer_norm: bool = field( + default=False, + metadata={"help": "GPT3 model, use fast layernorm"}, + ) + + hidden_dropout_prob: float = field(default=0.1, metadata={"help": "The hidden dropout prob."}) + attention_probs_dropout_prob: float = field(default=0.1, metadata={"help": "The attention hidden dropout prob."}) + + fuse_attention_qkv: bool = field( + default=None, + metadata={"help": "whether to fuse attention qkv"}, + ) + fuse_attention_ffn: bool = field( + default=None, + metadata={"help": "whether to fuse first up and gate proj in mlp block"}, + ) + + continue_training: bool = field( + default=False, + metadata={ + "help": "Pre-training from existing paddleformers model weights. Default False and model will train from scratch. If set True, the model_name_or_path argument must exist in the paddleformers models." + }, + ) + num_hidden_layers: Optional[int] = field( + default=None, + metadata={"help": "num_hidden_layers."}, + ) + + +def create_pretrained_dataset( + data_args, + training_args, + data_file, + tokenizer, + need_data=True, +): + + check_data_split(data_args.split, training_args.do_train, training_args.do_eval, training_args.do_predict) + + train_val_test_num_samples = [ + training_args.per_device_train_batch_size + * training_args.dataset_world_size + * training_args.max_steps + * training_args.gradient_accumulation_steps, + training_args.per_device_eval_batch_size + * training_args.dataset_world_size + * training_args.eval_iters + * (training_args.max_steps // training_args.eval_steps + 1), + training_args.per_device_eval_batch_size * training_args.dataset_world_size * training_args.test_iters, + ] + + print_rank_0(" > datasets target sizes (minimum size):") + if training_args.do_train: + print_rank_0(" train: {}".format(train_val_test_num_samples[0])) + if training_args.do_eval: + print_rank_0(" validation: {}".format(train_val_test_num_samples[1])) + if training_args.do_predict: + print_rank_0(" test: {}".format(train_val_test_num_samples[2])) + + # Build the datasets. + train_dataset, valid_dataset, test_dataset = build_train_valid_test_datasets( + data_prefix=data_file, + data_impl=data_args.data_impl, + splits_string=data_args.split, + train_val_test_num_samples=train_val_test_num_samples, + seq_length=data_args.max_seq_length, + seed=training_args.seed, + skip_warmup=data_args.skip_warmup, + share_folder=data_args.share_folder, + data_cache_path=data_args.data_cache, + need_data=need_data, + ) + + def print_dataset(data, mode="train"): + logger.info(f"Sample data for {mode} mode.") + # input_ids, loss_mask, attention_mask, position_ids, labels = data + input_ids = data["text"] + logger.info(tokenizer._decode(list(input_ids))) + + from paddleformers.data import Stack + + def _collate_data(data, stack_fn=Stack()): + tokens_ = stack_fn([x["text"] for x in data]) + + labels = copy.deepcopy(tokens_)[:, 1:] + tokens = tokens_[:, :-1] + + return { + "input_ids": tokens, + "labels": labels, + } + + if need_data: + if training_args.do_train: + print_dataset(train_dataset[0], "train") + if training_args.do_eval: + print_dataset(valid_dataset[0], "valid") + if training_args.do_predict: + print_dataset(test_dataset[0], "test") + + return train_dataset, valid_dataset, test_dataset, _collate_data + + +def get_train_data_file(args): + if len(args.input_dir.split()) > 1: + # weight-1 data-prefix-1 weight-2 data-prefix-2 ... + return args.input_dir.split() + else: + files = [ + os.path.join(args.input_dir, f) + for f in os.listdir(args.input_dir) + if (os.path.isfile(os.path.join(args.input_dir, f)) and ("_idx.npz" in str(f) or ".idx" in str(f))) + ] + files = [x.replace("_idx.npz", "") for x in files] + files = [x.replace(".idx", "") for x in files] + + if len(files) > 1: + ret = [] + logger.info("You are using multi-dataset:") + for x in files: + ret.append(1.0) + ret.append(x) + logger.info(" > set weight of %s dataset to 1.0" % x) + return ret + + return files + + +class PretrainingTrainer(Trainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_pretraining = True + + def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"): + # keep eval_dataloader + eval_dataloader = getattr(self, "eval_dataloader", None) + if eval_dataloader is None: + eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset + eval_dataloader = self.get_eval_dataloader(eval_dataset) + # must call data loader, otherwise, it will init many times, cause OOM error. + self.eval_dataloader = eval_dataloader() + + start_time = time.time() + # Temporarily disable metric computation, we will do it in the loop here. + compute_metrics = self.compute_metrics + eval_loop = self.evaluation_loop + + output = eval_loop( + eval_dataloader, + description="Evaluation", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if compute_metrics is None else None, + ignore_keys=ignore_keys, + # Only evaluate max_eval_iters + max_eval_iters=self.args.eval_iters, + ) + + total_batch_size = self.args.eval_batch_size * self.args.world_size + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) + + self.log(output.metrics) + + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) + return output.metrics + + def _get_eval_sampler(self, eval_dataset) -> Optional[paddle.io.Sampler]: + return DistributedBatchSampler( + eval_dataset, + batch_size=self.args.per_device_eval_batch_size, + shuffle=False, + num_replicas=self.args.dataset_world_size, + rank=self.args.dataset_rank, + drop_last=self.args.dataloader_drop_last, + ) + + def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: + return DistributedBatchSampler( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + shuffle=False, + num_replicas=self.args.dataset_world_size, + rank=self.args.dataset_rank, + drop_last=self.args.dataloader_drop_last, + ) + + +def main(): + parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments)) + # Support format as "args.json --arg1 value1 --arg2 value2.” + # In case of conflict, command line arguments take precedence. + if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): + model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines() + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if training_args.no_recompute_layers is not None: + training_args.no_recompute_layers.sort() + + if training_args.enable_linear_fused_grad_add: + from utils.fused_layers import mock_layers + + mock_layers() + + if model_args.tokenizer_name_or_path is None: + model_args.tokenizer_name_or_path = model_args.model_name_or_path + + if data_args.data_cache is not None: + os.makedirs(data_args.data_cache, exist_ok=True) + + paddle.set_device(training_args.device) + set_seed(seed=training_args.seed) + + training_args.eval_iters = 10 + training_args.test_iters = training_args.eval_iters * 10 + + # Log model and data config + training_args.print_config(model_args, "Model") + training_args.print_config(data_args, "Data") + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, " + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}" + ) + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + # if last_checkpoint is None and len( + # os.listdir(training_args.output_dir)) > 1: + # raise ValueError( + # f"Output directory ({training_args.output_dir}) already exists and is not empty. " + # "Use --overwrite_output_dir to overcome.") + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path, download_hub="huggingface") + config = DeepseekV2FastConfig.from_pretrained(model_args.model_name_or_path) + + # set all llm config + LlmMetaConfig.set_llm_config(config, training_args) + config.use_fast_layer_norm = model_args.use_fast_layer_norm + + config.seq_length = data_args.max_seq_length + # There are some technique extend RotaryEmbedding context. so don't change max_position_embeddings + if not model_args.continue_training: + config.max_position_embeddings = max(config.max_position_embeddings, data_args.max_seq_length) + + if not model_args.continue_training: + config.vocab_size = max(config.vocab_size, ((tokenizer.vocab_size - 1) // 128 + 1) * 128) + logger.info(f"Reset vocab size to {config.vocab_size} for batter amp peformance.") + + config.num_hidden_layers = ( + model_args.num_hidden_layers if model_args.num_hidden_layers is not None else config.num_hidden_layers + ) + # Config for model using dropout, such as GPT. + if hasattr(config, "use_dualpipev"): + # NOTE(zhangyuqin): In Paddle, the segmentation and scheduling of pipeline parallel + # models are separate. Therefore, first we need to set the flag in the model config + # to perform V-shape segmentation. Second, we need to set the flag in the training_args + # to configure strategy.hybrid_configs to choose the DualPipeV schedule. + config.use_dualpipev = "use_dualpipev" in training_args.pipeline_parallel_config + if hasattr(config, "hidden_dropout_prob"): + config.hidden_dropout_prob = model_args.hidden_dropout_prob + if hasattr(config, "attention_probs_dropout_prob"): + config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob + if model_args.fuse_attention_qkv is not None: + config.fuse_attention_qkv = model_args.fuse_attention_qkv + if model_args.fuse_attention_ffn is not None: + config.fuse_attention_ffn = model_args.fuse_attention_ffn + + if config.sequence_parallel: + assert config.tensor_parallel_degree > 1, "tensor_parallel_degree must be larger than 1 for sequence parallel." + assert ( + config.num_attention_heads % config.sep_parallel_degree == 0 + ), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}" + assert ( + config.seq_length % config.context_parallel_degree == 0 + ), f"seq_length:{config.seq_length} must be divisible by context_parallel_degree {config.context_parallel_degree}" + + if training_args.sharding_parallel_config is not None: + # for stage1 overlap optimization + if ( + "enable_stage1_allgather_overlap" in training_args.sharding_parallel_config + or "enable_stage1_broadcast_overlap" in training_args.sharding_parallel_config + ): + from paddle.io.reader import use_pinned_memory + + use_pinned_memory(False) + + if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1: + try: + from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401 + + LinearConfig.enable_accumulate_steps_opt() + LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps) + except ImportError: + # It's OK, not use accumulate_steps optimization + pass + + print("Final pre-training config:", config) + + # Set the dtype for loading model + dtype = "float32" + if training_args.fp16_opt_level == "O2": + if training_args.fp16: + dtype = "float16" + if training_args.bf16: + dtype = "bfloat16" + + model_class = DeepseekV2ForCausalLM + if training_args.pipeline_parallel_degree > 1: + model_class = DeepseekV2ForCausalLMPipe + if "LLama" in str(config.architectures): + try: + from utils.register_reshard import register_pp_reshard_information + + register_pp_reshard_information(config.num_hidden_layers) + except: + print("Not register llama pp reshard information.") + + architectures_to_check = {"Qwen2Moe", "DeepseekV2", "DeepseekV3"} + if ( + any(architecture in str(config.architectures) for architecture in architectures_to_check) + and training_args.data_parallel_degree > 1 + ): + training_args.use_expert_parallel = True + + if model_args.continue_training: + # NOTE(gongenlei): new add + if training_args.autotuner_benchmark: + model = model_class.from_config(config, dtype=dtype) + else: + model = model_class.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + ) + else: + # Modify here to reduce the number of model layers. The first 3 layers of DeepSeek are dense layers, and sparse layers appear after that. + # config.num_hidden_layers = 4 # v3 uses 61 + # config.first_k_dense_replace = 0 # v3 uses 3 + # Modify here to reduce the number of experts in the model. If EP (Expert Parallelism) is desired, the number of experts should be divisible by the parallelism degree. + # config.n_routed_experts = 64 # v3 uses 256 + # config.num_experts_per_tok = 8 # v3 uses 8 + # config.topk_group = 4 # v3 uses 4 + + # config.using_flex_token = True + # config.num_nextn_predict_layers = 1 + # config.using_fake_gate = True + # config.use_fused_rms_norm = True + # config.fuse_attention_ffn = True + # config.use_fused_rope = True + # config.token_drop_steps = 0 + model = model_class.from_config(config, dtype=dtype) + + if training_args.recompute: + model.recompute_enable() + + # Create the learning_rate sheduler and optimizer + if training_args.decay_steps is None: + training_args.decay_steps = training_args.max_steps + + if training_args.warmup_steps > 0: + warmup_steps = training_args.warmup_steps + else: + warmup_steps = training_args.warmup_ratio * training_args.max_steps + + lr_scheduler = None + if training_args.lr_scheduler_type.value == "cosine": + lr_scheduler = CosineAnnealingWithWarmupDecay( + max_lr=training_args.learning_rate, + min_lr=training_args.min_learning_rate, + warmup_step=warmup_steps, + decay_step=training_args.decay_steps, + last_epoch=0, + ) + elif training_args.lr_scheduler_type.value == "linear": + lr_scheduler = LinearAnnealingWithWarmupDecay( + max_lr=training_args.learning_rate, + min_lr=training_args.min_learning_rate, + warmup_step=warmup_steps, + decay_step=training_args.decay_steps, + last_epoch=0, + ) + + data_file = get_train_data_file(data_args) + train_dataset, eval_dataset, test_dataset, data_collator = create_pretrained_dataset( + data_args, + training_args, + data_file, + tokenizer, + need_data=training_args.should_load_dataset, + ) + + total_effective_tokens = ( + training_args.per_device_train_batch_size + * training_args.dataset_world_size + * training_args.max_steps + * training_args.gradient_accumulation_steps + * data_args.max_seq_length + ) + + callbacks = [StepFlexToken(), FP8QuantWeightCallback()] + + def resume_from_custom_func(model): + if training_args.resume_from_huggingface_ckpt: + load_huggingface_ckpt(model, training_args.resume_from_huggingface_ckpt) + else: + logger.info("No resume from checkpoint since training args 'resume_from_huggingface_ckpt' is None.") + + trainer = PretrainingTrainer( + model=model, + args=training_args, + data_collator=data_collator, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + optimizers=(None, lr_scheduler), + tokenizer=tokenizer, + callbacks=callbacks, + resume_from_custom_func=resume_from_custom_func, + ) + + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=checkpoint) + + # NOTE(gongenlei): new add + if not training_args.autotuner_benchmark: + metrics = train_result.metrics + if not int(os.getenv("test_ci_no_save_model", 0)): + trainer.save_model() + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + if training_args.do_predict: + test_ret = trainer.predict(test_dataset) + trainer.log_metrics("test", test_ret.metrics) + + if training_args.do_train and training_args.should_load_dataset: + effective_tokens_per_second = total_effective_tokens / train_result.metrics["train_runtime"] + print(f"Effective Tokens per second: {effective_tokens_per_second:.2f}") + print(f"ips: {effective_tokens_per_second:.2f} tokens/s") + + +if __name__ == "__main__": + main() diff --git a/examples/experiments/deepseek_v3_pretrain/script/kill_process.sh b/examples/experiments/deepseek_v3_pretrain/script/kill_process.sh new file mode 100755 index 00000000000..3c3db6a4639 --- /dev/null +++ b/examples/experiments/deepseek_v3_pretrain/script/kill_process.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -x +skip_kill_time=${1:-"False"} +function kill_impl() { + skip_kill_time=$1 + # kill aadiff test finally. + pids=`ps -ef | grep pretrain.py | grep -v grep | awk '{print $2}'` + if [[ "$pids" != "" ]] ; then + echo $pids + echo $pids | xargs kill -9 + fi + + echo "Killing processes on gpu" + lsof /dev/nvidia* | awk '{print $2}' | xargs -I {} kill -9 {} +} + +kill_impl $skip_kill_time || true \ No newline at end of file diff --git a/examples/experiments/deepseek_v3_pretrain/script/selective_launch.py b/examples/experiments/deepseek_v3_pretrain/script/selective_launch.py new file mode 100644 index 00000000000..6f87c5ea073 --- /dev/null +++ b/examples/experiments/deepseek_v3_pretrain/script/selective_launch.py @@ -0,0 +1,73 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Selective launch script. + +Usage: python script/selective_launch.py ... +""" +import os +import sys + + +def parse_ranks(ranks_strs): + """ + parse_ranks + """ + # NOTE: You can return ranks directly here to change script/train_gpu.sh + # and script/kill_process.sh together + + # Example 1: Use contiguous nodes [8, 16) + return range(0, 1) + + # Example 2: Use non-contiguous nodes [4, 8) + {10} + [30, 32), i.e., [4, 5, 6, 7, 10, 30, 31] + # return list(range(0, 16)) + list(range(24, 40)) + + # Example 3: + # Just Python code, return any nodes you want! + + if not ranks_strs: + return None + + ranks = [] + for r in ranks_strs: + r = eval(r) + if isinstance(r, int): + ranks.append(r) + else: + ranks.extend(r) + return ranks + + +def main(port, ranks): + """ + main + """ + ips = [ip.strip() for ip in os.getenv("TRAINER_INSTANCES").split(",") if ip.strip()] + if ranks is None: + ranks = list(range(len(ips))) + ranks = sorted(list(set(ranks))) + my_rank = int(os.getenv("POD_INDEX", "0")) + if my_rank not in ranks: + return + + rank = ranks.index(my_rank) + nranks = len(ranks) + + master = ips[ranks[0]] + print(f"--master {master}:{port} --rank {rank} --nnodes {nranks}") + + +if __name__ == "__main__": + main(int(sys.argv[1]), parse_ranks(sys.argv[2:])) diff --git a/examples/experiments/deepseek_v3_pretrain/script/train_gpu.sh b/examples/experiments/deepseek_v3_pretrain/script/train_gpu.sh new file mode 100644 index 00000000000..bc2f5c5c869 --- /dev/null +++ b/examples/experiments/deepseek_v3_pretrain/script/train_gpu.sh @@ -0,0 +1,67 @@ +#!/bin/bash + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT + +nnodes=$PADDLE_TRAINERS_NUM +rank=$PADDLE_TRAINER_ID + +for name in `env | grep -E 'PADDLE|ENDPOINT' | awk -F'=' '{print $1}'`; do + unset ${name} +done + +#export FLAGS_shard_bypass_dygraph_optimizer=1 +export NCCL_IB_GID_INDEX=3 +export NVSHMEM_IB_GID_INDEX=3 +export NVSHMEM_IB_TRAFFIC_CLASS=162 + +#export NVSHMEM_IB_ENABLE_IBGDA=true +##export NVSHMEM_DISABLE_P2P=1 +export NVSHMEM_BOOTSTRAP=UID + +unset NVSHMEM_HCA_LIST +unset NVSHMEM_ENABLE_NIC_PE_MAPPING + +LAUNCH_CMD=`python script/selective_launch.py 36677` +if [[ -z "$LAUNCH_CMD" ]]; then + exit 0 +fi + +export PYTHONPATH=../../../:$PYTHONPATH +# export PYTHONPATH=../../PaddleFormers/:$PYTHONPATH + +export CUDA_PATH=/usr/local/cuda-12.9 + + +# Flags for allocator +export FLAGS_large_pool_auto_growth_chunk_size_in_mb=500 +export FLAGS_small_pool_auto_growth_chunk_size_in_mb=20 +export FLAGS_small_pool_size_in_mb=10 +export FLAGS_samll_pool_pre_alloc_in_mb=500 + +bash script/kill_process.sh + +source /root/paddlejob/workspace/env_run/chenxi/chenxi_py3.10/bin/activate +python3.10 -m paddle.distributed.launch \ + --log_dir output/paddle_distributed_logs \ + $LAUNCH_CMD \ + --run_mode=collective \ + ${script:-run_pretrain.py} \ + $@ diff --git a/examples/experiments/deepseek_v3_pretrain/token_dispatcher.py b/examples/experiments/deepseek_v3_pretrain/token_dispatcher.py new file mode 100644 index 00000000000..8e6a3f1fb69 --- /dev/null +++ b/examples/experiments/deepseek_v3_pretrain/token_dispatcher.py @@ -0,0 +1,405 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 DeepSeek +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from typing import Optional, Tuple + +import paddle +from moe_utils import permute_fast as permute +from moe_utils import topk_to_permuted_indices +from moe_utils import unpermute_fast as unpermute +from paddle.distributed.communication.group import Group + +from paddleformers.transformers import _DispatchManager +from paddleformers.transformers.fused_a2a import fused_combine, fused_dispatch + + +class _DeepepManager(_DispatchManager): + """ + A manager class to handle fused all-to-all communication processes for MoE models using + DeepEP backend. See https://github.com/deepseek-ai/deepep for more details. + + The workflow of the DeepEP dispatcher is: + (1) setup_metadata(): Process routing map and probabilities to prepare dispatch metadata + (2) dispatch(): + - Use fused kernel to permute tokens and perform all-to-all communication in single step + (3) get_permuted_hidden_states_by_instances(): + - Convert routing map and probabilities to multihot format + - Permute tokens using fused kernel + (4) get_restored_hidden_states_by_instances(): + - Reverse permutation using fused kernel + (5) combine(): + - Reverse process using fused kernel to unpermute and perform all-to-all in single step + + This implementation uses fused communication kernels (fused_dispatch/fused_combine) that + combine permutation and communication operations for improved efficiency compared to + separate permute+alltoall steps. + """ + + def __init__( + self, + group: Group, + router_topk: int, + num_experts: int = None, + num_local_experts: int = None, + ): + self.group = group + self.router_topk = router_topk + self.num_experts = num_experts + self.num_local_experts = num_local_experts + + # Metadata + self.token_indices = None + self.token_probs = None + # Handle used for combine operation + self.handle = None + + if fused_dispatch is None: + raise ImportError("DeepEP is not supported in your paddlepaddle whl package.") + + def setup_metadata(self, routing_map: paddle.Tensor, probs: paddle.Tensor): + num_tokens = routing_map.shape[0] + + routing_map = routing_map.reshape([num_tokens, self.num_experts]) + probs = probs.reshape([num_tokens, self.num_experts]) + # Convert the format of routing map from multihot to indices. + self.token_probs, self.token_indices = paddle.topk(probs, self.router_topk, axis=-1) + + def dispatch( + self, hidden_states: paddle.Tensor, token_indices: paddle.Tensor, token_probs: paddle.Tensor + ) -> paddle.Tensor: + hidden_states, dispatched_probs, states = fused_dispatch( + hidden_states, token_indices, token_probs, self.num_experts, self.group + ) + self.handle = states["handle"] + self.tokens_per_expert_list = states["tokens_per_expert"] + dispatched_indices = states["dispatched_indices"] + + return hidden_states, dispatched_indices, dispatched_probs + + def _indices_to_multihot(self, indices, probs): + """ + Converts a tensor of indices to a multihot vector. + + Args: + indices (paddle.Tensor): [num_tokens, topk] token indices, where -1 means masked out. + probs (paddle.Tensor): [num_tokens, topk] token probabilities. + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: + - routing_map: Multihot vector. + - probs: Multihot probabilities. + """ + batch_size = indices.shape[0] + multihot_routing_map = paddle.zeros((batch_size, self.num_local_experts), dtype=paddle.int64) + + multihot_probs = paddle.zeros((batch_size, self.num_local_experts), dtype=paddle.float32) + + mask = indices != -1 + valid_indices = indices[mask] + row_indices = paddle.arange(batch_size).repeat_interleave(mask.sum(axis=1)) + multihot_routing_map[row_indices, valid_indices] = 1 + multihot_probs[row_indices, valid_indices] = probs[mask] + return multihot_routing_map.cast(paddle.bool), multihot_probs + + def get_dispatched_metadata(self) -> paddle.Tensor: + return self.dispatched_indices, self.dispatched_probs + + def get_number_of_tokens_per_expert(self) -> paddle.Tensor: + """ + Get the number of tokens per expert. + """ + return self.tokens_per_expert + + def combine(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + hidden_states = fused_combine(hidden_states, self.group, self.handle) + # Release the handle after combine operation + self.handle = None + return hidden_states + + def get_permuted_hidden_states_by_experts(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + self.dispatched_routing_map, self.dispatched_probs = self._indices_to_multihot( + self.dispatched_indices, self.dispatched_probs + ) + self.hidden_shape_before_permute = hidden_states.shape + hidden_states, self.reversed_mapping_for_combine = permute( + hidden_states, + self.dispatched_routing_map, + num_out_tokens=sum(self.tokens_per_expert), + ) + return hidden_states + + def get_permuted_hidden_states_by_experts_fast( + self, hidden_states: paddle.Tensor, dispatched_indices: paddle.Tensor + ) -> paddle.Tensor: + self.hidden_shape_before_permute = hidden_states.shape + token_permuted_indices, prob_permuted_indices = topk_to_permuted_indices( + dispatched_indices, self.tokens_per_expert_list, self.router_topk + ) + hidden_states = permute(hidden_states, token_permuted_indices) + return hidden_states, token_permuted_indices, prob_permuted_indices + + def get_restored_hidden_states_by_experts(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + input_dtype = hidden_states.dtype + assert self.dispatched_probs.dtype == paddle.float32, "DeepEP only supports float32 probs" + hidden_states = unpermute( + hidden_states, + self.reversed_mapping_for_combine, + restore_shape=self.hidden_shape_before_permute, + routing_map=self.dispatched_routing_map, + probs=self.dispatched_probs, + ) + return hidden_states.to(input_dtype) + + def get_restored_hidden_states_by_experts_fast( + self, + hidden_states: paddle.Tensor, + token_permuted_indices: paddle.Tensor, + prob_permuted_indices: paddle.Tensor, + dispatched_probs: paddle.Tensor, + ) -> paddle.Tensor: + input_dtype = hidden_states.dtype + assert dispatched_probs.dtype == paddle.float32, "DeepEP only supports float32 probs" + hidden_states = unpermute( + permuted_tokens=hidden_states, + token_permuted_indices=token_permuted_indices, + prob_permuted_indices=prob_permuted_indices, + restore_shape=self.hidden_shape_before_permute, + probs=dispatched_probs, + ) + return hidden_states.to(input_dtype) + + +class MoETokenDispatcher: + """ + MoE Token Dispatcher + """ + + def __init__(self, ep_group) -> None: + """ + Initialize the MoE Token Dispatcher. + """ + self._ep_group = ep_group + + @property + def ep_group(self): + """Get expert model parallel group.""" + return self._ep_group + + @property + def ep_size(self): + """Get expert model parallel world_size.""" + return self.ep_group.world_size + + @abstractmethod + def token_permutation(self, tokens: paddle.Tensor, probs: paddle.Tensor, routing_map: paddle.Tensor): + """Dispatch tokens to experts. + + Args: + tokens (paddle.Tensor): Input tokens. + probs (paddle.Tensor): The routing probability tensor [num_tokens, num_experts]. + routing_map (paddle.Tensor): Token to expert mapping tensor. + + Returns: + paddle.Tensor: Tokens tensor. + """ + raise NotImplementedError("Dispatch function not implemented.") + + @abstractmethod + def token_unpermutation(self, expert_output: paddle.Tensor, bias: paddle.Tensor = None): + """Restores the expert output to its original ordering. + + Args: + expert_output (paddle.Tensor): The output tensor from the expert models. + bias (paddle.Tensor): The bias tensor. + + Returns: + (paddle.Tensor, paddle.Tensor): Unpermuted activation and optional bias. + """ + raise NotImplementedError("Restore function not implemented.") + + +class MoEFlexTokenDispatcher(MoETokenDispatcher): + """ + Flexible token dispatcher for MoE models with Efficient-A2A communication kernels. + """ + + def __init__(self, num_local_experts: int, moe_router_topk: int, num_moe_experts: int, ep_group: Group): + super().__init__(ep_group) + + self.num_local_experts = num_local_experts + assert self.ep_size > 1, "Flex token dispatcher requires EP > 1" + self._comm_manager = _DeepepManager( + group=self.ep_group, + router_topk=moe_router_topk, + num_experts=num_moe_experts, + num_local_experts=self.num_local_experts, + ) + + def token_permutation( + self, hidden_states: paddle.Tensor, probs: paddle.Tensor, routing_map: paddle.Tensor + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + self.hidden_shape = hidden_states.shape + hidden_states = hidden_states.view([-1, self.hidden_shape[-1]]) + + self._comm_manager.setup_metadata(routing_map, probs) + hidden_states, _, _ = self._comm_manager.dispatch(hidden_states) + global_input_tokens = self._comm_manager.get_permuted_hidden_states_by_experts(hidden_states) + tokens_per_expert = self._comm_manager.get_number_of_tokens_per_expert() + + return global_input_tokens, tokens_per_expert + + def token_unpermutation( + self, hidden_states: paddle.Tensor, bias: Optional[paddle.Tensor] = None + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]: + assert bias is None, "Bias is not supported in MoEFlexTokenDispatcher" + hidden_states = self._comm_manager.get_restored_hidden_states_by_experts(hidden_states) + hidden_states = self._comm_manager.combine(hidden_states) + + hidden_states = hidden_states.reshape(self.hidden_shape) + return hidden_states, None + + +class MoEFlexTokenDispatcherFast: + """ + Flexible token dispatcher for MoE models with Efficient-A2A communication kernels. + """ + + def __init__(self, num_local_experts: int, moe_router_topk: int, num_moe_experts: int, ep_group: Group): + self._ep_group = ep_group + + self.num_local_experts = num_local_experts + assert self.ep_size > 1, "Flex token dispatcher requires EP > 1" + self._comm_manager = _DeepepManager( + group=self.ep_group, + router_topk=moe_router_topk, + num_experts=num_moe_experts, + num_local_experts=self.num_local_experts, + ) + + @property + def ep_group(self): + """Get expert model parallel group.""" + return self._ep_group + + @property + def ep_size(self): + """Get expert model parallel world_size.""" + return self.ep_group.world_size + + def pre_dispatch(self, hidden_states, probs, routing_map): + self.hidden_shape = hidden_states.shape + hidden_states = hidden_states.view([-1, self.hidden_shape[-1]]) + num_tokens = routing_map.shape[0] + routing_map = routing_map.reshape([num_tokens, self._comm_manager.num_experts]) + probs = probs.reshape([num_tokens, self._comm_manager.num_experts]) + # Convert the format of routing map from multihot to indices. + token_probs, token_indices = paddle.topk(probs, self._comm_manager.router_topk, axis=-1) + return hidden_states, token_indices, token_probs + + def post_dispatch(self, hidden_states, dispatched_indices): + ( + global_input_tokens, + token_permuted_indices, + prob_permuted_indices, + ) = self._comm_manager.get_permuted_hidden_states_by_experts_fast(hidden_states, dispatched_indices) + return (global_input_tokens, token_permuted_indices, prob_permuted_indices) + + def pre_combine(self, hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs): + hidden_states = self._comm_manager.get_restored_hidden_states_by_experts_fast( + hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs + ) + return hidden_states + + def post_combine(self, hidden_states): + hidden_states = hidden_states.reshape(self.hidden_shape) + return hidden_states + + def token_permutation( + self, hidden_states: paddle.Tensor, probs: paddle.Tensor, routing_map: paddle.Tensor + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + hidden_states, token_indices, token_probs = self.pre_dispatch(hidden_states, probs, routing_map) + hidden_states, dispatched_indices, dispatched_probs = self._comm_manager.dispatch( + hidden_states, token_indices, token_probs + ) + (global_input_tokens, token_permuted_indices, prob_permuted_indices) = self.post_dispatch( + hidden_states, dispatched_indices + ) + + return ( + global_input_tokens, + token_permuted_indices, + prob_permuted_indices, + dispatched_probs, + ) + + def token_unpermutation( + self, + hidden_states: paddle.Tensor, + token_permuted_indices, + prob_permuted_indices, + dispatched_probs, + bias: Optional[paddle.Tensor] = None, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]: + assert bias is None, "Bias is not supported in MoEFlexTokenDispatcher" + hidden_states = self.pre_combine( + hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs + ) + hidden_states = self._comm_manager.combine(hidden_states) + + hidden_states = self.post_combine(hidden_states) + return hidden_states, None + + +class PreDispatchNode: + def __init__(self, token_dispatcher): + self.token_dispatcher = token_dispatcher + self.probs_origin_shape = None + + def reset_status(self): + self.probs = None + self.reshaped_probs = None + self.token_indices = None + + @paddle.no_grad() + def forward(self, routing_map, probs): + num_tokens = routing_map.shape[0] + self.probs_origin_shape = probs.shape + # routing_map = routing_map.reshape([num_tokens, token_dispatcher._comm_manager.num_experts]) + self.probs = probs + reshaped_probs = probs.reshape([num_tokens, self.token_dispatcher._comm_manager.num_experts]) + self.reshaped_probs = reshaped_probs + token_probs, token_indices = paddle.topk( + reshaped_probs, self.token_dispatcher._comm_manager.router_topk, axis=-1 + ) + self.token_indices = token_indices + token_probs.stop_gradient = False + return token_indices, token_probs + + @paddle.no_grad() + def backward(self, token_probs_g): + probs_grad = paddle._C_ops.topk_grad( + self.reshaped_probs, + self.token_indices, + token_probs_g, + self.token_dispatcher._comm_manager.router_topk, + -1, + True, + True, + ) + probs_reshape_g = paddle._C_ops.reshape_grad(self.probs, probs_grad) + self.reset_status() + return probs_reshape_g diff --git a/paddleformers/data/blendable_dataset.py b/paddleformers/data/blendable_dataset.py new file mode 100644 index 00000000000..44e16594119 --- /dev/null +++ b/paddleformers/data/blendable_dataset.py @@ -0,0 +1,184 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import importlib.metadata +import os +import time + +import numpy as np +import paddle + +local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) + + +def print_rank_0(*args, **kwargs): + if paddle.distributed.get_rank() == 0: + print(*args, **kwargs) + + +class BlendableDataset(paddle.io.Dataset): + def __init__(self, datasets, weights, size, share_folder, *, data_cache_path=None): + + self.datasets = datasets + num_datasets = len(datasets) + assert num_datasets == len(weights) + + self.size = size + + # Normalize weights. + weights = np.array(weights, dtype=np.float64) + sum_weights = np.sum(weights) + assert sum_weights > 0.0 + weights /= sum_weights + + # Build indices. + def _build_indices(): + start_time = time.time() + + fast_dataindex_version = importlib.metadata.version("fast_dataindex") + if fast_dataindex_version > "0.1.1": + assert ( + num_datasets < 32767 + ), f"Detect num_datasets({num_datasets})>=32767. Currently, num_datasets should be less than 32767." + dataset_index = np.zeros(self.size, dtype=np.int16) + else: + assert ( + num_datasets < 255 + ), f"Detect num_datasets:({num_datasets})>=255. When 'fast_dataindex<=0.1.1', num_datasets should be less than 255. To support num_datasets greater than 255, please upgrade `fast_dataindex>=0.1.2`." + dataset_index = np.zeros(self.size, dtype=np.uint8) + dataset_sample_index = np.zeros(self.size, dtype=np.int64) + + from fast_dataindex import helpers + + helpers.build_blending_indices( + dataset_index, + dataset_sample_index, + weights, + num_datasets, + self.size, + local_rank == 0, + # paddle.distributed.get_rank() == 0, + ) + print_rank_0( + "> elapsed time for building blendable dataset indices: " + "{:.2f} (sec)".format(time.time() - start_time) + ) + return dataset_index, dataset_sample_index + + desc = "Blendable dataset\n\n" + desc += "Datasets:\n" + for dataset in datasets: + desc += dataset.desc + "\n\n" + desc += f"Weights: {weights}\n" + desc += f"Size: {size}\n" + self.desc = desc + + if data_cache_path: + desc_hash = hashlib.md5(desc.encode("utf-8")).hexdigest() + desc_path = os.path.join(data_cache_path, desc_hash + ".dsc") + index_path = os.path.join(data_cache_path, desc_hash + "_index.npy") + sample_index_path = os.path.join(data_cache_path, desc_hash + "_sample_index.npy") + cache_hit = os.path.isfile(index_path) and os.path.isfile(sample_index_path) + # cache_success = True + # if paddle.distributed.get_rank() == 0 and not cache_hit: + check_rank_flag = not cache_hit and local_rank == 0 + if share_folder: + check_rank_flag = not cache_hit and paddle.distributed.get_rank() == 0 + + print( + f"searching for blendable dataset, cache_hit={cache_hit}, share_folder {share_folder}, check_rank_flag {check_rank_flag}", + flush=True, + ) + if check_rank_flag: + print( + " > WARNING: could not find index map files for blendable" + " dataset, building indices on rank 0 ...", + flush=True, + ) + dataset_index, dataset_sample_index = _build_indices() + try: + os.makedirs(os.path.dirname(index_path), exist_ok=True) + with open(desc_path, "wt") as fd: + fd.write(desc) + np.save(index_path, dataset_index, allow_pickle=True) + np.save(sample_index_path, dataset_sample_index, allow_pickle=True) + except OSError: + print(f"There was an error trying to create the data cache directory ({data_cache_path})") + print("or a file in it. This is set with the --data-cache-path argument. Please") + print("ensure you have write access to this directory or specify one that you do have") + print("write access to.") + # cache_success = False + + # hcg = paddle.distributed.fleet.get_hybrid_communicate_group() + + # counts = paddle.to_tensor([cache_success], dtype="int64") + # paddle.distributed.all_reduce(counts, group=hcg.get_data_parallel_group()) + # paddle.distributed.all_reduce(counts, group=hcg.get_pipeline_model_parallel_group()) + # if counts[0].item() != ( + # paddle.distributed.get_world_size() + # // paddle.distributed.get_world_size(group=hcg.get_tensor_model_parallel_group()) + # ): + # print_rank_0("Data index creation unsuccessful, exiting.") + # exit() + + else: + while True: + if (not os.path.isfile(index_path)) or (not os.path.isfile(sample_index_path)): + print("building indices on rank 0 ...", flush=True) + time.sleep(3) + else: + try: + np.load(index_path, allow_pickle=True, mmap_mode="r") + print("build success", flush=True) + break + except Exception: + print("%s file is still writing or damaged, please wait for a moment." % index_path) + time.sleep(3) + + # paddle.distributed.barrier() + # Load on all ranks. + print_rank_0(f"> loading blendable dataset index: {index_path}") + self.dataset_index = np.load(index_path, allow_pickle=True, mmap_mode="r") + assert self.dataset_index.size == self.size + + print_rank_0(f"> loading blendable dataset sample index: {sample_index_path}") + self.dataset_sample_index = np.load(sample_index_path, allow_pickle=True, mmap_mode="r") + assert self.dataset_sample_index.size == self.size + else: + print_rank_0( + "building indices for the blendable dataset, Since --data_cache is not specified, the index file will not be stored.", + flush=True, + ) + self.dataset_index, self.dataset_sample_index = _build_indices() + + # Check size + _ = self.__getitem__(self.size - 1) + try: + _ = self.__getitem__(self.size) + raise RuntimeError("BlendedDataset size is improperly bounded") + except IndexError: + pass + print_rank_0("> size of blendable dataset: " "{} samples".format(self.size)) + + def __len__(self): + return self.size + + def __getitem__(self, idx): + dataset_idx = self.dataset_index[idx] + sample_idx = self.dataset_sample_index[idx] + return { + "dataset_idx": dataset_idx, + **self.datasets[dataset_idx][sample_idx], + } diff --git a/paddleformers/data/causal_dataset.py b/paddleformers/data/causal_dataset.py new file mode 100644 index 00000000000..d3789287649 --- /dev/null +++ b/paddleformers/data/causal_dataset.py @@ -0,0 +1,711 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""GPT style dataset.""" +import hashlib +import math +import os +import time + +import numpy as np +import paddle + +from .blendable_dataset import BlendableDataset +from .indexed_dataset import make_dataset as make_indexed_dataset + +local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) + + +# class FakeHCG: +# def get_data_parallel_group(self): +# return None + +# def get_pipe_parallel_group(self): +# return None + +# def get_model_parallel_group(self): +# return None + + +def check_data_split(splits_string, do_train, do_eval, do_predict): + splits = [] + if splits_string.find(",") != -1: + splits = [float(s) for s in splits_string.split(",")] + elif splits_string.find("/") != -1: + splits = [float(s) for s in splits_string.split("/")] + else: + splits = [float(splits_string)] + while len(splits) < 3: + splits.append(0.0) + splits = splits[:3] + splits_sum = sum(splits) + data_flag = True + assert splits_sum > 0.0, "sum of splits should larger than 0.0!" + if (do_train and splits[0] == 0) or (do_eval and splits[1] == 0) or (do_predict and splits[2] == 0): + data_flag = False + if not data_flag: + raise ValueError("If do_train/do_eval/do_predict is True, the corresponding dataset split should not be 0!") + + +def get_train_valid_test_split_(splits_string, size): + """Get dataset splits from comma or '/' separated string list.""" + + splits = [] + if splits_string.find(",") != -1: + splits = [float(s) for s in splits_string.split(",")] + elif splits_string.find("/") != -1: + splits = [float(s) for s in splits_string.split("/")] + else: + splits = [float(splits_string)] + while len(splits) < 3: + splits.append(0.0) + splits = splits[:3] + splits_sum = sum(splits) + assert splits_sum > 0.0 + splits = [split / splits_sum for split in splits] + splits_index = [0] + for index, split in enumerate(splits): + splits_index.append(splits_index[index] + int(round(split * float(size)))) + diff = splits_index[-1] - size + for index in range(1, len(splits_index)): + splits_index[index] -= diff + assert len(splits_index) == 4 + assert splits_index[-1] == size + return splits_index + + +def get_datasets_weights_and_num_samples(data_prefix, train_val_test_num_samples): + + # The data prefix should be in the format of: + # weight-1, data-prefix-1, weight-2, data-prefix-2, .. + assert len(data_prefix) % 2 == 0 + num_datasets = len(data_prefix) // 2 + weights = [0] * num_datasets + prefixes = [0] * num_datasets + for i in range(num_datasets): + weights[i] = float(data_prefix[2 * i]) + prefixes[i] = (data_prefix[2 * i + 1]).strip() + # Normalize weights + weight_sum = 0.0 + for weight in weights: + weight_sum += weight + assert weight_sum > 0.0 + weights = [weight / weight_sum for weight in weights] + + # Add 0.5% (the 1.005 factor) so in case the blending dataset does + # not uniformly distribute the number of samples, we still have + # samples left to feed to the network. + # (NOTE, yujun06): This is a workaround to avoid issues with indexing in the blending dataset. Therefore, we need to add 20 samples to each dataset. + datasets_train_valid_test_num_samples = [] + for weight in weights: + datasets_train_valid_test_num_samples.append( + [int(math.ceil(val * weight * 1.005)) + 20 for val in train_val_test_num_samples] + ) + + return prefixes, weights, datasets_train_valid_test_num_samples + + +def print_rank_0(*args, **kwargs): + if paddle.distributed.get_rank() == 0: + print(*args, **kwargs) + + +def build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + train_val_test_num_samples, + seq_length, + seed, + skip_warmup, + train_data_prefix=None, + valid_data_prefix=None, + test_data_prefix=None, + return_doc_ids=False, + share_folder=False, + *, + data_cache_path=None, + need_data=True, +): + """Build train, valid, and test datasets.""" + + # Single dataset. + if len(data_prefix) == 1: + return _build_train_valid_test_datasets( + data_prefix[0], + data_impl, + splits_string, + train_val_test_num_samples, + seq_length, + seed, + skip_warmup, + share_folder=share_folder, + data_cache_path=data_cache_path, + need_data=need_data, + ) + + # Blending dataset. + # Parse the values. + output = get_datasets_weights_and_num_samples(data_prefix, train_val_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + # NOTE: megatron/gpt_dataset.py has been updated. When creating BlendableDataset, we will use the raw train_val_test_num_samples instead of the expanded ones. + # Please refer to https://github.com/NVIDIA/NeMo/blob/72f630d087d45655b1a069dc72debf01dfdbdb2d/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py#L74-L80 for more information + train_num_samples, valid_num_samples, test_num_samples = train_val_test_num_samples + + # Build individual datasets. + train_datasets = [] + valid_datasets = [] + test_datasets = [] + for i in range(len(prefixes)): + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + prefixes[i], + data_impl, + splits_string, + datasets_train_valid_test_num_samples[i], + seq_length, + seed, + skip_warmup, + return_doc_ids, + share_folder=share_folder, + data_cache_path=data_cache_path, + need_data=need_data, + ) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + blending_train_dataset = None + if train_datasets: + blending_train_dataset = BlendableDataset( + train_datasets, weights, train_num_samples, share_folder, data_cache_path=data_cache_path + ) + blending_valid_dataset = None + if valid_datasets: + blending_valid_dataset = BlendableDataset( + valid_datasets, weights, valid_num_samples, share_folder, data_cache_path=data_cache_path + ) + blending_test_dataset = None + if test_datasets: + blending_test_dataset = BlendableDataset( + test_datasets, + weights, + test_num_samples, + share_folder, + data_cache_path=data_cache_path, + ) + + return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) + + +def _build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + train_val_test_num_samples, + seq_length, + seed, + skip_warmup, + return_doc_ids=False, + share_folder=False, + *, + data_cache_path=None, + need_data=True, +): + """Build train, valid, and test datasets.""" + + # Indexed dataset. + if need_data: + indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) + + total_num_of_documents = indexed_dataset.sizes.shape[0] + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + + # Print stats about the splits. + print_rank_0(" > dataset split:") + + def print_split_stats(name, index): + print_rank_0(" {}:".format(name)) + print_rank_0( + " document indices in [{}, {}) total of {} " + "documents".format(splits[index], splits[index + 1], splits[index + 1] - splits[index]) + ) + + print_split_stats("train", 0) + print_split_stats("validation", 1) + print_split_stats("test", 2) + + if paddle.distributed.get_world_size() > 1: + paddle.distributed.barrier() + + def build_dataset(index, name): + documents = np.arange(splits[index], splits[index + 1], 1, np.int32) if need_data else None + dataset = GPTDataset( + name, + data_prefix, + documents, + indexed_dataset if need_data else None, + splits_string, + train_val_test_num_samples[index], + seq_length, + seed, + return_doc_ids, + share_folder, + data_cache_path=data_cache_path, + need_data=need_data, + ) + if need_data: + return dataset if splits[index + 1] > splits[index] else None + else: + return None + + train_dataset = build_dataset(0, "train") + valid_dataset = build_dataset(1, "valid") + test_dataset = build_dataset(2, "test") + + return (train_dataset, valid_dataset, test_dataset) + + +def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): + """Build indexed dataset.""" + print_rank_0(" > building dataset index ...") + + start_time = time.time() + indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) + print_rank_0(" > finished creating indexed dataset in {:4f} " "seconds".format(time.time() - start_time)) + print_rank_0(" number of documents: {}".format(indexed_dataset.sizes.shape[0])) + + return indexed_dataset + + +class GPTDataset(paddle.io.Dataset): + def __init__( + self, + name, + data_prefix, + documents, + indexed_dataset, + splits_string, + num_samples, + seq_length, + seed, + return_doc_ids=False, + share_folder=False, + *, + data_cache_path=None, + need_data=True, + ): + + self.name = name + self.indexed_dataset = indexed_dataset + self.return_doc_ids = return_doc_ids + + # Build index mappings. + if need_data and len(documents) > 0: + assert np.min(documents) >= 0 + assert np.max(documents) < indexed_dataset.sizes.shape[0] + + ( + doc_idx_filename, + sample_idx_filename, + shuffle_idx_filename, + self.desc, + self.desc_hash, + num_epochs, + ) = _build_index_mappings( + self.name, + data_prefix, + documents, + self.indexed_dataset.sizes, + splits_string, + num_samples, + seq_length, + seed, + share_folder, + data_cache_path=data_cache_path, + ) + + if paddle.distributed.get_world_size() > 1: + paddle.distributed.barrier() + + # Load mappings. + if need_data and len(documents) > 0: + start_time = time.time() + print_rank_0(f" > loading doc-idx mapping from {doc_idx_filename}") + self.doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode="r") + + print_rank_0(f" > loading sample-idx mapping from {sample_idx_filename}") + self.sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode="r") + + print_rank_0(f" > loading shuffle-idx mapping from {shuffle_idx_filename}") + self.shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode="r") + + print_rank_0(" loaded indexed file in {:3.3f} seconds".format(time.time() - start_time)) + print_rank_0(" total number of samples: {}".format(self.sample_idx.shape[0])) + print_rank_0(" total number of epochs: {}".format(num_epochs)) + + if paddle.distributed.get_world_size() > 1: + paddle.distributed.barrier() + + def __len__(self): + # -1 is due to data structure used to retrieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + return self.sample_idx.shape[0] - 1 + + def __getitem__(self, idx): + # Get the shuffled index. + idx = self.shuffle_idx[idx] + # Start and end documents and offsets. + doc_index_f = self.sample_idx[idx][0] + doc_index_l = self.sample_idx[idx + 1][0] + offset_f = self.sample_idx[idx][1] + offset_l = self.sample_idx[idx + 1][1] + # If we are within the same document, just extract the chunk. + doc_ids = [] + if doc_index_f == doc_index_l: + doc_ids.append(self.doc_idx[doc_index_f]) + + sample, mask = self.indexed_dataset.get( + self.doc_idx[doc_index_f], offset=offset_f, length=offset_l - offset_f + 1 + ) + else: + # Otherwise, get the rest of the initial document. + doc_ids.append(self.doc_idx[doc_index_f]) + sample, mask = self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f) + append_mask = True + if mask is None: + append_mask = False + + sample_list = [sample] + mask_list = [] + mask_list = [mask] + # Loop over all in between documents and add the entire document. + for i in range(doc_index_f + 1, doc_index_l): + doc_ids.append(self.doc_idx[i]) + sample, mask = self.indexed_dataset.get(self.doc_idx[i]) + sample_list.append(sample) + if append_mask: + mask_list.append(mask) + + # And finally add the relevant portion of last document. + doc_ids.append(self.doc_idx[doc_index_l]) + sample, mask = self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) + sample_list.append(sample) + if append_mask: + mask_list.append(mask) + sample = np.concatenate(sample_list) + if append_mask: + mask = np.concatenate(mask_list) + # print(sample) + if self.return_doc_ids: # for retro preprocessing + if mask is None: + return {"text": np.array(sample, dtype=np.int64), "doc_ids": np.array(doc_ids, dtype=np.int64)} + else: + return { + "text": np.array(sample, dtype=np.int64), + "doc_ids": np.array(doc_ids, dtype=np.int64), + "mask": np.array(mask, dtype=np.int64), + } + else: + if mask is None: + return {"text": np.array(sample, dtype=np.int64)} + else: + return {"text": np.array(sample, dtype=np.int64), "mask": np.array(mask, dtype=np.int64)} + + +def _build_index_mappings( + name, data_prefix, documents, sizes, splits_string, num_samples, seq_length, seed, share_folder, *, data_cache_path +): + """Build doc-idx, sample-idx, and shuffle-idx. + doc-idx: is an array (ordered) of documents to be used in training. + sample-idx: is the start document index and document offset for each + training sample. + shuffle-idx: maps the sample index into a random index into sample-idx. + """ + + # Number of tokens in each epoch and number of required epochs. + tokens_per_epoch = _num_tokens(documents, sizes) + num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) + + # rng state + np_rng = np.random.RandomState(seed=seed) + # Filename of the index mappings. + desc = "GPT Dataset\n\n" + desc += f"Data prefix {data_prefix}\n" + desc += f"Dataset name {name}\n" + desc += f"Number of samples {num_samples}\n" + desc += f"Sequence length {seq_length}\n" + desc += f"Random seed {seed}\n" + desc += f"Split {splits_string}\n" + desc_hash = hashlib.md5(desc.encode("utf-8")).hexdigest() + desc_filename = desc_hash + ".dsc" + doc_idx_filename = desc_hash + "_doc_idx.npy" + sample_idx_filename = desc_hash + "_sample_idx.npy" + shuffle_idx_filename = desc_hash + "_shuffle_idx.npy" + + # Look for cache in main data dir first to avoid unnecessary + # duplication, then look in data-cache-path if specified, + # If nothing is found, use the last path looked in + build_indices = True + prefixes = [os.path.join(os.path.dirname(data_prefix), "index-cache")] + if data_cache_path is not None: + prefixes.append(data_cache_path) + for prefix in prefixes: + idx_path = { + "desc": os.path.join(prefix, desc_filename), + "doc": os.path.join(prefix, doc_idx_filename), + "sample": os.path.join(prefix, sample_idx_filename), + "shuffle": os.path.join(prefix, shuffle_idx_filename), + } + for f in idx_path.values(): + if not os.path.isfile(f): + break + else: + # Found our files! + build_indices = False + break + data_cache_dir = os.path.dirname(idx_path["desc"]) + # data_cache_success = True + # Build the indexed mapping if not exist. + check_rank_flag = build_indices and local_rank == 0 + if share_folder: + check_rank_flag = build_indices and paddle.distributed.get_rank() == 0 + + # if build_indices and paddle.distributed.get_rank() == 0: + + print( + f"searching for causal dataset, build_indices={build_indices}, share_folder {share_folder}, check_rank_flag {check_rank_flag}", + flush=True, + ) + if check_rank_flag: + print_rank_0(" > WARNING: could not find index map files, building " "the indices on rank 0 ...") + + # For the last epoch, decide whether include the entire epoch + # in the global shuffle or not. + + # If we need only one epoch, then separating last epoch does + # not mean anything. + if num_epochs == 1: + separate_last_epoch = False + print(" > only one epoch required, setting " "separate_last_epoch to False", flush=True) + + else: + # Get the number of samples for the last epoch + num_samples_from_epochs_minus_one = ((num_epochs - 1) * tokens_per_epoch - 1) // seq_length + last_epoch_num_samples = num_samples - num_samples_from_epochs_minus_one + assert last_epoch_num_samples >= 0, "last epoch number of samples should be non-negative." + num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length + assert last_epoch_num_samples <= ( + num_samples_per_epoch + 1 + ), "last epoch number of samples exceeded max value." + # If we have less than 80% of the samples for the last epoch, + # separate out the epoch and treat it differently. + # Note: the 80% number is just based on common sense and can + # be adjusted if needed. + separate_last_epoch = last_epoch_num_samples < int(0.80 * num_samples_per_epoch) + if separate_last_epoch: + string = ( + " > last epoch number of samples ({}) is smaller " + "than 80% of number of samples per epoch ({}), " + "setting separate_last_epoch to True" + ) + else: + string = ( + " > last epoch number of samples ({}) is larger " + "than 80% of number of samples per epoch ({}), " + "setting separate_last_epoch to False" + ) + print(string.format(last_epoch_num_samples, num_samples_per_epoch), flush=True) + + try: + os.makedirs(data_cache_dir, exist_ok=True) + + # description + with open(idx_path["desc"], "wt") as fd: + fd.write(desc) + + # doc-idx. + start_time = time.time() + doc_idx = _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch) + np.save(idx_path["doc"], doc_idx, allow_pickle=True) + print_rank_0( + " > elapsed time to build and save doc-idx mapping " + "(seconds): {:4f}".format(time.time() - start_time) + ) + # sample-idx. + start_time = time.time() + # Use C++ implementation for speed. + # First compile and then import. + # from megatron.data import helpers + from fast_dataindex import helpers + + assert doc_idx.dtype == np.int32 + assert sizes.dtype == np.int32 + sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch) + np.save(idx_path["sample"], sample_idx, allow_pickle=True) + print_rank_0( + " > elapsed time to build and save sample-idx mapping " + "(seconds): {:4f}".format(time.time() - start_time) + ) + # shuffle-idx. + start_time = time.time() + # -1 is due to data structure used to retrieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + if separate_last_epoch: + num_samples_ = num_samples_from_epochs_minus_one + else: + num_samples_ = sample_idx.shape[0] - 1 + shuffle_idx = _build_shuffle_idx(num_samples_, sample_idx.shape[0] - 1, np_rng) + np.save(idx_path["shuffle"], shuffle_idx, allow_pickle=True) + print_rank_0( + " > elapsed time to build and save shuffle-idx mapping" + " (seconds): {:4f}".format(time.time() - start_time) + ) + except OSError: + print(f"There was an error trying to create the data cache directory ({data_cache_dir})") + print('or a file in it. This defaults to a directory "index-cache" within the directory') + print("the data files are in and can be set with the --data-cache-path argument. Please") + print("ensure you have write access to this directory or specify one that you do have") + print("write access to.") + # data_cache_success = False + else: + while True: + if ( + (not os.path.isfile(idx_path["doc"])) + or (not os.path.isfile(idx_path["sample"])) + or (not os.path.isfile(idx_path["shuffle"])) + ): + print("building indices on rank 0 ...", flush=True) + time.sleep(3) + else: + try: + np.load(idx_path["shuffle"], allow_pickle=True, mmap_mode="r") + print("build success", flush=True) + break + except Exception: + print("%s file is still writing or damaged, please wait for a moment." % idx_path["shuffle"]) + time.sleep(3) + # try: + # hcg = paddle.distributed.fleet.get_hybrid_communicate_group() + # except: + # hcg = FakeHCG() + + # counts = paddle.to_tensor([data_cache_success], dtype="int64") + # paddle.distributed.all_reduce(counts, group=hcg.get_data_parallel_group()) + # paddle.distributed.all_reduce(counts, group=hcg.get_pipe_parallel_group()) + # if counts[0].item() != ( + # paddle.distributed.get_world_size() // paddle.distributed.get_world_size(group=hcg.get_model_parallel_group()) + # ): + # print_rank_0("Data index creation unsuccessful, exiting.") + # exit() + # paddle.distributed.barrier() + + return idx_path["doc"], idx_path["sample"], idx_path["shuffle"], desc, desc_hash, num_epochs + + +def _num_tokens(documents, sizes): + """Total number of tokens in the dataset.""" + return np.sum(sizes[documents]) + + +def _num_epochs(tokens_per_epoch, seq_length, num_samples): + """Based on number of samples and sequence length, calculate how many + epochs will be needed.""" + num_epochs = 0 + total_tokens = 0 + while True: + num_epochs += 1 + total_tokens += tokens_per_epoch + # -1 is because we need to retrieve seq_length + 1 token each time + # but the last token will overlap with the first token of the next + # sample except for the last sample. + if ((total_tokens - 1) // seq_length) >= num_samples: + return num_epochs + + +def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch): + """Build an array with length = number-of-epochs * number-of-documents. + Each index is mapped to a corresponding document.""" + if not separate_last_epoch or num_epochs == 1: + doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1] + doc_idx[:] = documents + doc_idx = doc_idx.reshape(-1) + doc_idx = doc_idx.astype(np.int32) + np_rng.shuffle(doc_idx) + return doc_idx + + doc_idx_first = _build_doc_idx(documents, num_epochs - 1, np_rng, False) + doc_idx_last = _build_doc_idx(documents, 1, np_rng, False) + return np.concatenate((doc_idx_first, doc_idx_last)) + + +def _build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch): + """Sample index mapping is a 2D array with sizes + [number-of-samples + 1, 2] where [..., 0] contains + the index into `doc_idx` and [..., 1] is the + starting offset in that document.""" + + # Total number of samples. For -1 see comments in `_num_epochs`. + num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length + sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32) + + # Index into sample_idx. + sample_index = 0 + # Index into doc_idx. + doc_idx_index = 0 + # Beginning offset for each document. + doc_offset = 0 + # Start with first document and no offset. + sample_idx[sample_index][0] = doc_idx_index + sample_idx[sample_index][1] = doc_offset + sample_index += 1 + while sample_index <= num_samples: + # Start with a fresh sequence. + remaining_seq_length = seq_length + 1 + while remaining_seq_length != 0: + # Get the document length. + doc_id = doc_idx[doc_idx_index] + doc_length = sizes[doc_id] - doc_offset + # And add it to the current sequence. + remaining_seq_length -= doc_length + # If we have more than a full sequence, adjust offset and set + # remaining length to zero so we return from the while loop. + # Note that -1 here is for the same reason we have -1 in + # `_num_epochs` calculations. + if remaining_seq_length <= 0: + doc_offset += remaining_seq_length + doc_length - 1 + remaining_seq_length = 0 + else: + # Otherwise, start from the beginning of the next document. + doc_idx_index += 1 + doc_offset = 0 + # Record the sequence. + sample_idx[sample_index][0] = doc_idx_index + sample_idx[sample_index][1] = doc_offset + sample_index += 1 + + return sample_idx + + +def _build_shuffle_idx(num_samples, total_size, np_rng): + """Build the range [0, size) and shuffle.""" + print( + " > building shuffle index with split [0, {}) and [{}, {}) " + "...".format(num_samples, num_samples, total_size), + flush=True, + ) + + dtype_ = np.uint32 + if total_size >= (np.iinfo(np.uint32).max - 1): + dtype_ = np.int64 + + shuffle_idx_first = np.arange(start=0, stop=num_samples, step=1, dtype=dtype_) + np_rng.shuffle(shuffle_idx_first) + if num_samples == total_size: + return shuffle_idx_first + + shuffle_idx_last = np.arange(start=num_samples, stop=total_size, step=1, dtype=dtype_) + np_rng.shuffle(shuffle_idx_last) + + return np.concatenate((shuffle_idx_first, shuffle_idx_last)) diff --git a/paddleformers/trainer/__init__.py b/paddleformers/trainer/__init__.py index 53ceb66a961..b129b4a20a1 100644 --- a/paddleformers/trainer/__init__.py +++ b/paddleformers/trainer/__init__.py @@ -75,6 +75,8 @@ "TrainerState", "DEFAULT_PROGRESS_CALLBACK", "TrainerCallback", + "StepFlexToken", + "FP8QuantWeightCallback", ], "trainer_utils": [ "get_last_checkpoint", diff --git a/paddleformers/trainer/trainer.py b/paddleformers/trainer/trainer.py index 55fb28d5c09..9d2a85074a2 100644 --- a/paddleformers/trainer/trainer.py +++ b/paddleformers/trainer/trainer.py @@ -302,6 +302,7 @@ def __init__( optimizers: Tuple[paddle.optimizer.Optimizer, paddle.optimizer.lr.LRScheduler] = (None, None), preprocess_logits_for_metrics: Callable[[paddle.Tensor, paddle.Tensor], paddle.Tensor] = None, processing_class: Optional[ImageProcessingMixin] = None, + resume_from_custom_func: Optional[Callable] = None, ): if args is None: @@ -356,6 +357,7 @@ def __init__( self.train_dataset = train_dataset self.eval_dataset = eval_dataset self.tokenizer = tokenizer + self.resume_from_custom_func = resume_from_custom_func if not args.skip_profile_timer: set_timers() self.timers = get_timers() @@ -1133,6 +1135,9 @@ def _inner_training_loop( if self.args.ignore_data_skip: self.timers and self.timers("read-data").start() + if self.resume_from_custom_func is not None: + self.resume_from_custom_func(self.model) + for epoch in range(epochs_trained, num_train_epochs): if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance( train_dataloader.batch_sampler, DistributedBatchSampler diff --git a/paddleformers/trainer/trainer_callback.py b/paddleformers/trainer/trainer_callback.py index 812c8dc9f59..92ed1e182d4 100644 --- a/paddleformers/trainer/trainer_callback.py +++ b/paddleformers/trainer/trainer_callback.py @@ -20,12 +20,14 @@ """ import dataclasses import json +import os from dataclasses import dataclass from typing import Dict, List, Optional, Union import numpy as np from tqdm.auto import tqdm +from ..transformers.moe_utils import offload, reload from ..utils.log import logger from .trainer_utils import IntervalStrategy, has_length from .training_args import TrainingArguments @@ -39,6 +41,8 @@ "ProgressCallback", "PrinterCallback", "EarlyStoppingCallback", + "StepFlexToken", + "FP8QuantWeightCallback", ] @@ -608,3 +612,72 @@ def on_evaluate(self, args, state, control, metrics, **kwargs): self.check_metric_value(args, state, control, metric_value) if self.early_stopping_patience_counter >= self.early_stopping_patience: control.should_training_stop = True + + +class StepFlexToken(TrainerCallback): + def on_step_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + model = kwargs.pop("model") + if hasattr(model, "step_flex_token"): + model.step_flex_token(state.global_step) + + +g_shard_bypass_dygraph_optimizer = int(os.environ.get("FLAGS_shard_bypass_dygraph_optimizer", 0)) + + +def enable_in_dict_config(config, key): + """enable_in_dict_config""" + return key in config and config[key] + + +skip_count = 0 + + +class FP8QuantWeightCallback(TrainerCallback): + """ + Callback for FP8 weight quantization during training + """ + + def on_step_begin(self, args, state, control, **kwargs): + """ + Quantize expert weights to FP8 before each training step + """ + model = kwargs["model"] + optimizer = kwargs["optimizer"] + global skip_count + + if (not g_shard_bypass_dygraph_optimizer or skip_count == 0) and hasattr(model, "fp8_quant_weight"): + model.fp8_quant_weight(True, quant_transpose=True) + optimizer.clear_param_storage("moe_expert") + optimizer.clear_param_storage("rms_linear") + optimizer.clear_param_storage("memory_attn") + optimizer.clear_param_storage("attn_out_project") + optimizer.clear_param_storage("shared_expert") + + self.moe_weights_name = [] + for param in optimizer._inner_opt._parameter_list: + color = getattr(param, "color", -1) + if isinstance(color, dict) and color["color"] == "moe_expert": + self.moe_weights_name.append(param.name) + + for name in self.moe_weights_name: + offload(optimizer._master_weights[name]) + + skip_count += 1 + + def on_optimizer_begin(self, args, state, control, **kwargs): + """ + Reload weights before optimizer step + """ + model = kwargs["model"] + optimizer = kwargs["optimizer"] + global skip_count + + if (not g_shard_bypass_dygraph_optimizer) and hasattr(model, "fp8_quant_weight"): + for name in self.moe_weights_name: + reload(optimizer._master_weights[name]) diff --git a/paddleformers/trainer/training_args.py b/paddleformers/trainer/training_args.py index c505856532b..7436222d37b 100644 --- a/paddleformers/trainer/training_args.py +++ b/paddleformers/trainer/training_args.py @@ -901,6 +901,10 @@ class TrainingArguments: default=None, metadata={"help": "The path to a folder with a valid checkpoint for your model."}, ) + resume_from_huggingface_ckpt: Optional[str] = field( + default=None, + metadata={"help": "The path to a folder with a valid huggingface checkpoint for your model."}, + ) auto_parallel_resume_form_hybrid_parallel: Optional[bool] = field( default=False, metadata={"help": "Whether hybrid parallel checkpoints be loaded in auto parallel mode."}, @@ -1088,6 +1092,10 @@ class TrainingArguments: default=False, metadata={"help": "Save model to HuggingFace safetensors."}, ) + reorder_pipeline_priority: Optional[bool] = field( + default=False, + metadata={"help": "Controls the parallel execution order. False (pp first), True (sharding first)."}, + ) def __post_init__(self): world_size = paddle.distributed.get_world_size() @@ -1405,12 +1413,15 @@ def is_segment_parallel_supported(): else: order = ["dp", "sharding", "pp", "mp"] if self.use_expert_parallel: - if self.moe_sharding_parallel_degree >= 1 and self.expert_parallel_degree > 1: - order.insert(-1, "ep") - sd_idx = order.index("sharding") - # if pp_first, the order = ["dp", "pp", "moe_sharding", "sharding", "sep", "ep", "mp"] - # if sharding_first, the order is ["dp", "moe_sharding", "sharding", "pp", "sep", "ep", "mp"] - order.insert(sd_idx, "moe_sharding") + if not self.reorder_pipeline_priority: + if self.moe_sharding_parallel_degree >= 1 and self.expert_parallel_degree > 1: + order.insert(-1, "ep") + sd_idx = order.index("sharding") + # if pp_first, the order = ["dp", "pp", "moe_sharding", "sharding", "sep", "ep", "mp"] + # if sharding_first, the order is ["dp", "moe_sharding", "sharding", "pp", "sep", "ep", "mp"] + order.insert(sd_idx, "moe_sharding") + else: + order = order[1:-1] + ["dp", "mp"] if is_segment_parallel_supported(): hybrid_configs = { @@ -1564,6 +1575,10 @@ def is_segment_parallel_supported(): fleet.init(is_collective=True, strategy=strategy) logger.info(strategy) + if self.reorder_pipeline_priority: + if self.expert_parallel_degree > 1: + self.add_moe_comm_group() + elif self.enable_auto_parallel: self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1) self.sep_parallel_degree = max(self.sep_parallel_degree, 1) diff --git a/paddleformers/transformers/__init__.py b/paddleformers/transformers/__init__.py index 65a063f2d20..c3bbeff17a6 100644 --- a/paddleformers/transformers/__init__.py +++ b/paddleformers/transformers/__init__.py @@ -49,7 +49,7 @@ "feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"], "image_processing_utils": ["ImageProcessingMixin"], "moe_gate": ["PretrainedMoEGate", "MoEGateMixin"], - "token_dispatcher": [], + "token_dispatcher": ["_DispatchManager"], "moe_layer": ["combining", "_AllToAll", "MoELayer", "dispatching", "MoEFlexTokenLayer"], "bert.modeling": [ "BertForSequenceClassification", diff --git a/paddleformers/transformers/fp8_utils.py b/paddleformers/transformers/fp8_utils.py new file mode 100644 index 00000000000..b927aa53097 --- /dev/null +++ b/paddleformers/transformers/fp8_utils.py @@ -0,0 +1,1308 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from functools import partial + +import numpy +import paddle +import paddle.nn.functional as F + +try: + import fused_ln + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + + +from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import WeightGradStore + +USE_DS_GEMM = os.getenv("USE_DS_GEMM", "False").lower() == "true" + +try: + if USE_DS_GEMM: + import deep_gemm + else: + from paddle.incubate.fp8 import deep_gemm +except: + pass + + +__all__ = [ + "FP8LinearFunctionBase", + "FP8Linear", + "FP8GroupGemmMlpFunctionNode", +] + + +def get_sm_num(): + return 112 + + +def set_parameter_color( + parameters, color, group=None, offline_quant_expert_weight=True, clear_origin_weight_when_offline_quant=True +): + if offline_quant_expert_weight and clear_origin_weight_when_offline_quant: + if group is None: + for p in parameters: + if hasattr(p, "color") and p.color is not None: + continue + setattr(p, "color", {"color": color}) + else: + for p in parameters: + if hasattr(p, "color") and p.color is not None: + continue + setattr(p, "color", {"color": color, "group": group}) + + +def extract_first_if_tuple(x): + return x[0] if isinstance(x, tuple) else x + + +def _get_fp8_weight_and_scale(weight, stacked=False, transpose=False): + """_get_fp8_weight_and_scale""" + if stacked: + if transpose: + fp8_weight, fp8_scale = weight.fp8_weight_stacked_transpose, weight.fp8_scale_stacked_transpose + else: + fp8_weight, fp8_scale = weight.fp8_weight_stacked, weight.fp8_scale_stacked + else: + if transpose: + fp8_weight, fp8_scale = weight.fp8_weight_transpose, weight.fp8_scale_transpose + else: + fp8_weight, fp8_scale = weight.fp8_weight, weight.fp8_scale + return fp8_weight, fp8_scale + + +def fused_stack_quant(expert_weight_list, transpose=False): + if transpose is False and hasattr(expert_weight_list[0], "fp8_weight_stacked"): + w, scale = _get_fp8_weight_and_scale(expert_weight_list[0], stacked=True, transpose=False) + elif transpose is True and hasattr(expert_weight_list[0], "fp8_weight_stacked_transpose"): + w, scale = _get_fp8_weight_and_scale(expert_weight_list[0], stacked=True, transpose=True) + elif transpose is True and hasattr(expert_weight_list[0], "fp8_weight_stacked"): + w, scale = _get_fp8_weight_and_scale(expert_weight_list[0], stacked=True, transpose=False) + elif transpose is False and hasattr(expert_weight_list[0], "fp8_weight_stacked_transpose"): + w, scale = _get_fp8_weight_and_scale(expert_weight_list[0], stacked=True, transpose=True) + else: + w, scale = paddle.incubate.nn.functional.fused_stack_transpose_quant(expert_weight_list, transpose=transpose) + return w, scale + + +def weight_quant(weight, transpose=False): + if transpose: + if hasattr(weight, "fp8_weight_transpose"): + return weight.fp8_weight_transpose, weight.fp8_scale_transpose + elif hasattr(weight, "fp8_weight"): + return weight.fp8_weight.T.contiguous(), weight.fp8_scale.T.contiguous() + else: + return paddle.incubate.nn.functional.fp8_quant_blockwise( + weight, + output_scale_transpose=False, + quant_method="128x128", + input_transpose=True, + return_transpose_only=True, + ) + else: + if hasattr(weight, "fp8_weight"): + return weight.fp8_weight, weight.fp8_scale + elif hasattr(weight, "fp8_weight_transpose"): + return weight.fp8_weight_transpose.T.contiguous(), weight.fp8_scale_transpose.T.contiguous() + else: + return paddle.incubate.nn.functional.fp8_quant_blockwise( + weight, + output_scale_transpose=False, + quant_method="128x128", + input_transpose=False, + return_transpose_only=False, + ) + + +class FP8LinearFunctionBase: + @staticmethod + def dequantize_fp8_to_fp32(fp8_tensor, scale): + res = fp8_tensor.reshape([-1, 128]).astype("bfloat16") * (scale.reshape([-1, 1])) + return res.reshape(fp8_tensor.shape) + + @staticmethod + def padding(x, axis): + if x.shape[axis] % 512 != 0: + if (x.shape[axis] + 128 - (x.shape[axis] % 128)) % 512 != 0: + padding_size = 512 + else: + padding_size = 128 + pad_size = padding_size - (x.shape[axis] % padding_size) + if axis == 0: + x = paddle.concat([x, paddle.zeros([pad_size, x.shape[-1]], dtype=x.dtype)], axis=0) + else: + x = paddle.concat([x, paddle.zeros([x.shape[0], pad_size], dtype=x.dtype)], axis=-1) + return x + + @staticmethod + def padding_and_quant_input(tensor): + """Quantize input to FP8, with fallback to padded transposed version if shape not aligned.""" + if tensor.shape[0] % 512 != 0: + tensor_fp8, tensor_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + tensor, output_scale_transpose=True, quant_method="1x128", input_transpose=False + ) + tensor = FP8LinearFunctionBase.padding(tensor, 0) + tensor_t_fp8, tensor_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + tensor, + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + return_transpose_only=True, + ) + return tensor_fp8, tensor_scale, tensor_t_fp8, tensor_t_scale + else: + tensor_fp8, tensor_scale, tensor_t_fp8, tensor_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + tensor, output_scale_transpose=True, quant_method="1x128", input_transpose=True + ) + return tensor_fp8, tensor_scale, tensor_t_fp8, tensor_t_scale + + @staticmethod + def kitchen_gemm( + x_fp8, x_scale, w_fp8, w_scale, is_a_1d_scaled, is_b_1d_scaled, out=None, rtn_dtype=paddle.bfloat16 + ): + if USE_DS_GEMM: + if out is None: + out = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], rtn_dtype) + if numpy.prod(x_fp8.shape) != 0 and numpy.prod(w_fp8.shape) != 0: + deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt((x_fp8, x_scale), (w_fp8, w_scale), out, num_sms=get_sm_num()) + return out + + if out is not None: + accumulate = True + out_dtype = out.dtype + else: + accumulate = False + out_dtype = rtn_dtype + if numpy.prod(x_fp8.shape) != 0 and numpy.prod(w_fp8.shape) != 0: + y = paddle.incubate.nn.functional.fp8_gemm_blockwise( + a=x_fp8, + a_decode_scale=x_scale, + b=w_fp8, + b_decode_scale=w_scale, + out_dtype=out_dtype, + out=out, + accumulate=accumulate, + use_split_accumulator=True, + is_a_1d_scaled=is_a_1d_scaled, + is_b_1d_scaled=is_b_1d_scaled, + ) + else: + y = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], out_dtype) + if out is not None: + out = out + y + return out + + return y + + @staticmethod + def compute_fp8_linear( + input, weight, weight_transpose=False, return_transpose_only=False, return_mode="output_only", *, out=None + ): + """ + FP8 Linear computation function supporting multiple return modes and quantized/unquantized inputs. + + Args: + input: Input tensor (raw tensor or quantized as (input_fp8, input_scale) tuple) + weight: Weight tensor + weight_transpose (bool): Whether to transpose weight + return_transpose_only (bool): Whether to return only transposed weight + return_mode (str): Return mode options: + - "output_only": Returns only output tensor + - "with_input_quant": Returns output + input quant results (input_fp8, input_scale) + - "with_input_transpose_quant": Returns output + transposed quant results (input_t_fp8, input_t_scale) + + Returns: + Different combinations of tensors based on return_mode + + Raises: + RuntimeError: If return_mode is not supported + """ + # check input + is_input_quantized = isinstance(input, (tuple, list)) and len(input) == 2 + + if is_input_quantized: + input_fp8, input_scale = input + if return_mode == "with_input_transpose_quant": + raise RuntimeError( + "Cannot return transposed quant if input is already quantized. " "Use raw input instead." + ) + else: + # quant input (with optional transposed output) + if return_mode == "with_input_transpose_quant": + input_fp8, input_scale, input_t_fp8, input_t_scale = FP8LinearFunctionBase.padding_and_quant_input( + input + ) + else: + input_fp8, input_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + input, + output_scale_transpose=True, + quant_method="1x128", + input_transpose=False, + return_transpose_only=False, + ) + + # quant weight + weight_fp8, weight_scale = weight_quant(weight, weight_transpose) + + # FP8 GEMM + if out is None: + out = paddle.empty([input_fp8.shape[0], weight_fp8.shape[0]], dtype=weight.dtype) + + deep_gemm.gemm_fp8_fp8_bf16_nt( + (input_fp8, input_scale.T), (weight_fp8, weight_scale), out, num_sms=get_sm_num() + ) + + # Return outputs + if return_mode == "output_only": + return out + elif return_mode == "with_input_quant": + return (out, input_fp8, input_scale) + elif return_mode == "with_input_transpose_quant": + return (out, input_t_fp8, input_t_scale) + else: + raise RuntimeError( + f"Unsupported return_mode: {return_mode}. " + "Supported modes: 'output_only', 'with_input_quant', 'with_input_transpose_quant'" + ) + + @staticmethod + def compute_expert_w_grad( + input_t, + input_t_scale, + dout_t, + dout_t_scale, + is_a_1d_scaled=True, + is_b_1d_scaled=True, + weight=None, + rtn_dtype=paddle.bfloat16, + ): + """ + Unified gradient computation for expert_w weights (supports both main_grad and regular grad). + """ + + if input_t is None or numpy.prod(input_t.shape) == 0: + return + + if hasattr(weight, "main_grad"): + if weight.main_grad is None: + weight.main_grad = paddle.zeros(shape=weight.shape, dtype=paddle.float32) + if WeightGradStore.enabled: + WeightGradStore.put( + partial( + FP8LinearFunctionBase.kitchen_gemm, + input_t, + input_t_scale, + dout_t, + dout_t_scale, + is_a_1d_scaled, + is_b_1d_scaled, + weight.main_grad, + rtn_dtype, + ) + ) + result = None + + else: + result = FP8LinearFunctionBase.kitchen_gemm( + input_t, + input_t_scale, + dout_t, + dout_t_scale, + is_a_1d_scaled, + is_b_1d_scaled, + weight.main_grad, + rtn_dtype, + ) + else: + if weight.grad is None: + weight.grad = paddle.zeros(shape=weight.shape, dtype=paddle.float32) + result = FP8LinearFunctionBase.kitchen_gemm( + input_t, input_t_scale, dout_t, dout_t_scale, is_a_1d_scaled, is_b_1d_scaled, weight.grad, rtn_dtype + ) + + if hasattr(weight, "_apply_backward_hook"): + weight._apply_backward_hook() + return result + + @staticmethod + def common_fp8_mlp_bwd( + do3, x_t_fp8, x_t_scale, w1, w2, o1=None, x_fp8=None, x_scale=None, apply_backward_hook=False + ): + if o1 is not None and (x_fp8 is not None or x_scale is not None): + raise ValueError("When o1 is provided, both x_fp8 and x_scale must be None.") + + if o1 is None: + if x_fp8 is None or x_scale is None: + raise ValueError("When o1 is None, both x_fp8 and x_scale must be provided.") + + # [recompute] o1 = deep_gemm(x_fp8, w1_t_fp8) + + # Recompute o1 using deep_gemm(x_fp8, w1_t_fp8) + w1_fp8, w1_scale = weight_quant(w1, True) + o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=do3.dtype) + deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1, num_sms=get_sm_num()) + + # [recompute] o2 = swiglu(o1) + o2 = swiglu(o1) + + # do2 = deep_gemm(do3_fp8, w2_fp8) + do2, do3_t_fp8, do3_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + do3, w2, return_mode="with_input_transpose_quant" + ) + + # dw2 = deep_gemm(o2_t_fp8, do3_t_fp8) + o2 = FP8LinearFunctionBase.padding(o2, 0) + o2_t_fp8, o2_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + o2, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True + ) + if apply_backward_hook: + if WeightGradStore.enabled: + WeightGradStore.put( + partial( + FP8LinearFunctionBase.compute_expert_w_grad, + o2_t_fp8, + o2_t_scale, + do3_t_fp8, + do3_t_scale, + True, + True, + w2, + rtn_dtype=paddle.float32, + ) + ) + else: + + FP8LinearFunctionBase.compute_expert_w_grad( + o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, True, True, w2, rtn_dtype=paddle.float32 + ) + else: + dw2 = FP8LinearFunctionBase.kitchen_gemm( + o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, True, True, rtn_dtype=paddle.float32 + ) + + # do1 = swiglu_grad(o1, None, do2) + do1, _ = paddle._C_ops.swiglu_grad(o1, None, do2) + + # dx = deep_gemm(do1_fp8, w1_fp8) + dx, do1_t_fp8, do1_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + do1, w1, return_mode="with_input_transpose_quant" + ) + + # dw1 = deep_gemm(x_t_fp8, do1_t_fp8) + if apply_backward_hook: + if WeightGradStore.enabled: + WeightGradStore.put( + partial( + FP8LinearFunctionBase.compute_expert_w_grad, + x_t_fp8, + x_t_scale, + do1_t_fp8, + do1_t_scale, + True, + True, + w1, + rtn_dtype=paddle.float32, + ) + ) + + else: + FP8LinearFunctionBase.compute_expert_w_grad( + x_t_fp8, x_t_scale, do1_t_fp8, do1_t_scale, True, True, w1, rtn_dtype=paddle.float32 + ) + else: + dw1 = FP8LinearFunctionBase.kitchen_gemm( + x_t_fp8, x_t_scale, do1_t_fp8, do1_t_scale, True, True, rtn_dtype=paddle.float32 + ) + + if apply_backward_hook: + return dx + else: + assert dw1 is not None and dw2 is not None + return dx, dw1, dw2 + + @staticmethod + def fp8_mlp_fwd(x, w1, w2): + x_orig_shape = x.shape + x = x.reshape([-1, x_orig_shape[-1]]) + + # o1 = deep_gemm(x_fp8, w1_t_fp8) + o1, x_fp8, x_scale = FP8LinearFunctionBase.compute_fp8_linear( + x, w1, weight_transpose=True, return_transpose_only=True, return_mode="with_input_quant" + ) + + # o2 = swiglu(o1) + o2 = swiglu(o1) + + # o3 = deep_gemm(o2_fp8, w2_t_fp8) + o3 = FP8LinearFunctionBase.compute_fp8_linear(o2, w2, weight_transpose=True, return_transpose_only=True) + + if len(x_orig_shape) > 2: + o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]]) + + return o1, x_fp8, x_scale, o3 + + @staticmethod + def fp8_mlp_fwd_norm_rc(x, norm_w, norm_eps, w1, w2): + # compute norm_output + norm_output, _ = fused_ln.fused_rms_norm(x, norm_w, norm_eps) + # compute fp8_mlp_fwd + _, _, _, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(norm_output, w1, w2) + return o3 + + @staticmethod + def fp8_mlp_bwd(do3, x, w1, w2, apply_backward_hook=False): + do3_orig_shape = do3.shape + do3 = do3.reshape([-1, do3_orig_shape[-1]]) + + x_orig_shape = x.shape + x = x.reshape([-1, x_orig_shape[-1]]) + + x_fp8, x_scale, x_t_fp8, x_t_scale = FP8LinearFunctionBase.padding_and_quant_input(x) + + if apply_backward_hook: + dx = FP8LinearFunctionBase.common_fp8_mlp_bwd( + do3, + x_t_fp8, + x_t_scale, + w1, + w2, + o1=None, + x_fp8=x_fp8, + x_scale=x_scale, + apply_backward_hook=apply_backward_hook, + ) + if len(x_orig_shape) > 2: + dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]]) + return dx + else: + dx, dw1, dw2 = FP8LinearFunctionBase.common_fp8_mlp_bwd( + do3, + x_t_fp8, + x_t_scale, + w1, + w2, + o1=None, + x_fp8=x_fp8, + x_scale=x_scale, + apply_backward_hook=apply_backward_hook, + ) + if len(x_orig_shape) > 2: + dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]]) + + return dx, dw1, dw2 + + @staticmethod + def fp8_mlp_bwd_norm_rc(do3, x, norm_w, norm_eps, w1, w2): + # recompute norm_output + norm_output, invar = fused_ln.fused_rms_norm(x, norm_w, norm_eps) + + # compute fp8_mlp_fwd + d_norm_output = FP8LinearFunctionBase.fp8_mlp_bwd(do3, norm_output, w1, w2, True) + + if hasattr(norm_w, "_apply_backward_hook"): + norm_w._apply_backward_hook() + + return d_norm_output, norm_output, invar + + +class FP8LinearFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, custom_map, keep_x=False): + weight = custom_map.weight + x_orig_shape = x.shape + + # deep_gemm only support 2D + x = x.reshape([-1, x_orig_shape[-1]]).contiguous() + + if keep_x: + out = FP8LinearFunctionBase.compute_fp8_linear( + x, + weight, + weight_transpose=True, + return_transpose_only=True, + ) + # save for bwd + out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]]) + ctx.save_for_backward(x, weight) + return out + else: + x_t = x.T + out, x_t_fp8, x_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + x, weight, weight_transpose=True, return_transpose_only=True, return_mode="with_input_transpose_quant" + ) + out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]]) + ctx.save_for_backward((x_t_fp8, x_t_scale), weight) + ctx.x_t_shape = x_t.shape + return out + + @staticmethod + def backward(ctx, dout): + x, weight = ctx.saved_tensor() + dout_2d = dout.reshape([-1, dout.shape[-1]]) + + keep_x = not isinstance(x, tuple) + + if keep_x: + # padding x and quant + dx_orig_shape = x.shape + x = FP8LinearFunctionBase.padding(x, 0) + x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + x, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True + ) + + # dx = deep_gemm(dout_fp8, w_fp8) + dx, dout_t_fp8, dout_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + dout_2d, weight, weight_transpose=False, return_mode="with_input_transpose_quant" + ) + dx = dx.reshape(dx_orig_shape) + + else: + x_t_fp8, x_t_scale = x + + # dx = deep_gemm(dout_fp8, w_fp8) + dx, dout_t_fp8, dout_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + dout_2d, weight, weight_transpose=False, return_mode="with_input_transpose_quant" + ) + dx_orig_shape = dout.shape[:-1] + dx_orig_shape.append(ctx.x_t_shape[0]) + dx = dx.reshape(dx_orig_shape) + + # dw1 = deep_gemm(x_t_fp8, dout_t_fp8) + FP8LinearFunctionBase.compute_expert_w_grad( + x_t_fp8, x_t_scale, dout_t_fp8, dout_t_scale, True, True, weight, paddle.float32 + ) + return dx + + +class FP8Linear(paddle.nn.Layer): + def __init__(self, in_features: int, out_features: int, bias_attr: bool = False) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.weight = self.create_parameter( + shape=[in_features, out_features], + dtype="bfloat16", + is_bias=False, + ) + + def forward(self, x): + return FP8LinearFunction.apply(x, self, keep_x=False) + + +def cache_fp8_weight(weight, quant_transpose=None): + if hasattr(weight, "fp8_weight") or hasattr(weight, "fp8_weight_transpose"): + return + if quant_transpose is None: + w_fp8, w_scale, w_t_fp8, w_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + weight, + output_scale_transpose=False, + quant_method="128x128", + input_transpose=True, + return_transpose_only=False, + ) + + setattr(weight, "fp8_weight_transpose", w_t_fp8) + setattr(weight, "fp8_scale_transpose", w_t_scale) + setattr(weight, "fp8_weight", w_fp8) + setattr(weight, "fp8_scale", w_scale) + elif quant_transpose is True: + w_t_fp8, w_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + weight, + output_scale_transpose=False, + quant_method="128x128", + input_transpose=True, + return_transpose_only=True, + ) + setattr(weight, "fp8_weight_transpose", w_t_fp8) + setattr(weight, "fp8_scale_transpose", w_t_scale) + elif quant_transpose is False: + w_fp8, w_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + weight, + output_scale_transpose=False, + quant_method="128x128", + input_transpose=False, + return_transpose_only=False, + ) + setattr(weight, "fp8_weight", w_fp8) + setattr(weight, "fp8_scale", w_scale) + else: + raise ValueError("quant_transpose must be either True, False or None.") + + +class FP8KeepXLinear(paddle.nn.Layer): + def __init__(self, in_features: int, out_features: int, bias_attr: bool = False) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.weight = self.create_parameter( + shape=[in_features, out_features], + dtype="bfloat16", + is_bias=False, + ) + set_parameter_color([self.weight], "attn_out_project") + + def fp8_quant_weight(self, quant_transpose=None): + cache_fp8_weight(self.weight, quant_transpose=quant_transpose) + + def forward(self, x): + return FP8LinearFunction.apply(x, self, keep_x=True) + + +class FusedNormFP8MLPFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, norm_w, w1, w2, norm_eps): + # compute norm_output + norm_output, invar = fused_ln.fused_rms_norm(x, norm_w, norm_eps) + # reshape for deep_gemm, since deep_gemm only support 2D + x_orig_shape = norm_output.shape + norm_output = norm_output.reshape([-1, x_orig_shape[-1]]) + + # call func fp8_mlp_fwd + _, _, _, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(norm_output, w1, w2) + + # reshape to origin shape + if len(x_orig_shape) > 2: + o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]]) + + # save for backward + ctx.save_for_backward( + norm_output, + invar, + x, + norm_w, + w1, + w2, + norm_eps, + paddle.to_tensor(x_orig_shape, dtype="int64", place=paddle.CPUPlace()), + ) + return o3 + + @staticmethod + def backward(ctx, do3): + # reshape for deep_gemm, since deep_gemm only support 2D + do3_orig_shape = do3.shape + do3 = do3.reshape([-1, do3_orig_shape[-1]]) + + # recive saved tensors + norm_output, invar, x, norm_w, w1, w2, norm_eps, x_orig_shape = ctx.saved_tensor() + + x_fp8, x_scale, x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + norm_output, output_scale_transpose=True, quant_method="1x128", input_transpose=True + ) + + # call func common_fp8_mlp_bwd + d_norm_output, dw1, dw2 = FP8LinearFunctionBase.common_fp8_mlp_bwd( + do3, x_t_fp8, x_t_scale, w1, w2, o1=None, x_fp8=x_fp8, x_scale=x_scale + ) + + # reshape to origin shape + if len(x_orig_shape) > 2: + d_norm_output = d_norm_output.reshape([x_orig_shape[0], -1, d_norm_output.shape[-1]]) + + # compute norm grad + dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, norm_w, invar, d_norm_output, norm_eps) + + return dx, d_rms_norm_weight, dw1, dw2 + + +class FP8MlpFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, w1, w2, recompute_fwd_gate_up): + # reshape for deep_gemm, since deep_gemm only support 2D + x_orig_shape = x.shape + x = x.reshape([-1, x_orig_shape[-1]]) + + # call func fp8_mlp_fwd + o1, x_fp8, x_scale, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(x, w1, w2) + # reshape to origin shape + if len(x_orig_shape) > 2: + o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]]) + + # save for backward + o1 = None if recompute_fwd_gate_up else o1 + ctx.save_for_backward( + o1, + x_fp8, + x_scale, + w1, + w2, + paddle.to_tensor(x_orig_shape, dtype="int64", place=paddle.CPUPlace()), + ) + return o3 + + @staticmethod + def backward(ctx, do3): + # reshape for deep_gemm, since deep_gemm only support 2D + do3_orig_shape = do3.shape + do3 = do3.reshape([-1, do3_orig_shape[-1]]) + + # recive saved tensors + o1, x_fp8, x_scale, w1, w2, x_orig_shape = ctx.saved_tensor() + + # compute x_t_fp8, x_t_scale for dw1 + x_dequant_fp16 = paddle.incubate.nn.functional.fused_act_dequant(x_fp8, x_scale.T.contiguous()) + x_dequant_fp16 = FP8LinearFunctionBase.padding(x_dequant_fp16, 0) + + x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + x_dequant_fp16, + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + return_transpose_only=True, + ) + + # call func common_fp8_mlp_bwd + if o1 is None: + dx = FP8LinearFunctionBase.common_fp8_mlp_bwd( + do3, x_t_fp8, x_t_scale, w1, w2, o1=None, x_fp8=x_fp8, x_scale=x_scale, apply_backward_hook=True + ) + else: + dx = FP8LinearFunctionBase.common_fp8_mlp_bwd( + do3, x_t_fp8, x_t_scale, w1, w2, o1=o1, x_fp8=None, x_scale=None, apply_backward_hook=True + ) + # reshape to origin shape + if len(x_orig_shape) > 2: + dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]]) + + return dx, None, None + + +class FP8Mlp(paddle.nn.Layer): + def __init__( + self, + config, + hidden_size=None, + intermediate_size=None, + is_moe=False, + using_post_norm_recompute=False, + norm_weight=None, + norm_eps=None, + recompute_fwd_gate_up=False, + ): + super().__init__() + self.config = config + self.using_post_norm_recompute = using_post_norm_recompute + if self.using_post_norm_recompute: + assert norm_weight is not None and norm_eps is not None + self.norm_weight = norm_weight + self.norm_eps = norm_eps + + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + + self.recompute_fwd_gate_up = recompute_fwd_gate_up + + self.w1 = self.create_parameter( + shape=[self.hidden_size, self.intermediate_size * 2], + dtype="bfloat16", + is_bias=False, + ) + self.w2 = self.create_parameter( + shape=[self.intermediate_size, self.hidden_size], + dtype="bfloat16", + is_bias=False, + ) + + def fp8_quant_weight(self, quant_transpose=None): + cache_fp8_weight(self.w1, quant_transpose) + cache_fp8_weight(self.w2, quant_transpose) + + def forward(self, x): + if self.using_post_norm_recompute: + return FusedNormFP8MLPFunction.apply(x, self.norm_weight, self.w1, self.w2, self.norm_eps) + else: + return FP8MlpFunction.apply(x, self.w1, self.w2, self.recompute_fwd_gate_up) + + +def split_group_gemm(x_fp8, x_scale, w_fp8, w_scale, tokens_per_expert, gemm_out): + start_idx = 0 + for i, token_num in enumerate(tokens_per_expert): + if token_num == 0: + continue + end_idx = start_idx + token_num + + x_scale_tma_align = x_scale[start_idx:end_idx].T.contiguous().T + + deep_gemm.gemm_fp8_fp8_bf16_nt( + (x_fp8[start_idx:end_idx], x_scale_tma_align), + (w_fp8[i], w_scale[i]), + gemm_out[start_idx:end_idx], + num_sms=get_sm_num(), + ) + + start_idx = end_idx + + return gemm_out + + +class FP8GroupGemmMlpFunctionNode: + def __init__( + self, + custom_map, + recompute_fwd_gate_up=False, + is_split_group_gemm=False, + name="experts_group_gemm_contiguous_node", + ): + self.experts = custom_map.experts + self.recompute_fwd_gate_up = recompute_fwd_gate_up + self.is_split_group_gemm = is_split_group_gemm + self.m_indices = None + self.input = None + self.input_fp8 = None + self.input_scale = None + self.o1 = None + self.all_unzipped_grad = None + self.fwd_subbatch = None + self.bwd_subbatch = None + + def reset_statue(self): + self.m_indices = None + self.fwd_subbatch = None + self.bwd_subbatch = None + self.clear_activation_tensors() + + def clear_activation_tensors(self): + self.input = None + self.input_fp8 = None + self.input_scale = None + self.o1 = None + self.all_unzipped_grad = None + + def gen_m_indices(self, tokens_per_expert): + tokens = [] + for i in range(len(tokens_per_expert)): + tokens.append(paddle.full([tokens_per_expert[i]], i, dtype="int32")) + out = paddle.concat(tokens, axis=0) + return out + + def fwd_gate_up(self, x, expert_w1, num_expert, tokens_per_expert, m_indices=None): + """ + o1 = x * w1 + [m_sum, n] = [m_sum, k] * [num_groups, k, n] (m_sum = sum(tokens_per_expert)) + """ + if not self.is_split_group_gemm and self.m_indices is None: + self.m_indices = self.gen_m_indices(tokens_per_expert) + # concat w1, shape is [num_groups, n, k] + w1_t_quant, w1_t_scale = fused_stack_quant(expert_w1, transpose=True) + w1_t_quant = w1_t_quant.reshape([num_expert, -1, w1_t_quant.shape[-1]]) + w1_t_scale = w1_t_scale.reshape([num_expert, -1, w1_t_scale.shape[-1]]) + + if hasattr(expert_w1[0], "fp8_weight_stacked") and not hasattr(expert_w1[0], "fp8_weight_stacked_transpose"): + w1_t_quant = w1_t_quant.contiguous().transpose([0, 2, 1]).contiguous() + w1_t_scale = w1_t_scale.contiguous().transpose([0, 2, 1]).contiguous() + + if x is None: + x_fp8, x_scale = self.input_fp8, self.input_scale + assert x_fp8 is not None and x_scale is not None + else: + if isinstance(x, tuple): + (x_fp8, x_scale) = x + x_scale = paddle.transpose(paddle.transpose(x_scale, [1, 0]).contiguous(), [1, 0]) + else: + # quant x_bf16 + x_fp8, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + x, output_scale_transpose=True, quant_method="1x128", input_transpose=False + ) + x_scale = x_scale.T + + # compute gemm + o1 = paddle.empty([x_fp8.shape[0], w1_t_quant.shape[1]], dtype=expert_w1[0].dtype) + if numpy.prod(x_fp8.shape) != 0: + if self.is_split_group_gemm: + split_group_gemm(x_fp8, x_scale, w1_t_quant, w1_t_scale, tokens_per_expert, o1) + else: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (x_fp8, x_scale), + (w1_t_quant, w1_t_scale), + o1, + m_indices=self.m_indices if m_indices is None else m_indices, + num_sms=get_sm_num(), + ) + + if m_indices is None: + self.input_fp8 = x_fp8 + self.input_scale = x_scale + return o1 + + def fwd_swiglu(self, o1): + o2 = swiglu(o1) + return o2 + + def fwd_down( + self, o1, unzipped_probs, expert_w2, num_expert, tokens_per_expert, m_indices=None, o3=None, clear_o1=False + ): + """ + o3 = o2 * w2 + [m_sum, k] = [m_sum, n] * [num_groups, n, k] + """ + # concat and transpose w2 + w2_quant, w2_scale = fused_stack_quant(expert_w2, transpose=True) + w2_quant = w2_quant.reshape([num_expert, -1, w2_quant.shape[-1]]) + w2_scale = w2_scale.reshape([num_expert, -1, w2_scale.shape[-1]]) + + if hasattr(expert_w2[0], "fp8_weight_stacked") and not hasattr(expert_w2[0], "fp8_weight_stacked_transpose"): + w2_quant = w2_quant.contiguous().transpose([0, 2, 1]).contiguous() + w2_scale = w2_scale.contiguous().transpose([0, 2, 1]).contiguous() + + # quant o2 + with paddle.amp.auto_cast(False): + unzipped_probs = unzipped_probs.squeeze(-1) + o2_fp8, o2_scale = paddle.incubate.nn.functional.fused_weighted_swiglu_act_quant( + o1, unzipped_probs, using_pow2_scaling=True + ) + o2_scale = paddle.transpose(paddle.transpose(o2_scale, [1, 0]).contiguous(), [1, 0]) + + if clear_o1: + o1._clear_to_zero_allocation() + + # compute gemm + o3_shape = [o2_fp8.shape[0], w2_quant.shape[1]] + if o3 is not None: + assert o3.shape == o3_shape, "{} vs {}".format(o3.shape, o3_shape) + else: + o3 = paddle.empty(o3_shape, dtype=o1.dtype) + if numpy.prod(o2_fp8.shape) != 0: + if self.is_split_group_gemm: + split_group_gemm(o2_fp8, o2_scale, w2_quant, w2_scale, tokens_per_expert, o3) + else: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (o2_fp8, o2_scale), + (w2_quant, w2_scale), + o3, + m_indices=m_indices if self.fwd_subbatch else self.m_indices, + num_sms=get_sm_num(), + ) + + return o3 + + def bwd_dowm_input(self, expert_w2, unzipped_grad, o1, tokens_per_expert, m_indices=None, unzipped_probs=None): + """ + do2 = do3 * w2_t + [m_sum, n] = [m_sum, k] * [num_groups, k, n] + """ + # recompute concated_w2_2d + bw_w2_quant, bw_w2_scale = fused_stack_quant(expert_w2, transpose=False) + bw_w2_quant = bw_w2_quant.reshape([len(expert_w2), -1, bw_w2_quant.shape[-1]]) + bw_w2_scale = bw_w2_scale.reshape([len(expert_w2), -1, bw_w2_scale.shape[-1]]) + + if hasattr(expert_w2[0], "fp8_weight_stacked_transpose") and not hasattr(expert_w2[0], "fp8_weight_stacked"): + bw_w2_quant = bw_w2_quant.contiguous().transpose([0, 2, 1]).contiguous() + bw_w2_scale = bw_w2_scale.contiguous().transpose([0, 2, 1]).contiguous() + + # compute gemm + if isinstance(unzipped_grad, tuple): + (unzipped_grad_fp8, unzipped_grad_scale) = unzipped_grad + unzipped_grad_scale = unzipped_grad_scale.T.contiguous().T + else: + unzipped_grad_fp8, unzipped_grad_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + unzipped_grad, output_scale_transpose=True, quant_method="1x128", input_transpose=False + ) + unzipped_grad_scale = unzipped_grad_scale.T + + do2_s = paddle.empty([unzipped_grad_fp8.shape[0], bw_w2_quant.shape[1]], dtype="bfloat16") + if numpy.prod(unzipped_grad_fp8.shape) != 0: + if self.is_split_group_gemm: + split_group_gemm( + unzipped_grad_fp8, unzipped_grad_scale, bw_w2_quant, bw_w2_scale, tokens_per_expert, do2_s + ) + else: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (unzipped_grad_fp8, unzipped_grad_scale), + (bw_w2_quant, bw_w2_scale), + do2_s, + m_indices=m_indices if self.bwd_subbatch else self.m_indices, + num_sms=get_sm_num(), + ) + + with paddle.amp.auto_cast(False): + do1, probs_grad, o2_s = paddle.incubate.nn.functional.fused_swiglu_weighted_bwd(o1, do2_s, unzipped_probs) + + return do1, o2_s, probs_grad + + def bwd_swiglu(self, o1, do2): + do1, _ = paddle._C_ops.swiglu_grad(o1, None, do2) + return do1 + + def bwd_gate_up_input(self, do1, expert_w1, tokens_per_expert, m_indices=None, dx=None): + """ + dx = do1 * w1_t + [m_sum, k] = [m_sum, n] * [num_groups, n, k] + """ + # recompute concated_w1_t + bw_w1_quant, bw_w1_scale = fused_stack_quant(expert_w1, transpose=False) + bw_w1_quant = bw_w1_quant.reshape([len(expert_w1), -1, bw_w1_quant.shape[-1]]) + bw_w1_scale = bw_w1_scale.reshape([len(expert_w1), -1, bw_w1_scale.shape[-1]]) + + if hasattr(expert_w1[0], "fp8_weight_stacked_transpose") and not hasattr(expert_w1[0], "fp8_weight_stacked"): + bw_w1_quant = bw_w1_quant.contiguous().transpose([0, 2, 1]).contiguous() + bw_w1_scale = bw_w1_scale.contiguous().transpose([0, 2, 1]).contiguous() + + # quant do1 + do1_fp8, do1_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + do1, output_scale_transpose=True, quant_method="1x128", input_transpose=False + ) + do1_scale = do1_scale.T + # compute gemm + dx_shape = [do1_fp8.shape[0], bw_w1_quant.shape[1]] + if dx is None or dx.dtype != do1.dtype: + dx = paddle.empty(shape=dx_shape, dtype=do1.dtype) + else: + assert dx.shape == dx_shape, f"{dx.shape} vs {dx_shape}" + if numpy.prod(do1_fp8.shape) != 0: + if self.is_split_group_gemm: + split_group_gemm(do1_fp8, do1_scale, bw_w1_quant, bw_w1_scale, tokens_per_expert, dx) + else: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (do1_fp8, do1_scale), + (bw_w1_quant, bw_w1_scale), + dx, + m_indices=m_indices if self.bwd_subbatch else self.m_indices, + num_sms=get_sm_num(), + ) + + return dx + + def fused_transpose_split_quant(self, x, scale, tokens_per_expert, pow_2_scales): + out, scale = paddle.incubate.nn.functional.fused_transpose_split_quant( + x, scale, tokens_per_expert, pow_2_scales + ) + return out, scale + + def bwd_down_weight(self, do3, o2, expert_w2, tokens_per_expert): + """ + dw2 = do2_t * do3 + [n, k] = [n, m_sum] * [m_sum, k] (m_sum = sum(tokens_per_expert)) + """ + if isinstance(o2, tuple): + o2_t_fp8, o2_t_scale = o2 + else: + o2_t_fp8, o2_t_scale = self.fused_transpose_split_quant(o2, None, tokens_per_expert, True) + + if isinstance(do3, tuple): + do3_t_fp8, do3_t_scale = do3 + else: + do3_t_fp8, do3_t_scale = self.fused_transpose_split_quant(do3, None, tokens_per_expert, True) + + def cal_weight_fn(o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, expert_w2): + with paddle.no_grad(): + for i in range(len(expert_w2)): + FP8LinearFunctionBase.compute_expert_w_grad( + o2_t_fp8[i], + o2_t_scale[i], + do3_t_fp8[i], + do3_t_scale[i], + True, + True, + expert_w2[i], + paddle.float32, + ) + + if WeightGradStore.enabled: + WeightGradStore.put(partial(cal_weight_fn, o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, expert_w2)) + else: + cal_weight_fn(o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, expert_w2) + + def bwd_gate_up_weight( + self, + do1, + input_x, + expert_w1, + tokens_per_expert, + input_fp8_slice=None, + input_scale_slice=None, + clear_input=False, + ): + """ + dw1 = dx_t * do1 + [k, n] = [k, m_sum] * [m_sum, n] (m_sum = sum(tokens_per_expert)) + """ + if input_x is None: + inp = (input_fp8_slice, input_scale_slice) if self.bwd_subbatch else (self.input_fp8, self.input_scale) + input_x_t_fp8, input_x_t_scale = self.fused_transpose_split_quant(inp[0], inp[1], tokens_per_expert, True) + + else: + input_x_t_fp8, input_x_t_scale = self.fused_transpose_split_quant(input_x, None, tokens_per_expert, True) + + if clear_input: + if self.input_fp8 is not None: + self.input_fp8._clear_to_zero_allocation() + self.input_fp8 = None + if self.input_scale is not None: + self.input_scale._clear_to_zero_allocation() + self.input_scale = None + if self.input is not None: + self.input._clear_to_zero_allocation() + self.input = None + + do1_t_fp8, do1_t_scale = self.fused_transpose_split_quant(do1, None, tokens_per_expert, True) + + def cal_weight_fn(input_x_t_fp8, input_x_t_scale, do1_t_fp8, do1_t_scale, expert_w1): + with paddle.no_grad(): + for i in range(len(expert_w1)): + FP8LinearFunctionBase.compute_expert_w_grad( + input_x_t_fp8[i], + input_x_t_scale[i], + do1_t_fp8[i], + do1_t_scale[i], + True, + True, + expert_w1[i], + paddle.float32, + ) + + if WeightGradStore.enabled: + WeightGradStore.put( + partial(cal_weight_fn, input_x_t_fp8, input_x_t_scale, do1_t_fp8, do1_t_scale, expert_w1) + ) + else: + cal_weight_fn(input_x_t_fp8, input_x_t_scale, do1_t_fp8, do1_t_scale, expert_w1) + + @paddle.no_grad() + def forward(self, hs_out, unzipped_probs, tokens_per_expert, m_indices=None): + # check subbatch + if self.fwd_subbatch: + assert m_indices is not None + # deal 0 size + dtype = paddle.bfloat16 + if hs_out is None: + assert self.input_fp8 is not None + assert self.input_scale is not None + shape = self.input_fp8.shape + else: + if isinstance(hs_out, tuple): + shape = hs_out[0].shape + else: + shape = hs_out.shape + + if shape[0] == 0: + o3 = paddle.zeros(shape, dtype=dtype) + return o3 + + # get w1/w2 + expert_w1 = [x.w1 for x in self.experts if x is not None] + expert_w2 = [x.w2 for x in self.experts if x is not None] + + num_expert = len(expert_w1) + + # o1 + o1 = self.fwd_gate_up(hs_out, expert_w1, num_expert, tokens_per_expert, m_indices) + if not self.recompute_fwd_gate_up: + self.o1 = o1 + clear_o1 = False + else: + clear_o1 = True + + # o3 + o3 = self.fwd_down( + o1, unzipped_probs, expert_w2, num_expert, tokens_per_expert, clear_o1=clear_o1, m_indices=m_indices + ) + + # save for bwd + return o3 + + @paddle.no_grad() + def backward( + self, + out_grad, + unzipped_probs, + tokens_per_expert, + input_fp8_slice=None, + input_scale_slice=None, + m_indices=None, + reset_status=False, + ): + # check subbatch + if self.bwd_subbatch: + assert ( + m_indices is not None + and input_fp8_slice is not None + and input_scale_slice is not None + and tokens_per_expert is not None + ) + # deal 0 size + dtype = paddle.bfloat16 + shape = out_grad[0].shape if isinstance(out_grad, tuple) else out_grad.shape + if shape[0] == 0: + return paddle.zeros_like(extract_first_if_tuple(out_grad), dtype=dtype), paddle.zeros_like(unzipped_probs) + + # recompute expert_w2 and expert_w1 + expert_w1 = [x.w1 for x in self.experts if x is not None] + expert_w2 = [x.w2 for x in self.experts if x is not None] + + if self.recompute_fwd_gate_up: + inp = None if not self.bwd_subbatch else (input_fp8_slice, input_scale_slice) + o1 = self.fwd_gate_up(inp, expert_w1, len(expert_w1), tokens_per_expert, m_indices=m_indices) + else: + o1 = self.o1 + + # do2 + do1, o2_s, probs_grad = self.bwd_dowm_input( + expert_w2, out_grad, o1, tokens_per_expert, unzipped_probs=unzipped_probs, m_indices=m_indices + ) + del o1 + if self.o1 is not None: + self.o1._clear_to_zero_allocation() + self.o1 = None + + # dw1 + self.bwd_gate_up_weight( + do1, + None, + expert_w1, + tokens_per_expert, + input_fp8_slice=input_fp8_slice, + input_scale_slice=input_scale_slice, + clear_input=reset_status, + ) + + if reset_status: + if self.input_fp8 is not None: + self.input_fp8._clear_to_zero_allocation() + self.input_fp8 = None + if self.input_scale is not None: + self.input_scale._clear_to_zero_allocation() + self.input_scale = None + if self.input is not None: + self.input._clear_to_zero_allocation() + self.input = None + + # dx + dx = self.bwd_gate_up_input( + do1, + expert_w1, + tokens_per_expert, + dx=out_grad[0] if isinstance(out_grad, tuple) else out_grad, + m_indices=m_indices, + ) + del do1 + + # dw2 + if isinstance(out_grad, tuple): + do3_fp8, do3_scale = self.fused_transpose_split_quant(out_grad[0], out_grad[1], tokens_per_expert, True) + out_grad[0]._clear_to_zero_allocation() + out_grad[1]._clear_to_zero_allocation() + self.bwd_down_weight((do3_fp8, do3_scale), o2_s, expert_w2, tokens_per_expert) + else: + self.bwd_down_weight(out_grad, o2_s, expert_w2, tokens_per_expert) + + if reset_status: + self.reset_statue() + return dx, probs_grad diff --git a/paddleformers/transformers/fused_a2a.py b/paddleformers/transformers/fused_a2a.py index 7b5fa09c9e0..7e1b6c9c22a 100644 --- a/paddleformers/transformers/fused_a2a.py +++ b/paddleformers/transformers/fused_a2a.py @@ -72,78 +72,145 @@ def get_buffer(group: Group, hidden_bytes: int): return _buffer +def fused_dispatch_forward_func( + x, + token_indices, + token_probs, + num_experts, + group, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, +): + """Forward pass of fused dispatch.""" + # Calculate layout before actual dispatch + if isinstance(x, tuple): + buffer = get_buffer(group, get_hidden_bytes(x[0])) + else: + buffer = get_buffer(group, get_hidden_bytes(x)) + ( + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + is_token_in_rank, + previous_event_, + ) = buffer.get_dispatch_layout( + token_indices, + num_experts, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + assert token_probs.dtype == paddle.float32 + # Do MoE dispatch + # NOTES: the CPU will wait for GPU's signal to arrive, + # so this is not compatible with CUDA graph + (recv_x, recv_token_indices, recv_token_probs, num_recv_tokens_per_expert_list, handle, event,) = buffer.dispatch( + x, + topk_idx=token_indices, + topk_weights=token_probs, + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=num_tokens_per_expert, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + states = dict() + states["dispatched_indices"] = recv_token_indices + states["tokens_per_expert"] = num_recv_tokens_per_expert_list + states["handle"] = handle + + return recv_x, recv_token_probs, states, event + + +def fused_dispatch_backward_func( + grad_output, + grad_token_probs, + group, + handle, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, +): + """Backward pass of fused dispatch.""" + buffer = get_buffer(group, get_hidden_bytes(grad_output)) + + grad_x, grad_token_probs, event = buffer.combine( + grad_output.contiguous(), + handle, + topk_weights=grad_token_probs.cast(paddle.float32), + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + return grad_x, None, grad_token_probs + + +def fused_combine_forward_func( + x, group, states, previous_event=None, async_finish=False, allocate_on_comm_stream=False +): + """Forward pass of fused combine.""" + handle = states["handle"] + buffer = get_buffer(group, get_hidden_bytes(x)) + combined_x, _, event = buffer.combine( + x, + handle=handle, + async_finish=async_finish, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + return combined_x + + +def fused_combine_backward_func( + grad_output, group, handle, previous_event=None, async_finish=False, allocate_on_comm_stream=False +): + """Backward pass of fused combine.""" + if isinstance(grad_output, tuple): + buffer = get_buffer(group, get_hidden_bytes(grad_output[0])) + grad_x, _, _, _, _, event = buffer.dispatch( + (grad_output[0].contiguous(), grad_output[1].contiguous()), + handle=handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + else: + buffer = get_buffer(group, get_hidden_bytes(grad_output)) + grad_x, _, _, _, _, event = buffer.dispatch( + grad_output.contiguous(), + handle=handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + return grad_x + + class FusedDispatch(PyLayer): """Fused dispatch operation for MoE routing combining computation and communication.""" @staticmethod def forward(ctx, x, token_indices, token_probs, num_experts, group, previous_event=None): """Forward pass of fused dispatch.""" - # Calculate layout before actual dispatch - buffer = get_buffer(group, get_hidden_bytes(x)) - ( - num_tokens_per_rank, - num_tokens_per_rdma_rank, - num_tokens_per_expert, - is_token_in_rank, - previous_event, - ) = buffer.get_dispatch_layout( - token_indices, - num_experts, - previous_event=None, - async_finish=False, - allocate_on_comm_stream=False, - ) - - # Do MoE dispatch - # NOTES: the CPU will wait for GPU's signal to arrive, - # so this is not compatible with CUDA graph - ( - recv_x, - recv_token_indices, - recv_token_probs, - num_recv_tokens_per_expert_list, - handle, - event, - ) = buffer.dispatch( - x, - topk_idx=token_indices, - topk_weights=token_probs.cast(paddle.float32), - num_tokens_per_rank=num_tokens_per_rank, - num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, - is_token_in_rank=is_token_in_rank, - num_tokens_per_expert=num_tokens_per_expert, - previous_event=None, - async_finish=False, - allocate_on_comm_stream=False, + recv_x, recv_token_probs, states, event = fused_dispatch_forward_func( + x, token_indices, token_probs, num_experts, group, previous_event ) ctx.group = group - ctx.handle = handle + ctx.handle = states["handle"] ctx.event = event - tokens_per_expert = paddle.to_tensor(num_recv_tokens_per_expert_list) - - states = dict() - states["dispatched_indices"] = recv_token_indices - states["tokens_per_expert"] = tokens_per_expert - states["handle"] = handle return recv_x, recv_token_probs, states @staticmethod def backward(ctx, grad_output, grad_token_probs): """Backward pass of fused dispatch.""" - buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output)) - handle = ctx.handle - - grad_x, grad_token_probs, event = buffer.combine( - grad_output.contiguous(), - handle, - topk_weights=grad_token_probs.cast(paddle.float32), - previous_event=None, - async_finish=False, - allocate_on_comm_stream=False, - ) - return grad_x, None, grad_token_probs + return fused_dispatch_backward_func(grad_output, grad_token_probs, ctx.group, ctx.handle) class FusedCombine(PyLayer): @@ -152,12 +219,9 @@ class FusedCombine(PyLayer): @staticmethod def forward(ctx, x, group, states, previous_event=None): """Forward pass of fused combine.""" - handle = states["handle"] - buffer = get_buffer(group, get_hidden_bytes(x)) - combined_x, _, event = buffer.combine( - x, handle=handle, async_finish=False, previous_event=None, allocate_on_comm_stream=False - ) - ctx.handle = handle + combined_x = fused_combine_forward_func(x, group, states, previous_event) + + ctx.handle = states["handle"] ctx.group = group ctx.previous_event = previous_event @@ -166,15 +230,7 @@ def forward(ctx, x, group, states, previous_event=None): @staticmethod def backward(ctx, grad_output): """Backward pass of fused combine.""" - buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output)) - grad_x, _, _, _, _, event = buffer.dispatch( - grad_output.contiguous(), - handle=ctx.handle, - previous_event=ctx.previous_event, - async_finish=False, - allocate_on_comm_stream=False, - ) - return grad_x + return fused_combine_backward_func(grad_output, ctx.group, ctx.handle, ctx.previous_event) if HAVE_DEEP_EP: @@ -214,3 +270,96 @@ def fused_combine(x, group, handle, previous_event=None): else: fused_dispatch = None fused_combine = None + + +class DispatchNode: + def __init__(self, name="dispatch"): + self.name = name + + def reset_statue(self): + self.handle = None + + def forward( + self, + x, + token_indices, + token_probs, + num_experts, + group, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, + ): + """Forward pass of fused dispatch.""" + recv_x, recv_token_probs, states, event = fused_dispatch_forward_func( + x, + token_indices, + token_probs, + num_experts, + group, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + self.group = group + self.handle = states["handle"] + self.event = event + + return recv_x, recv_token_probs, states + + def backward( + self, grad_output, grad_token_probs, previous_event=None, async_finish=False, allocate_on_comm_stream=False + ): + """Backward pass of fused dispatch.""" + out = fused_dispatch_backward_func( + grad_output, + grad_token_probs, + self.group, + self.handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + self.reset_statue() + return out + + +class CombineNode: + def __init__(self, name="combine"): + self.name = name + + def reset_statue(self): + self.handle = None + + def forward(self, x, group, handle, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + """Forward pass of fused combine.""" + states = dict() + states["handle"] = handle + combined_x = fused_combine_forward_func( + x, + group, + states, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + self.handle = handle + self.group = group + self.previous_event = previous_event + + return combined_x + + def backward(self, grad_output, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + """Backward pass of fused combine.""" + out = fused_combine_backward_func( + grad_output, + self.group, + self.handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + self.reset_statue() + return out diff --git a/paddleformers/transformers/moe_utils.py b/paddleformers/transformers/moe_utils.py index 466591b0638..3776cf79ef9 100644 --- a/paddleformers/transformers/moe_utils.py +++ b/paddleformers/transformers/moe_utils.py @@ -18,6 +18,24 @@ import paddle +from ..utils.tools import get_env_device + + +def to_device(tensor, place=None): + if place is None: + place = get_env_device() + + if isinstance(place, str): + place = paddle.device._convert_to_place(place) + + if not tensor.place._equals(place): + new_t = tensor._copy_to(place, True) + dst_tensor = tensor.value().get_tensor() + src_tensor = new_t.value().get_tensor() + dst_tensor._share_data_with(src_tensor) + + return tensor + def permute( tokens, @@ -99,3 +117,18 @@ def unpermute( include_self=True, ) return output_tokens + + +def offload(tensor): + if paddle.is_compiled_with_cuda(): + place = paddle.CUDAPinnedPlace() + else: + place = paddle.CPUPlace() + + new_tensor = to_device(tensor, place) + assert new_tensor is tensor, "to_device must be inplace operation" + + +def reload(tensor): + new_tensor = to_device(tensor) + assert new_tensor is tensor, "to_device must be inplace operation"