diff --git a/paddleformers/examples/deepseek_v3/config/config.json b/paddleformers/examples/deepseek_v3/config/config.json new file mode 100644 index 00000000000..aec7bbb8cb6 --- /dev/null +++ b/paddleformers/examples/deepseek_v3/config/config.json @@ -0,0 +1,75 @@ +{ + "architectures": [ + "DeepseekV3ForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_deepseek.DeepseekV3Config", + "AutoModel": "modeling_deepseek.DeepseekV3Model", + "AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM" + }, + "aux_loss_alpha": 0.001, + "bos_token_id": 0, + "eos_token_id": 1, + "ep_size": 1, + "first_k_dense_replace": 3, + "hidden_act": "silu", + "hidden_size": 7168, + "initializer_range": 0.02, + "intermediate_size": 18432, + "kv_lora_rank": 512, + "max_position_embeddings": 163840, + "model_type": "deepseek_v3", + "moe_intermediate_size": 2048, + "moe_layer_freq": 1, + "n_group": 8, + "n_routed_experts": 8, + "n_shared_experts": 1, + "norm_topk_prob": true, + "num_attention_heads": 128, + "num_experts_per_tok": 8, + "num_hidden_layers": 15, + "num_key_value_heads": 128, + "num_nextn_predict_layers": 1, + "pretraining_tp": 1, + "q_lora_rank": 1536, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn" + }, + "rope_theta": 10000, + "routed_scaling_factor": 2.5, + "scoring_func": "sigmoid", + "seq_aux": true, + "tie_word_embeddings": false, + "topk_group": 4, + "topk_method": "noaux_tc", + "dtype": "bfloat16", + "transformers_version": "4.33.1", + "use_cache": true, + "v_head_dim": 128, + "vocab_size": 129280, + "using_flex_token": true, + "using_fake_gate": true, + "use_fused_rms_norm": true, + "fuse_attention_ffn": true, + "use_fused_rope": true, + "token_drop_steps": 0, + "recompute_fwd_gate_up": true, + "adaptive_remained_O1_recompute_ratio": 0.3, + "using_post_norm_recompute": true, + "is_split_group_gemm": false, + "use_dualpipev": true, + "send_mtp_embed": true, + "offline_quant_expert_weight": false, + "clear_origin_weight_when_offline_quant": false + } \ No newline at end of file diff --git a/paddleformers/examples/deepseek_v3/config/pretrain_argument.json b/paddleformers/examples/deepseek_v3/config/pretrain_argument.json new file mode 100644 index 00000000000..0c8d4aefed9 --- /dev/null +++ b/paddleformers/examples/deepseek_v3/config/pretrain_argument.json @@ -0,0 +1,53 @@ +{ + "model_name_or_path": "./config/", + "tokenizer_name_or_path": "deepseek-ai/DeepSeek-V3", + "input_dir": "./data", + "output_dir": "./checkpoints/pretrain_ckpts", + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 24, + "per_device_eval_batch_size": 1, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 4, + "pipeline_parallel_config": "use_dualpipev", + "sharding_parallel_degree": 2, + "sharding_parallel_config": "split_param", + "sharding_comm_buffer_size_MB": 2048, + "expert_parallel_degree": 2, + "sharding": "stage1", + "virtual_pp_degree": 1, + "sequence_parallel": 0, + "use_flash_attention": true, + "max_seq_length": 4097, + "learning_rate": 3e-05, + "min_learning_rate": 3e-06, + "warmup_steps": 30, + "logging_steps": 1, + "max_steps": 200, + "save_steps": 5000, + "eval_steps": 1000, + "weight_decay": 0.01, + "bf16": true, + "fp16_opt_level": "O2", + "warmup_ratio": 0.01, + "max_grad_norm": 1.0, + "amp_master_grad": 1, + "dataloader_num_workers": 8, + "continue_training": 0, + "do_train": true, + "do_eval": true, + "do_predict": false, + "disable_tqdm": true, + "recompute": false, + "distributed_dataloader": 1, + "unified_checkpoint": true, + "save_total_limit": 2, + "skip_profile_timer": false, + "use_fused_rms_norm": true, + "fuse_attention_ffn": true, + "use_fused_rope": true, + "save_sharded_model": false, + "load_sharded_model": false, + "use_expert_parallel": true, + "unified_checkpoint_config": "skip_save_model_weight", + "offload_optim": true + } \ No newline at end of file diff --git a/paddleformers/examples/deepseek_v3/run.sh b/paddleformers/examples/deepseek_v3/run.sh new file mode 100644 index 00000000000..adb56ecf03f --- /dev/null +++ b/paddleformers/examples/deepseek_v3/run.sh @@ -0,0 +1,23 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# llama 模型数据下载 +# mkdir -p data +# wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k.bin +# wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k.idx + +# mpirun sh script/kill_process.sh +# mpirun rm -rf output +nohup bash script/train_gpu.sh config/pretrain_argument.json > run.log 2>&1 & + diff --git a/paddleformers/examples/deepseek_v3/run_pretrain.py b/paddleformers/examples/deepseek_v3/run_pretrain.py new file mode 100644 index 00000000000..eaf966da095 --- /dev/null +++ b/paddleformers/examples/deepseek_v3/run_pretrain.py @@ -0,0 +1,615 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import math +import os +import sys +import time +from dataclasses import dataclass, field +from typing import Optional + +import paddle + +from paddleformers.data.causal_dataset import ( + build_train_valid_test_datasets, + check_data_split, + print_rank_0, +) +from paddleformers.trainer import ( + FP8QuantWeightCallback, + PdArgumentParser, + StepFlexToken, + Trainer, + TrainingArguments, + get_last_checkpoint, + set_seed, + speed_metrics, +) +from paddleformers.transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForCausalLMPipe, + AutoTokenizer, + CosineAnnealingWithWarmupDecay, + LinearAnnealingWithWarmupDecay, +) +from paddleformers.transformers.configuration_utils import LlmMetaConfig, llmmetaclass +from paddleformers.utils.batch_sampler import DistributedBatchSampler +from paddleformers.utils.log import logger +from paddleformers.utils.tools import get_env_device + +# Pretaining Environment Variables to support sharding stage1 overlap optimization. +os.environ["USE_CASUAL_MASK"] = "True" + + +from paddleformers.trainer.utils.doc import add_start_docstrings + + +@dataclass +@llmmetaclass +@add_start_docstrings(TrainingArguments.__doc__) +class PreTrainingArguments(TrainingArguments): + min_learning_rate: float = field( + default=1e-5, + metadata={"help": "Minimum learning rate deacyed to."}, + ) + decay_steps: float = field( + default=None, + metadata={ + "help": "The steps use to control the learing rate. If the step > decay_steps, will use the min_learning_rate." + }, + ) + enable_linear_fused_grad_add: bool = field( + default=False, + metadata={ + "help": "Enable fused linear grad add strategy, which will reduce elementwise add for grad accumulation in the backward of nn.Linear ." + }, + ) + # NOTE(gongenlei): new add autotuner_benchmark + autotuner_benchmark: bool = field( + default=False, + metadata={"help": "Weather to run benchmark by autotuner. True for from_scratch and pad_max_length."}, + ) + unified_checkpoint: bool = field( + default=True, + metadata={"help": "Enable fused linear grad add strategy."}, + ) + + def __post_init__(self): + super().__post_init__() + # NOTE(gongenlei): new add autotuner_benchmark + from paddleformers.trainer.trainer_utils import IntervalStrategy + + if self.autotuner_benchmark: + self.max_steps = 5 + self.do_train = True + self.do_export = False + self.do_predict = False + self.do_eval = False + self.overwrite_output_dir = True + self.load_best_model_at_end = False + self.report_to = [] + self.save_strategy = IntervalStrategy.NO + self.evaluation_strategy = IntervalStrategy.NO + self.unified_checkpoint = False + + +@dataclass +class DataArguments: + """ + Arguments pertaining to what data we are going to input our model for training and evaluating. + Using `PdArgumentParser` we can turn this class into argparse arguments to be able to + specify them on the command line. + """ + + input_dir: str = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + split: str = field(default="949,50,1", metadata={"help": "Train/valid/test data split."}) + + max_seq_length: int = field( + default=1024, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + share_folder: bool = field( + default=False, + metadata={"help": "Use share folder for data dir and output dir on multi machine."}, + ) + + data_impl: str = field(default="mmap", metadata={"help": "The format of the preprocessed data."}) + skip_warmup: bool = field( + default=True, + metadata={"help": "Whether to skip the warmup process of mmap files."}, + ) + data_cache: str = field(default=None, metadata={"help": "The path of the cached dataset."}) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to pre-train from. + """ + + model_name_or_path: str = field( + default="__internal_testing__/tiny-random-llama", + metadata={ + "help": "Path to pretrained model or model identifier from https://paddleformers.readthedocs.io/zh/latest/model_zoo/transformers.html" + }, + ) + tokenizer_name_or_path: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + + use_fast_layer_norm: bool = field( + default=False, + metadata={"help": "GPT3 model, use fast layernorm"}, + ) + + hidden_dropout_prob: float = field(default=0.1, metadata={"help": "The hidden dropout prob."}) + attention_probs_dropout_prob: float = field(default=0.1, metadata={"help": "The attention hidden dropout prob."}) + + fuse_attention_qkv: bool = field( + default=None, + metadata={"help": "whether to fuse attention qkv"}, + ) + fuse_attention_ffn: bool = field( + default=None, + metadata={"help": "whether to fuse first up and gate proj in mlp block"}, + ) + + continue_training: bool = field( + default=False, + metadata={ + "help": "Pre-training from existing paddleformers model weights. Default False and model will train from scratch. If set True, the model_name_or_path argument must exist in the paddleformers models." + }, + ) + num_hidden_layers: Optional[int] = field( + default=None, + metadata={"help": "num_hidden_layers."}, + ) + + +def create_pretrained_dataset( + data_args, + training_args, + data_file, + tokenizer, + need_data=True, +): + + check_data_split(data_args.split, training_args.do_train, training_args.do_eval, training_args.do_predict) + + train_val_test_num_samples = [ + training_args.per_device_train_batch_size + * training_args.dataset_world_size + * training_args.max_steps + * training_args.gradient_accumulation_steps, + training_args.per_device_eval_batch_size + * training_args.dataset_world_size + * training_args.eval_iters + * (training_args.max_steps // training_args.eval_steps + 1), + training_args.per_device_eval_batch_size * training_args.dataset_world_size * training_args.test_iters, + ] + + print_rank_0(" > datasets target sizes (minimum size):") + if training_args.do_train: + print_rank_0(" train: {}".format(train_val_test_num_samples[0])) + if training_args.do_eval: + print_rank_0(" validation: {}".format(train_val_test_num_samples[1])) + if training_args.do_predict: + print_rank_0(" test: {}".format(train_val_test_num_samples[2])) + + # Build the datasets. + train_dataset, valid_dataset, test_dataset = build_train_valid_test_datasets( + data_prefix=data_file, + data_impl=data_args.data_impl, + splits_string=data_args.split, + train_val_test_num_samples=train_val_test_num_samples, + seq_length=data_args.max_seq_length, + seed=training_args.seed, + skip_warmup=data_args.skip_warmup, + share_folder=data_args.share_folder, + data_cache_path=data_args.data_cache, + need_data=need_data, + ) + + def print_dataset(data, mode="train"): + logger.info(f"Sample data for {mode} mode.") + # input_ids, loss_mask, attention_mask, position_ids, labels = data + input_ids = data["text"] + logger.info(tokenizer._decode(list(input_ids))) + + from paddleformers.data import Stack + + def _collate_data(data, stack_fn=Stack()): + tokens_ = stack_fn([x["text"] for x in data]) + + labels = copy.deepcopy(tokens_)[:, 1:] + tokens = tokens_[:, :-1] + + return { + "input_ids": tokens, + "labels": labels, + } + + if need_data: + if training_args.do_train: + print_dataset(train_dataset[0], "train") + if training_args.do_eval: + print_dataset(valid_dataset[0], "valid") + if training_args.do_predict: + print_dataset(test_dataset[0], "test") + + return train_dataset, valid_dataset, test_dataset, _collate_data + + +def get_train_data_file(args): + if len(args.input_dir.split()) > 1: + # weight-1 data-prefix-1 weight-2 data-prefix-2 ... + return args.input_dir.split() + else: + files = [ + os.path.join(args.input_dir, f) + for f in os.listdir(args.input_dir) + if (os.path.isfile(os.path.join(args.input_dir, f)) and ("_idx.npz" in str(f) or ".idx" in str(f))) + ] + files = [x.replace("_idx.npz", "") for x in files] + files = [x.replace(".idx", "") for x in files] + + if len(files) > 1: + ret = [] + logger.info("You are using multi-dataset:") + for x in files: + ret.append(1.0) + ret.append(x) + logger.info(" > set weight of %s dataset to 1.0" % x) + return ret + + return files + + +class PretrainingTrainer(Trainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_pretraining = True + + def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"): + # keep eval_dataloader + eval_dataloader = getattr(self, "eval_dataloader", None) + if eval_dataloader is None: + eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset + eval_dataloader = self.get_eval_dataloader(eval_dataset) + # must call data loader, otherwise, it will init many times, cause OOM error. + self.eval_dataloader = eval_dataloader() + + start_time = time.time() + # Temporarily disable metric computation, we will do it in the loop here. + compute_metrics = self.compute_metrics + eval_loop = self.evaluation_loop + + output = eval_loop( + eval_dataloader, + description="Evaluation", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if compute_metrics is None else None, + ignore_keys=ignore_keys, + # Only evaluate max_eval_iters + max_eval_iters=self.args.eval_iters, + ) + + total_batch_size = self.args.eval_batch_size * self.args.world_size + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) + + self.log(output.metrics) + + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) + return output.metrics + + def _get_eval_sampler(self, eval_dataset) -> Optional[paddle.io.Sampler]: + return DistributedBatchSampler( + eval_dataset, + batch_size=self.args.per_device_eval_batch_size, + shuffle=False, + num_replicas=self.args.dataset_world_size, + rank=self.args.dataset_rank, + drop_last=self.args.dataloader_drop_last, + ) + + def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: + return DistributedBatchSampler( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + shuffle=False, + num_replicas=self.args.dataset_world_size, + rank=self.args.dataset_rank, + drop_last=self.args.dataloader_drop_last, + ) + + +def main(): + parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments)) + # Support format as "args.json --arg1 value1 --arg2 value2.” + # In case of conflict, command line arguments take precedence. + if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): + model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines() + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if training_args.no_recompute_layers is not None: + training_args.no_recompute_layers.sort() + + if training_args.enable_linear_fused_grad_add: + from utils.fused_layers import mock_layers + + mock_layers() + + if model_args.tokenizer_name_or_path is None: + model_args.tokenizer_name_or_path = model_args.model_name_or_path + + if data_args.data_cache is not None: + os.makedirs(data_args.data_cache, exist_ok=True) + + paddle.set_device(training_args.device) + set_seed(seed=training_args.seed) + + training_args.eval_iters = 10 + training_args.test_iters = training_args.eval_iters * 10 + + # Log model and data config + training_args.print_config(model_args, "Model") + training_args.print_config(data_args, "Data") + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, " + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}" + ) + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + # if last_checkpoint is None and len( + # os.listdir(training_args.output_dir)) > 1: + # raise ValueError( + # f"Output directory ({training_args.output_dir}) already exists and is not empty. " + # "Use --overwrite_output_dir to overcome.") + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path, **{"download_hub": "bos"}) + config = AutoConfig.from_pretrained(model_args.model_name_or_path) + + # set all llm config + LlmMetaConfig.set_llm_config(config, training_args) + config.use_fast_layer_norm = model_args.use_fast_layer_norm + + config.seq_length = data_args.max_seq_length + # There are some technique extend RotaryEmbedding context. so don't change max_position_embeddings + if not model_args.continue_training: + config.max_position_embeddings = max(config.max_position_embeddings, data_args.max_seq_length) + + if not model_args.continue_training: + config.vocab_size = max(config.vocab_size, ((tokenizer.vocab_size - 1) // 128 + 1) * 128) + logger.info(f"Reset vocab size to {config.vocab_size} for batter amp peformance.") + + config.num_hidden_layers = ( + model_args.num_hidden_layers if model_args.num_hidden_layers is not None else config.num_hidden_layers + ) + # Config for model using dropout, such as GPT. + if hasattr(config, "use_dualpipev"): + # NOTE(zhangyuqin): In Paddle, the segmentation and scheduling of pipeline parallel + # models are separate. Therefore, first we need to set the flag in the model config + # to perform V-shape segmentation. Second, we need to set the flag in the training_args + # to configure strategy.hybrid_configs to choose the DualPipeV schedule. + config.use_dualpipev = "use_dualpipev" in training_args.pipeline_parallel_config + if hasattr(config, "hidden_dropout_prob"): + config.hidden_dropout_prob = model_args.hidden_dropout_prob + if hasattr(config, "attention_probs_dropout_prob"): + config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob + if model_args.fuse_attention_qkv is not None: + config.fuse_attention_qkv = model_args.fuse_attention_qkv + if model_args.fuse_attention_ffn is not None: + config.fuse_attention_ffn = model_args.fuse_attention_ffn + + if config.sequence_parallel: + assert config.tensor_parallel_degree > 1, "tensor_parallel_degree must be larger than 1 for sequence parallel." + assert ( + config.num_attention_heads % config.sep_parallel_degree == 0 + ), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}" + assert ( + config.seq_length % config.context_parallel_degree == 0 + ), f"seq_length:{config.seq_length} must be divisible by context_parallel_degree {config.context_parallel_degree}" + + if training_args.sharding_parallel_config is not None: + # for stage1 overlap optimization + if ( + "enable_stage1_allgather_overlap" in training_args.sharding_parallel_config + or "enable_stage1_broadcast_overlap" in training_args.sharding_parallel_config + ): + from paddle.io.reader import use_pinned_memory + + use_pinned_memory(False) + + if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1: + try: + from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401 + + LinearConfig.enable_accumulate_steps_opt() + LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps) + except ImportError: + # It's OK, not use accumulate_steps optimization + pass + + print("Final pre-training config:", config) + + # Set the dtype for loading model + dtype = "float32" + if training_args.fp16_opt_level == "O2": + if training_args.fp16: + dtype = "float16" + if training_args.bf16: + dtype = "bfloat16" + + model_class = AutoModelForCausalLM + if training_args.pipeline_parallel_degree > 1: + model_class = AutoModelForCausalLMPipe + if "LLama" in str(config.architectures): + try: + from utils.register_reshard import register_pp_reshard_information + + register_pp_reshard_information(config.num_hidden_layers) + except: + print("Not register llama pp reshard information.") + + architectures_to_check = {"Qwen2Moe", "DeepseekV2", "DeepseekV3"} + if ( + any(architecture in str(config.architectures) for architecture in architectures_to_check) + and training_args.data_parallel_degree > 1 + ): + training_args.use_expert_parallel = True + + if model_args.continue_training: + # NOTE(gongenlei): new add + if training_args.autotuner_benchmark: + model = model_class.from_config(config, dtype=dtype) + else: + model = model_class.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + ) + else: + # 修改这里降低模型层数,deepseek前3层为dense层,之后才有稀疏层 + # config.num_hidden_layers = 4 # v3是61 + # config.first_k_dense_replace = 0 # v3是3 + # # 修改这里降低模型专家数量,如果希望进行EP并行,专家数量要能够被并行度整除 + # config.n_routed_experts = 64 # v3是256 + # config.num_experts_per_tok = 8 # v3是8 + # config.topk_group = 4 # v3是4 + + # config.using_flex_token = True + # config.num_nextn_predict_layers = 1 + # config.using_fake_gate = True + # config.use_fused_rms_norm = True + # config.fuse_attention_ffn = True + # config.use_fused_rope = True + # config.token_drop_steps = 0 + model = model_class.from_config(config, dtype=dtype) + + if training_args.recompute: + model.recompute_enable() + + # Create the learning_rate sheduler and optimizer + if training_args.decay_steps is None: + training_args.decay_steps = training_args.max_steps + + if training_args.warmup_steps > 0: + warmup_steps = training_args.warmup_steps + else: + warmup_steps = training_args.warmup_ratio * training_args.max_steps + + lr_scheduler = None + if training_args.lr_scheduler_type.value == "cosine": + lr_scheduler = CosineAnnealingWithWarmupDecay( + max_lr=training_args.learning_rate, + min_lr=training_args.min_learning_rate, + warmup_step=warmup_steps, + decay_step=training_args.decay_steps, + last_epoch=0, + ) + elif training_args.lr_scheduler_type.value == "linear": + lr_scheduler = LinearAnnealingWithWarmupDecay( + max_lr=training_args.learning_rate, + min_lr=training_args.min_learning_rate, + warmup_step=warmup_steps, + decay_step=training_args.decay_steps, + last_epoch=0, + ) + + data_file = get_train_data_file(data_args) + train_dataset, eval_dataset, test_dataset, data_collator = create_pretrained_dataset( + data_args, + training_args, + data_file, + tokenizer, + need_data=training_args.should_load_dataset, + ) + + total_effective_tokens = ( + training_args.per_device_train_batch_size + * training_args.dataset_world_size + * training_args.max_steps + * training_args.gradient_accumulation_steps + * data_args.max_seq_length + ) + + callbacks = [StepFlexToken(), FP8QuantWeightCallback()] + + trainer = PretrainingTrainer( + model=model, + args=training_args, + data_collator=data_collator, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + optimizers=(None, lr_scheduler), + tokenizer=tokenizer, + callbacks=callbacks, + ) + + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=checkpoint) + + # NOTE(gongenlei): new add + if not training_args.autotuner_benchmark: + metrics = train_result.metrics + if not int(os.getenv("test_ci_no_save_model", 0)): + trainer.save_model() + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + if training_args.do_predict: + test_ret = trainer.predict(test_dataset) + trainer.log_metrics("test", test_ret.metrics) + + if training_args.do_train and training_args.should_load_dataset: + effective_tokens_per_second = total_effective_tokens / train_result.metrics["train_runtime"] + print(f"Effective Tokens per second: {effective_tokens_per_second:.2f}") + print(f"ips: {effective_tokens_per_second:.2f} tokens/s") + + +if __name__ == "__main__": + main() diff --git a/paddleformers/examples/deepseek_v3/script/kill_process.sh b/paddleformers/examples/deepseek_v3/script/kill_process.sh new file mode 100644 index 00000000000..3c3db6a4639 --- /dev/null +++ b/paddleformers/examples/deepseek_v3/script/kill_process.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -x +skip_kill_time=${1:-"False"} +function kill_impl() { + skip_kill_time=$1 + # kill aadiff test finally. + pids=`ps -ef | grep pretrain.py | grep -v grep | awk '{print $2}'` + if [[ "$pids" != "" ]] ; then + echo $pids + echo $pids | xargs kill -9 + fi + + echo "Killing processes on gpu" + lsof /dev/nvidia* | awk '{print $2}' | xargs -I {} kill -9 {} +} + +kill_impl $skip_kill_time || true \ No newline at end of file diff --git a/paddleformers/examples/deepseek_v3/script/selective_launch.py b/paddleformers/examples/deepseek_v3/script/selective_launch.py new file mode 100644 index 00000000000..1f8a37bfbc5 --- /dev/null +++ b/paddleformers/examples/deepseek_v3/script/selective_launch.py @@ -0,0 +1,73 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Selective launch script. + +Usage: python script/selective_launch.py ... +""" +import os +import sys + + +def parse_ranks(ranks_strs): + """ + parse_ranks + """ + # NOTE: You can return ranks directly here to change script/train_gpu.sh + # and script/kill_process.sh together + + # Example 1: Use contiguous nodes [8, 16) + return range(6, 7) + + # Example 2: Use non-contiguous nodes [4, 8) + {10} + [30, 32), i.e., [4, 5, 6, 7, 10, 30, 31] + # return list(range(0, 16)) + list(range(24, 40)) + + # Example 3: + # Just Python code, return any nodes you want! + + if not ranks_strs: + return None + + ranks = [] + for r in ranks_strs: + r = eval(r) + if isinstance(r, int): + ranks.append(r) + else: + ranks.extend(r) + return ranks + + +def main(port, ranks): + """ + main + """ + ips = [ip.strip() for ip in os.getenv("TRAINER_INSTANCES").split(",") if ip.strip()] + if ranks is None: + ranks = list(range(len(ips))) + ranks = sorted(list(set(ranks))) + my_rank = int(os.getenv("POD_INDEX", "0")) + if my_rank not in ranks: + return + + rank = ranks.index(my_rank) + nranks = len(ranks) + + master = ips[ranks[0]] + print(f"--master {master}:{port} --rank {rank} --nnodes {nranks}") + + +if __name__ == "__main__": + main(int(sys.argv[1]), parse_ranks(sys.argv[2:])) diff --git a/paddleformers/examples/deepseek_v3/script/train_gpu.sh b/paddleformers/examples/deepseek_v3/script/train_gpu.sh new file mode 100644 index 00000000000..5d5313c01c8 --- /dev/null +++ b/paddleformers/examples/deepseek_v3/script/train_gpu.sh @@ -0,0 +1,78 @@ +#!/bin/bash + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT + +nnodes=$PADDLE_TRAINERS_NUM +rank=$PADDLE_TRAINER_ID + +for name in `env | grep -E 'PADDLE|ENDPOINT' | awk -F'=' '{print $1}'`; do + unset ${name} +done + +#export FLAGS_shard_bypass_dygraph_optimizer=1 +export NCCL_IB_GID_INDEX=3 +export NVSHMEM_IB_GID_INDEX=3 +export NVSHMEM_IB_TRAFFIC_CLASS=162 + +#export NVSHMEM_IB_ENABLE_IBGDA=true +##export NVSHMEM_DISABLE_P2P=1 +export NVSHMEM_BOOTSTRAP=UID + +unset NVSHMEM_HCA_LIST +unset NVSHMEM_ENABLE_NIC_PE_MAPPING + +LAUNCH_CMD=`python script/selective_launch.py 36677` +if [[ -z "$LAUNCH_CMD" ]]; then + exit 0 +fi + +export PYTHONPATH=../../../:$PYTHONPATH +# export PYTHONPATH=../../PaddleFormers/:$PYTHONPATH + +export CUDA_PATH=/usr/local/cuda-12.9 + +# Flags for best performance +export DSV3_USE_FP8_GEMM=true +export DSV3_USE_ATTEN_RECOMPUTE=true +export FA_VERSION=3 +export FLAGS_share_tensor_for_grad_tensor_holder=1 +export FLAGS_use_default_stream=false +export DSV3_USE_FP8_DISPATCH=true +export USE_DS_GEMM=true + +# Flags for allocator +export FLAGS_large_pool_auto_growth_chunk_size_in_mb=500 +export FLAGS_small_pool_auto_growth_chunk_size_in_mb=20 +export FLAGS_small_pool_size_in_mb=10 +export FLAGS_samll_pool_pre_alloc_in_mb=500 +export FLAGS_large_pool_pre_alloc_in_mb=61440 +export FLAGS_deep_ep_comm_prealloc_in_mb=1000 + +export DSV3_FAST_PRETRAIN=true +bash script/kill_process.sh + +source /root/paddlejob/workspace/env_run/chenxi/chenxi_py3.10/bin/activate +python3.10 -m paddle.distributed.launch \ + --log_dir output/paddle_distributed_logs \ + $LAUNCH_CMD \ + --run_mode=collective \ + ${script:-run_pretrain.py} \ + $@ diff --git a/paddleformers/trainer/__init__.py b/paddleformers/trainer/__init__.py index 53ceb66a961..b129b4a20a1 100644 --- a/paddleformers/trainer/__init__.py +++ b/paddleformers/trainer/__init__.py @@ -75,6 +75,8 @@ "TrainerState", "DEFAULT_PROGRESS_CALLBACK", "TrainerCallback", + "StepFlexToken", + "FP8QuantWeightCallback", ], "trainer_utils": [ "get_last_checkpoint", diff --git a/paddleformers/trainer/trainer.py b/paddleformers/trainer/trainer.py index 55fb28d5c09..50d8b5491cf 100644 --- a/paddleformers/trainer/trainer.py +++ b/paddleformers/trainer/trainer.py @@ -86,12 +86,15 @@ ) from ..peft import LoKrModel, LoRAModel, PrefixModelForCausalLM, ReFTModel, VeRAModel from ..peft.lora import QuantizationLoRABaseLinear -from ..quantization.quantization_linear import ( - ColumnParallelQuantizationLinear, - QuantizationLinear, - RowParallelQuantizationLinear, -) +try: + from ..quantization.quantization_linear import ( + ColumnParallelQuantizationLinear, + QuantizationLinear, + RowParallelQuantizationLinear, + ) +except: + QuantizationLinear = None try: from paddle.distributed.fleet.utils.sequence_parallel_utils import ( register_sequence_parallel_allreduce_hooks, @@ -196,11 +199,13 @@ nested_numpify, nested_truncate, ) +from .utils.load_hf_ckpt import load_huggingface_ckpt from .utils.sharding_io import ShardingIO DEFAULT_CALLBACKS = [DefaultFlowCallback] DEFAULT_PROGRESS_CALLBACK = ProgressCallback + if is_datasets_available(): import datasets @@ -1133,6 +1138,9 @@ def _inner_training_loop( if self.args.ignore_data_skip: self.timers and self.timers("read-data").start() + if self.args.resume_from_huggingface_ckpt is not None: + load_huggingface_ckpt(model, self.args.resume_from_huggingface_ckpt) + for epoch in range(epochs_trained, num_train_epochs): if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance( train_dataloader.batch_sampler, DistributedBatchSampler diff --git a/paddleformers/trainer/trainer_callback.py b/paddleformers/trainer/trainer_callback.py index 812c8dc9f59..54f406c03ec 100644 --- a/paddleformers/trainer/trainer_callback.py +++ b/paddleformers/trainer/trainer_callback.py @@ -20,13 +20,16 @@ """ import dataclasses import json +import os from dataclasses import dataclass from typing import Dict, List, Optional, Union import numpy as np from tqdm.auto import tqdm -from ..utils.log import logger +from paddleformers.transformers.moe_utils import offload, reload +from paddleformers.utils.log import logger + from .trainer_utils import IntervalStrategy, has_length from .training_args import TrainingArguments @@ -39,6 +42,8 @@ "ProgressCallback", "PrinterCallback", "EarlyStoppingCallback", + "StepFlexToken", + "FP8QuantWeightCallback", ] @@ -608,3 +613,69 @@ def on_evaluate(self, args, state, control, metrics, **kwargs): self.check_metric_value(args, state, control, metric_value) if self.early_stopping_patience_counter >= self.early_stopping_patience: control.should_training_stop = True + + +class StepFlexToken(TrainerCallback): + def on_step_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + model = kwargs.pop("model") + if hasattr(model, "step_flex_token"): + model.step_flex_token(state.global_step) + + +g_shard_bypass_dygraph_optimizer = int(os.environ.get("FLAGS_shard_bypass_dygraph_optimizer", 0)) + + +def enable_in_dict_config(config, key): + """enable_in_dict_config""" + return key in config and config[key] + + +skip_count = 0 + + +class FP8QuantWeightCallback(TrainerCallback): + """ + FP8QuantWeightCallback + """ + + def on_step_begin(self, args, state, control, **kwargs): + """ + 每个step开始前把专家参数quant成fp8q + """ + model = kwargs["model"] + optimizer = kwargs["optimizer"] + global skip_count + + if (not g_shard_bypass_dygraph_optimizer or skip_count == 0) and hasattr(model, "fp8_quant_weight"): + model.fp8_quant_weight(True, quant_transpose=True) + optimizer.clear_param_storage("moe_expert") + optimizer.clear_param_storage("rms_linear") + optimizer.clear_param_storage("memory_attn") + optimizer.clear_param_storage("attn_out_project") + optimizer.clear_param_storage("shared_expert") + + self.moe_weights_name = [] + for param in optimizer._inner_opt._parameter_list: + color = getattr(param, "color", -1) + if isinstance(color, dict) and color["color"] == "moe_expert": + self.moe_weights_name.append(param.name) + + for name in self.moe_weights_name: + offload(optimizer._master_weights[name]) + + skip_count += 1 + + def on_optimizer_begin(self, args, state, control, **kwargs): + model = kwargs["model"] + optimizer = kwargs["optimizer"] + global skip_count + + if (not g_shard_bypass_dygraph_optimizer) and hasattr(model, "fp8_quant_weight"): + for name in self.moe_weights_name: + reload(optimizer._master_weights[name]) diff --git a/paddleformers/trainer/training_args.py b/paddleformers/trainer/training_args.py index c505856532b..b7c71e1be31 100644 --- a/paddleformers/trainer/training_args.py +++ b/paddleformers/trainer/training_args.py @@ -901,6 +901,10 @@ class TrainingArguments: default=None, metadata={"help": "The path to a folder with a valid checkpoint for your model."}, ) + resume_from_huggingface_ckpt: Optional[str] = field( + default=None, + metadata={"help": "The path to a folder with a valid huggingface checkpoint for your model."}, + ) auto_parallel_resume_form_hybrid_parallel: Optional[bool] = field( default=False, metadata={"help": "Whether hybrid parallel checkpoints be loaded in auto parallel mode."}, @@ -1405,12 +1409,15 @@ def is_segment_parallel_supported(): else: order = ["dp", "sharding", "pp", "mp"] if self.use_expert_parallel: - if self.moe_sharding_parallel_degree >= 1 and self.expert_parallel_degree > 1: - order.insert(-1, "ep") - sd_idx = order.index("sharding") - # if pp_first, the order = ["dp", "pp", "moe_sharding", "sharding", "sep", "ep", "mp"] - # if sharding_first, the order is ["dp", "moe_sharding", "sharding", "pp", "sep", "ep", "mp"] - order.insert(sd_idx, "moe_sharding") + if not os.getenv("DSV3_FAST_PRETRAIN", "False"): + if self.moe_sharding_parallel_degree >= 1 and self.expert_parallel_degree > 1: + order.insert(-1, "ep") + sd_idx = order.index("sharding") + # if pp_first, the order = ["dp", "pp", "moe_sharding", "sharding", "sep", "ep", "mp"] + # if sharding_first, the order is ["dp", "moe_sharding", "sharding", "pp", "sep", "ep", "mp"] + order.insert(sd_idx, "moe_sharding") + else: + order = order[1:-1] + ["dp", "mp"] if is_segment_parallel_supported(): hybrid_configs = { @@ -1564,6 +1571,10 @@ def is_segment_parallel_supported(): fleet.init(is_collective=True, strategy=strategy) logger.info(strategy) + if os.getenv("DSV3_FAST_PRETRAIN", "False"): + if self.expert_parallel_degree > 1: + self.add_moe_comm_group() + elif self.enable_auto_parallel: self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1) self.sep_parallel_degree = max(self.sep_parallel_degree, 1) diff --git a/paddleformers/trainer/utils/load_hf_ckpt.py b/paddleformers/trainer/utils/load_hf_ckpt.py new file mode 100644 index 00000000000..c0df004428e --- /dev/null +++ b/paddleformers/trainer/utils/load_hf_ckpt.py @@ -0,0 +1,378 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import re +import sys +from collections import defaultdict +from typing import List, Optional + +import paddle + +try: + from safetensors import safe_open +except: + safe_open = None + +_LAYER_RE = re.compile(r"^_layers\.(\d+)\.(\d+)(?:\.(.*))?$") +_EXPERT_W1_RE = re.compile(r"^mlp\.experts\.(\d+)\.w1(?:\.weight)?$") +_EXPERT_W2_RE = re.compile(r"^mlp\.experts\.(\d+)\.w2(?:\.weight)?$") +_SHARE_EXPERT_W1_RE = re.compile(r"^mlp\.shared_experts\.w1(?:\.weight)?$") +_SHARE_EXPERT_W2_RE = re.compile(r"^mlp\.shared_experts\.w2(?:\.weight)?$") + +_EXPERT_W1_RE_v2 = re.compile(r"^mlp\.experts\.(\d+)\.gate_up_fused_proj(?:\.weight)?$") +_SHARE_EXPERT_W1_RE_v2 = re.compile(r"^mlp\.shared_experts\.gate_up_fused_proj(?:\.weight)?$") +_LAYER_RE_v2 = re.compile(r"_layers.deepseek_v2.layers\.(\d+)\.(.*)$") + +custom_name_map = { + "self_attn.input_layernorm.weight": "input_layernorm.weight", + "self_attn.fused_rms_norm_linear.rms_norm_weight": "input_layernorm.weight", + "self_attn.memory_recompute_att.kv_ln_weight": "self_attn.kv_a_layernorm.weight", + "self_attn.fused_rms_norm_linear.kv_down_weight": "self_attn.kv_a_proj_with_mqa.weight", + "self_attn.memory_recompute_att.kv_up_weight": "self_attn.kv_b_proj.weight", + "self_attn.memory_recompute_att.q_ln_weight": "self_attn.q_a_layernorm.weight", + "self_attn.fused_rms_norm_linear.q_down_weight": "self_attn.q_a_proj.weight", + "self_attn.memory_recompute_att.q_up_weight": "self_attn.q_b_proj.weight", +} + + +def paddle_name_to_hf_names_ds_v2(paddle_name: str) -> List[str]: + """ + 将Paddle模型参数名称转换为Hugging Face格式的名称列表 + + 参数: + paddle_name: Paddle格式的参数名称 + + 返回: + Hugging Face格式的参数名称列表(可能拆分多个参数) + """ + if paddle_name == "_layers.deepseek_v2.embed_tokens.weight": + return ["model.embed_tokens.weight"] + + if paddle_name == "_layers.deepseek_v2.norm.weight": + return ["model.norm.weight"] + + if paddle_name == "_layers.lm_head.weight": + return ["lm_head.weight"] + + m = _LAYER_RE_v2.match(paddle_name) + if not m: + print("not match here !!", paddle_name) + return [] + + rest = m.group(2) or "" + layer_id = m.group(1) + if rest in custom_name_map: + rest = custom_name_map[rest] + out_name = "model.layers." + layer_id + "." + rest + + if rest == "mlp.gate_up_fused_proj.weight" or rest == "mlp.w1": + return [ + "model.layers." + layer_id + ".mlp.gate_proj.weight", + "model.layers." + layer_id + ".mlp.up_proj.weight", + ] + + if rest == "mlp.w2": + return ["model.layers." + layer_id + ".mlp.down_proj.weight"] + + if rest == "mlp.shared_experts.gate_up_fused_proj.weight": + return [ + "model.layers." + layer_id + ".mlp.shared_experts.gate_proj.weight", + "model.layers." + layer_id + ".mlp.shared_experts.up_proj.weight", + ] + + if m := _EXPERT_W1_RE_v2.match(rest): + expert_id = m.group(1) + return [ + "model.layers." + layer_id + ".mlp.experts." + expert_id + ".gate_proj.weight", + "model.layers." + layer_id + ".mlp.experts." + expert_id + ".up_proj.weight", + ] + + if m := _EXPERT_W1_RE.match(rest): + expert_id = m.group(1) + return [ + "model.layers." + layer_id + ".mlp.experts." + expert_id + ".gate_proj.weight", + "model.layers." + layer_id + ".mlp.experts." + expert_id + ".up_proj.weight", + ] + + if m := _EXPERT_W2_RE.match(rest): + expert_id = m.group(1) + return ["model.layers." + layer_id + ".mlp.experts." + expert_id + ".down_proj.weight"] + + if m := _SHARE_EXPERT_W1_RE.match(rest): + return [ + "model.layers." + layer_id + ".mlp.shared_experts.gate_proj.weight", + "model.layers." + layer_id + ".mlp.shared_experts.up_proj.weight", + ] + + if m := _SHARE_EXPERT_W2_RE.match(rest): + return ["model.layers." + layer_id + ".mlp.shared_experts.down_proj.weight"] + + return [out_name] + + +def paddle_name_to_hf_names(paddle_name: str) -> List[str]: + """ + 将Paddle模型参数名称转换为Hugging Face格式的名称列表 + + 参数: + paddle_name: Paddle格式的参数名称 + + 返回: + Hugging Face格式的参数名称列表(可能拆分多个参数) + """ + if paddle_name == "_layers.local_shared_layers.DeepseekV2_shared_weight.embed_tokens.weight": + return ["model.embed_tokens.weight"] + + if paddle_name == "_layers.deepseek_v2.embed_tokens.weight": + return ["model.embed_tokens.weight"] + + m = _LAYER_RE.match(paddle_name) + + if not m: + print("not match here !!", paddle_name) + return [] + else: + rest = m.group(3) or "" + + segment_id = int(m.group(1)) + id_in_segment = int(m.group(2)) + + hf_prefix = _get_hf_prefix(segment_id, id_in_segment) + + if rest in custom_name_map: + return [f"{hf_prefix}.{custom_name_map[rest]}"] + + if expert_names := _handle_expert_weights(hf_prefix, rest): + return expert_names + + if shared_mlp_names := _handle_shared_expert_weights(hf_prefix, rest): + return shared_mlp_names + + if mlp_names := _handle_mlp_weights(hf_prefix, rest): + return mlp_names + + if rest == "mlp.gate_up_fused_proj.weight" or rest == "mlp.w1": + return [hf_prefix + ".mlp.gate_proj.weight", hf_prefix + ".mlp.up_proj.weight"] + + if rest == "mlp.w2": + return [hf_prefix + ".mlp.down_proj.weight"] + + if rest == "mlp.shared_experts.gate_up_fused_proj.weight": + return [hf_prefix + ".mlp.shared_experts.gate_proj.weight", hf_prefix + ".mlp.shared_experts.up_proj.weight"] + + if m := _EXPERT_W1_RE_v2.match(rest): + expert_id = m.group(1) + return [ + hf_prefix + ".mlp.experts." + expert_id + ".gate_proj.weight", + hf_prefix + ".mlp.experts." + expert_id + ".up_proj.weight", + ] + + if m := _EXPERT_W1_RE.match(rest): + expert_id = m.group(1) + return [ + hf_prefix + ".mlp.experts." + expert_id + ".gate_proj.weight", + hf_prefix + ".mlp.experts." + expert_id + ".up_proj.weight", + ] + + if m := _EXPERT_W2_RE.match(rest): + expert_id = m.group(1) + return [hf_prefix + ".mlp.experts." + expert_id + ".down_proj.weight"] + + if m := _SHARE_EXPERT_W1_RE.match(rest): + return [hf_prefix + ".mlp.shared_experts.gate_proj.weight", hf_prefix + ".mlp.shared_experts.up_proj.weight"] + + if m := _SHARE_EXPERT_W2_RE.match(rest): + return [hf_prefix + ".mlp.shared_experts.down_proj.weight"] + + return [f"{hf_prefix}.{rest}"] if rest else [hf_prefix] + + +def _get_hf_prefix(segment_id: int, id_in_segment: int) -> str: + """生成Hugging Face格式的层级前缀""" + # 特殊层级映射 + # special_cases = {(0, 0): "model", (60, 2): "model.layers.61", (60, 3): "model"} + # special_cases = {(0, 0): "model", (28, 2): "model.layers.61", (28, 3): "model"} + # special_cases = {(0, 0): "model", (28, 2): "model.layers.61", (4, 1): "model"} + # special_cases = {(0, 0): "model", (28, 2): "model", (28,3): "lm_head"} + special_cases = {(0, 0): "model", (60, 2): "model.layers.61", (60, 3): "model", (60, 4): "lm_head"} + + if (segment_id, id_in_segment) in special_cases: + return special_cases[(segment_id, id_in_segment)] + + # 通用层级计算 + layer_idx = segment_id + id_in_segment - 1 + return f"model.layers.{layer_idx}" + + +def _handle_expert_weights(hf_prefix: str, rest: str) -> Optional[List[str]]: + if m := _EXPERT_W1_RE.match(rest): + expert_id = int(m.group(1)) + return [ + f"{hf_prefix}.mlp.experts.{expert_id}.gate_proj.weight", + f"{hf_prefix}.mlp.experts.{expert_id}.up_proj.weight", + ] + + if m := _EXPERT_W2_RE.match(rest): + expert_id = int(m.group(1)) + return [f"{hf_prefix}.mlp.experts.{expert_id}.down_proj.weight"] + + return None + + +def _handle_shared_expert_weights(hf_prefix: str, rest: str) -> Optional[List[str]]: + if _SHARE_EXPERT_W1_RE.match(rest): + return [ + f"{hf_prefix}.mlp.shared_experts.gate_proj.weight", + f"{hf_prefix}.mlp.shared_experts.up_proj.weight", + ] + + if _SHARE_EXPERT_W2_RE.match(rest): + return [f"{hf_prefix}.mlp.shared_experts.down_proj.weight"] + + return None + + +def _handle_mlp_weights(hf_prefix: str, rest: str) -> Optional[List[str]]: + if rest == "mlp.w1": + return [f"{hf_prefix}.mlp.gate_proj.weight", f"{hf_prefix}.mlp.up_proj.weight"] + + if rest == "mlp.w2": + return [f"{hf_prefix}.mlp.down_proj.weight"] + + return None + + +def prepare_tensor(tensor, dst_shape, *, force_transpose=False): + if isinstance(tensor, list): + t = paddle.concat( + [ + paddle.transpose(tensor[0], perm=[1, 0]).contiguous(), + paddle.transpose(tensor[1], perm=[1, 0]).contiguous(), + ], + axis=-1, + ) + if t.shape != dst_shape: + print("base shape", tensor[0].shape, tensor[1].shape) + print("shape not match ", t.shape, dst_shape) + sys.exit() + return t + + if force_transpose: + return tensor.T.contiguous() + + if tensor.shape == dst_shape: + if len(tensor.shape) != 1: + print("attention same shape not transpose !!!!!!!!!!!!!!!!!!!!!!") + return tensor + if len(tensor.shape) == 2 and paddle.transpose(tensor, perm=[1, 0]).contiguous().shape == dst_shape: + return paddle.transpose(tensor, perm=[1, 0]).contiguous() + + print("shape not match here") + sys.exit() + + +def load_huggingface_ckpt(model, huggingface_ckpt_path): + ckpt_pre = huggingface_ckpt_path + + # 1. 加载参数-文件映射表 + weight_map_path = ckpt_pre + "/model.safetensors.index.json" + with open(weight_map_path, "r") as f: + weight_map = json.load(f)["weight_map"] + + # 2. 创建反向索引:文件 -> 参数列表 + file_to_params = defaultdict(list) + for param_name, filename in weight_map.items(): + file_to_params[filename].append(param_name) + + # 2. 收集模型需要的文件列表 + required_files = set() + file_to_pd_param_name = defaultdict(list) + pd_param_name_to_file = defaultdict(list) + for pd_name, p in model.named_parameters(): + hf_name = paddle_name_to_hf_names(pd_name) + if hf_name[0] in weight_map: + filename = weight_map[hf_name[0]] + required_files.add(filename) + file_to_pd_param_name[filename].append(pd_name) + pd_param_name_to_file[pd_name].append(filename) + else: + print(f"Warning: {pd_name} -> {hf_name[0]} not found in weight map") + import sys + + sys.exit() + + if len(hf_name) > 1: + if hf_name[1] in weight_map: + filename = weight_map[hf_name[1]] + required_files.add(filename) + file_to_pd_param_name[filename].append(pd_name) + if filename != pd_param_name_to_file[pd_name][0]: + pd_param_name_to_file[pd_name].append(filename) + else: + print(f"Warning: {pd_name} -> {hf_name[1]} not found in weight map") + + # 3. 按文件分组加载 + check_list = [] + print("Start load huggingface ckpt") + for i, filename in enumerate(required_files): + try: + with safe_open(ckpt_pre + filename, framework="paddle", device="cpu") as f: + # 加载该文件包含的所有参数 + pd_params = file_to_pd_param_name[filename] + for pd_param in pd_params: + if pd_param in check_list: + continue + + hf_name = paddle_name_to_hf_names(pd_param) + if len(hf_name) == 1: + tensor = f.get_tensor(hf_name[0]) + + force_transpose = False + + model.state_dict()[pd_param].set_value( + paddle.cast( + prepare_tensor( + tensor, model.state_dict()[pd_param].shape, force_transpose=force_transpose + ), + model.state_dict()[pd_param].dtype, + ) + ) + else: + files = pd_param_name_to_file[pd_param] + if len(files) == 1: + tensor0 = f.get_tensor(hf_name[0]) + tensor1 = f.get_tensor(hf_name[1]) + else: + if weight_map[hf_name[0]] == filename: + tensor0 = f.get_tensor(hf_name[0]) + with safe_open( + ckpt_pre + weight_map[hf_name[1]], framework="paddle", device="cpu" + ) as f_other: + tensor1 = f_other.get_tensor(hf_name[1]) + else: + with safe_open( + ckpt_pre + weight_map[hf_name[0]], framework="paddle", device="cpu" + ) as f_other: + tensor0 = f_other.get_tensor(hf_name[0]) + tensor1 = f.get_tensor(hf_name[1]) + model.state_dict()[pd_param].set_value( + prepare_tensor([tensor0, tensor1], model.state_dict()[pd_param].shape) + ) + check_list.append(pd_param) + + except Exception as e: + print(f"Error loading {filename}: {str(e)}") + raise diff --git a/paddleformers/transformers/deepseek_v2/__init__.py b/paddleformers/transformers/deepseek_v2/__init__.py index a0fac197982..07e90115df6 100644 --- a/paddleformers/transformers/deepseek_v2/__init__.py +++ b/paddleformers/transformers/deepseek_v2/__init__.py @@ -56,6 +56,8 @@ "yarn_find_correction_range", "get_triangle_upper_mask", "DeepseekV2LinearScalingRotaryEmbedding", + "set_global_step", + "get_global_step", ], "modeling_auto": [ "DeepseekV2LMHeadAuto", diff --git a/paddleformers/transformers/deepseek_v2/configuration.py b/paddleformers/transformers/deepseek_v2/configuration.py index 1feba3cbec7..2588e8394c1 100644 --- a/paddleformers/transformers/deepseek_v2/configuration.py +++ b/paddleformers/transformers/deepseek_v2/configuration.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ DeepSeekV2 model configuration""" -from ..configuration_utils import PretrainedConfig +from paddleformers.transformers.configuration_utils import PretrainedConfig __all__ = [ "DeepseekV2Config", @@ -179,6 +179,20 @@ def __init__( attention_dropout=0.0, speculate_model_type=False, using_flex_token=False, + use_dualpipev=False, + send_mtp_embed=False, + using_post_norm_recompute=False, + stepped_recompute_fwd_gate_up=False, + recompute_fwd_gate_up=0, + recompute_fa3=0, + is_split_group_gemm=False, + fakse_gate_restrict_balance=False, + adaptive_remained_O1_recompute_ratio=0, + offline_quant_expert_weight=True, + clear_origin_weight_when_offline_quant=True, + mlp_bwd_subbatch_rows=0, + mlp_fwd_subbatch_rows=0, + output_subbatch_rows=0, **kwargs, ): self.vocab_size = vocab_size @@ -227,6 +241,20 @@ def __init__( self.speculate_model_type = speculate_model_type self.use_fp8 = False self.using_flex_token = using_flex_token + self.use_dualpipev = use_dualpipev + self.send_mtp_embed = send_mtp_embed + self.using_post_norm_recompute = using_post_norm_recompute + self.recompute_fwd_gate_up = recompute_fwd_gate_up + self.recompute_fa3 = recompute_fa3 + self.stepped_recompute_fwd_gate_up = stepped_recompute_fwd_gate_up + self.is_split_group_gemm = is_split_group_gemm + self.fakse_gate_restrict_balance = fakse_gate_restrict_balance + self.adaptive_remained_O1_recompute_ratio = adaptive_remained_O1_recompute_ratio + self.offline_quant_expert_weight = offline_quant_expert_weight + self.clear_origin_weight_when_offline_quant = clear_origin_weight_when_offline_quant + self.mlp_bwd_subbatch_rows = mlp_bwd_subbatch_rows + self.mlp_fwd_subbatch_rows = mlp_fwd_subbatch_rows + self.output_subbatch_rows = output_subbatch_rows super().__init__( pad_token_id=pad_token_id, diff --git a/paddleformers/transformers/deepseek_v2/modeling.py b/paddleformers/transformers/deepseek_v2/modeling.py index 04a8651f43e..69fb3b19054 100644 --- a/paddleformers/transformers/deepseek_v2/modeling.py +++ b/paddleformers/transformers/deepseek_v2/modeling.py @@ -23,6 +23,7 @@ import contextlib import math +import os import warnings from functools import partial from typing import List, Optional, Tuple, Union @@ -35,7 +36,9 @@ from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.recompute.recompute import recompute +from paddle.jit import to_static from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from paddle.utils import try_import try: from paddle.incubate.nn.functional import fused_rotary_position_embedding @@ -51,12 +54,13 @@ except: pass +from paddle import _C_ops + try: from paddle.nn.functional.flash_attention import flash_attention except: flash_attention = None - from ...utils.initializer import kaiming_uniform_ from ...utils.log import logger from ...utils.tools import get_env_device @@ -72,11 +76,44 @@ from ..model_utils import PretrainedModel, dtype_guard, register_base_model from ..moe_gate import PretrainedMoEGate from ..moe_layer import MoEFlexTokenLayer, MoELayer -from ..utils import device_guard +from ..utils import cast_if_needed, device_guard from . import fp8_linear as linear_utils from .configuration import DeepseekV2Config + +FA_VERSION = int(os.getenv("FA_VERSION", 2)) + +from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import WeightGradStore + +from ..fp8_utils import ( + FP8Linear, + FP8LinearFunctionBase, + cache_fp8_weight, + set_parameter_color, +) from .fp8_linear import Linear +DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true" +DSV3_USE_ATTEN_RECOMPUTE = os.getenv("DSV3_USE_ATTEN_RECOMPUTE", "False").lower() == "true" + +Linear = FP8Linear if DSV3_USE_FP8_GEMM else Linear + +try: + import fused_ln + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + + +try: + from paddle.incubate.nn.functional import fused_partial_rope +except ImportError: + fused_partial_rope = None + + __all__ = [ "DeepseekV2LMHead", "DeepseekV2PretrainingCriterion", @@ -84,8 +121,54 @@ "DeepseekV2ForSequenceClassification", "DeepseekV2Model", "DeepseekV2PretrainedModel", + "set_global_step", + "get_global_step", ] +global_step = 0 + + +def set_global_step(cur_step): + global global_step + global_step = cur_step + + +def get_global_step(): + global global_step + return global_step + + +def rms_norm_fused(x_in, w, eps, use_fast_ln=False): + if use_fast_ln: + fast_ln = try_import("fast_ln") + return fast_ln.fast_rms_norm(x_in, w, eps)[0] + else: + fused_ln = try_import("fused_ln") + return fused_ln.fused_rms_norm(x_in, w, eps)[0] + + +def fusion_rms_norm(hidden_states, weight, variance_epsilon, use_fast_ln=False): + if get_env_device() == "npu": + return paddle.base.core.eager._run_custom_op("rms_norm_npu", hidden_states, weight, variance_epsilon)[0] + if get_env_device() == "mlu": + return paddle.base.core.eager._run_custom_op("rms_norm_mlu", hidden_states, weight, variance_epsilon)[0] + elif get_env_device() == "gcu": + return paddle.base.core.eager._run_custom_op("rms_norm_gcu", hidden_states, weight, variance_epsilon)[0] + elif get_env_device() == "intel_hpu": + return paddle.incubate.nn.functional.fused_rms_norm( + hidden_states, weight, None, variance_epsilon, hidden_states.dim() - 1 + )[0] + elif get_env_device() == "xpu": + try: + import paddle_xpu_nn # noqa: F821 + + return paddle_xpu_nn.xpu_rms_norm(hidden_states, weight, variance_epsilon)[0] + except ImportError: + raise NotImplementedError( + f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature" + ) + return rms_norm_fused(hidden_states, weight, variance_epsilon, use_fast_ln) + def get_triangle_upper_mask(x, mask=None): if mask is not None: @@ -129,7 +212,46 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int): return assignment_list -def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): +class LMHeadFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, weight, transpose_y): + out = paddle.matmul(x, weight, transpose_y=transpose_y) + + ctx.save_for_backward(x, weight, transpose_y) + return out + + @staticmethod + def backward(ctx, dout): + if dout.dtype == paddle.float32: + dout = dout.cast(paddle.bfloat16) + + x, weight, transpose_y = ctx.saved_tensor() + + dx = paddle.matmul(dout, weight, transpose_y=not transpose_y) + if transpose_y: + with paddle.amp.auto_cast(False): + paddle._C_ops.fused_linear_param_grad_add( + dout.reshape([-1, dout.shape[-1]]), + x.reshape([-1, x.shape[-1]]), + weight.main_grad, + None, + True, + False, + ) + else: + with paddle.amp.auto_cast(False): + paddle._C_ops.fused_linear_param_grad_add( + x.reshape([-1, x.shape[-1]]), + dout.reshape([-1, dout.shape[-1]]), + weight.main_grad, + None, + True, + False, + ) + return dx, None + + +def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_output=True): is_fleet_init = True tensor_parallel_degree = 1 try: @@ -147,7 +269,7 @@ def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed: # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg' input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group) - logits = paddle.matmul(input_parallel, y, transpose_y=False) + logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y) if tensor_parallel_output: return logits @@ -155,7 +277,7 @@ def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): return paddle.distributed.collective._c_concat(logits, group=model_parallel_group) else: - logits = paddle.matmul(x, y, transpose_y=False) + logits = LMHeadFunction.apply(x, y, transpose_y=transpose_y) return logits @@ -328,17 +450,8 @@ def __init__(self, config: DeepseekV2Config, hidden_size=None, eps=1e-6, use_seq mark_as_sequence_parallel_parameter(self.weight) def forward(self, hidden_states): - if self.config.use_fused_rms_norm and get_env_device() == "xpu": - if self.weight.dtype != hidden_states.dtype: - hidden_states = paddle.cast(hidden_states, self.weight.dtype) - try: - import paddle_xpu_nn # noqa: F821 - - return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] - except ImportError: - raise NotImplementedError( - f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature" - ) + if self.config.use_fused_rms_norm: + return fusion_rms_norm(hidden_states, self.weight, self.variance_epsilon, self.config.use_fast_layer_norm) with paddle.amp.auto_cast(False): hidden_states = hidden_states.astype("float32") @@ -528,34 +641,37 @@ def __init__( super().__init__(dim, max_position_embeddings, base) def _set_cos_sin_cache(self, seq_len): - self.max_seq_len_cached = seq_len - dim = self.dim + with paddle.amp.auto_cast(False): + self.max_seq_len_cached = seq_len + dim = self.dim - freq_extra = 1.0 / (self.base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim)) - freq_inter = 1.0 / (self.scaling_factor * self.base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim)) + freq_extra = 1.0 / (self.base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim)) + freq_inter = 1.0 / ( + self.scaling_factor * self.base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim) + ) - low, high = yarn_find_correction_range( - self.beta_fast, - self.beta_slow, - dim, - self.base, - self.original_max_position_embeddings, - ) - inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) - self.inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) + self.inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask - t = paddle.arange(seq_len, dtype=paddle.float32) + t = paddle.arange(seq_len, dtype=paddle.float32) - freqs = paddle.outer(t, paddle.cast(self.inv_freq, dtype="float32")) + freqs = paddle.outer(t, paddle.cast(self.inv_freq, dtype="float32")) - _mscale = float( - yarn_get_mscale(self.scaling_factor, self.mscale) - / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) - ) + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) - emb = paddle.concat((freqs, freqs), axis=-1) - self.cos_cached = emb.cos() * _mscale - self.sin_cached = emb.sin() * _mscale + emb = paddle.concat((freqs, freqs), axis=-1) + self.cos_cached = emb.cos() * _mscale + self.sin_cached = emb.sin() * _mscale def rotate_half(x): @@ -592,7 +708,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, fuse_rope=False): b, s, h, d = k.shape k = k.reshape([b, s, h, d // 2, 2]).transpose([0, 1, 2, 4, 3]).reshape([b, s, h, d]) - if get_env_device() == "xpu" and fuse_rope: + if (get_env_device() == "xpu" or get_env_device() == "gpu") and fuse_rope: q_embed, k_embed, _ = fused_rotary_position_embedding( q, k, @@ -671,9 +787,83 @@ def forward(self, x): return down_proj +class FusedNormGateFunc(paddle.autograd.PyLayer): + """recompute of postnorm and gate""" + + _current_norm_output = None + _current_invar = None + + @classmethod + def set_temporary_vars(cls, norm_output, invar): + FusedNormGateFunc._current_norm_output = norm_output + FusedNormGateFunc._current_invar = invar + + @classmethod + def clear_temporary_vars(cls): + FusedNormGateFunc._current_norm_output = None + FusedNormGateFunc._current_invar = None + + @staticmethod + def forward(ctx, x, rms_norm_weight, moe_gate_weight, eps): + ctx.dtype = paddle.float32 + norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + with paddle.amp.auto_cast(False): + gate_logits = F.linear(cast_if_needed(norm_output, ctx.dtype), cast_if_needed(moe_gate_weight, ctx.dtype)) + + ctx.save_for_backward(x, rms_norm_weight, moe_gate_weight, eps) + return gate_logits, norm_output + + @staticmethod + def backward(ctx, d_gate_logits, d_norm_output): + x, rms_norm_weight, moe_gate_weight, eps = ctx.saved_tensor() + # recompute rmsnorm + norm_output = FusedNormGateFunc._current_norm_output + invar = FusedNormGateFunc._current_invar + if norm_output is None or invar is None: + norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + d_norm_output_linear, d_moe_gate_weight = paddle._C_ops.matmul_grad( + cast_if_needed(norm_output, ctx.dtype), + cast_if_needed(moe_gate_weight, ctx.dtype), + d_gate_logits, + False, + False, + ) + d_norm_output_linear, d_moe_gate_weight = cast_if_needed( + d_norm_output_linear, norm_output.dtype + ), cast_if_needed(d_moe_gate_weight, moe_gate_weight.dtype) + + d_norm_output = d_norm_output + d_norm_output_linear + dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, rms_norm_weight, invar, d_norm_output, eps) + + return dx, d_rms_norm_weight, d_moe_gate_weight + + +class TemporaryVarContext: + def __init__(self, norm_output, invar): + self.norm_output = norm_output + self.invar = invar + + def __enter__(self): + FusedNormGateFunc.set_temporary_vars(self.norm_output, self.invar) + + def __exit__(self, exc_type, exc_val, exc_tb): + FusedNormGateFunc.clear_temporary_vars() + + +def balance_expert_assignment(n, m, k): + assert k * n % m == 0 + matrix = paddle.zeros((n, m), dtype=paddle.int32) + for row in range(n): + start_col = row % m + for i in range(k): + col = (start_col + i) % m + matrix[row, col] = 1 + return matrix + + class FakeGate(paddle.autograd.PyLayer): @staticmethod - def forward(ctx, hidden_states, weight): + def forward(ctx, hidden_states, weight, fakse_gate_restrict_balance=False, num_experts_per_tok=8): expert_num = weight.shape[1] bsz, seq, _ = hidden_states.shape @@ -681,8 +871,12 @@ def forward(ctx, hidden_states, weight): ctx.x_dtype = hidden_states.dtype ctx.y_shape = weight.shape ctx.y_dtype = weight.dtype - - return paddle.randn([bsz, seq, expert_num]).cast(weight.dtype) + if fakse_gate_restrict_balance: + return paddle.reshape( + balance_expert_assignment(bsz * seq, expert_num, num_experts_per_tok), [bsz, seq, expert_num] + ) + else: + return paddle.randn([bsz, seq, expert_num]).cast(weight.dtype) @staticmethod def backward(ctx, grad_output): @@ -750,10 +944,9 @@ class AddAuxiliaryLoss(paddle.autograd.PyLayer): @staticmethod def forward(ctx, x, loss): - assert paddle.numel(loss) == 1 ctx.dtype = loss.dtype ctx.required_aux_loss = not loss.stop_gradient - return x + return x.clone() # clone to avoid inplace problem when using overlap @staticmethod def backward(ctx, grad_output): @@ -882,6 +1075,849 @@ def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor: return hidden_states.reshape([batch, slen, num_key_value_heads * n_rep, head_axis]) +@to_static(backend="CINN") +def qkv_pre_process_no_fuse( + q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids +): + bsz, q_len, _ = q.shape + + target_query_shape = [0, 0, num_heads, q_head_dim] + target_key_value_shape = [0, 0, num_heads, qk_nope_head_dim + v_head_dim] + + q = q.reshape(shape=target_query_shape) + q_nope = q[..., :qk_nope_head_dim] + q_pe = q[..., qk_nope_head_dim:] + + # DeepSeekV2 kv_lora_rank+qk_rope_head_dim=512+64 + + kv = kv.reshape(shape=target_key_value_shape) + + k_pe = k_pe.reshape([-1, q_len, 1, qk_rope_head_dim]).expand([-1, q_len, num_heads, qk_rope_head_dim]) + + # self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim = 128+64 + # self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim) = config.qk_nope_head_dim + self.v_head_dim = 128+128 + k_nope = kv[..., :qk_nope_head_dim] + value_states = kv[..., qk_nope_head_dim:] + + kv_seq_len = value_states.shape[1] + + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids, False) + + query_states = paddle.concat([q_nope, q_pe], axis=-1) + key_states = paddle.concat([k_nope, k_pe], axis=-1) + + return query_states, key_states, value_states + + +@to_static(backend="CINN") +def rearrange_kv(kv, k_pe, qk_nope_head_dim, num_heads): + k_nope = kv[..., :qk_nope_head_dim] + value_states = kv[..., qk_nope_head_dim:] + + k_pe = k_pe.expand([k_pe.shape[0], k_pe.shape[1], num_heads, k_pe.shape[3]]) + key_states = paddle.concat([k_nope, k_pe], axis=-1) + + return key_states, value_states + + +def qkv_pre_process( + q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids +): + if (fused_partial_rope is None) or (position_ids is not None): + return qkv_pre_process_no_fuse( + q, + kv, + k_pe, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + ) + + bsz, q_len, _ = q.shape + + target_query_shape = [0, 0, num_heads, q_head_dim] + target_key_value_shape = [0, 0, num_heads, qk_nope_head_dim + v_head_dim] + + q = q.reshape(shape=target_query_shape) + kv = kv.reshape(shape=target_key_value_shape) + k_pe = k_pe.reshape([-1, q_len, 1, qk_rope_head_dim]) + + value_states = kv[..., qk_nope_head_dim:] + + kv_seq_len = value_states.shape[1] + + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + + query_states = fused_partial_rope(q, cos, sin) + k_pe = fused_partial_rope(k_pe, cos, sin) + + key_states, value_states = rearrange_kv(kv, k_pe, qk_nope_head_dim, num_heads) + + return query_states, key_states, value_states + + +def manul_fwd( + q_init, + kv_init, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, +): + + q_ln_t, q_ln_invar = fused_ln.fused_rms_norm(q_init, q_ln_weight, eps) + q = paddle.matmul(q_ln_t, q_up_weight) + + compressed_kv, k_pe = paddle.split(kv_init, [kv_lora_rank, qk_rope_head_dim], axis=-1) + + kv_ln_t, kv_ln_invar = fused_ln.fused_rms_norm(compressed_kv, kv_ln_weight, eps) + + kv = paddle.matmul(kv_ln_t, kv_up_weight) + + query_states, key_states, value_states = qkv_pre_process( + q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids + ) + + q_head_dim = query_states.shape[-1] + softmax_scale = softmax_scale * (q_head_dim**0.5) + query_states = query_states * softmax_scale + + attn_out, _, softmax_lse, seed_offset = _C_ops.flash_attn( + query_states, + key_states, + query_states, + None, + None, + 0.0, + True, + False, + False, + "", + ) + + return attn_out + + +class MemroyRecomputeAttnFunc(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx, + q_init, + kv_init, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + recompute_fa3=False, + ): + + bsz = q_init.shape[0] + q_ln_t, q_ln_invar = fused_ln.fused_rms_norm(q_init, q_ln_weight, eps) + # q = paddle.matmul(q_ln_t, q_up_weight) + q_orig_shape = q_ln_t.shape + q = FP8LinearFunctionBase.compute_fp8_linear( + q_ln_t.reshape([-1, q_orig_shape[-1]]), q_up_weight, weight_transpose=True, return_transpose_only=True + ) + q = q.reshape(q_orig_shape[:-1] + [q_up_weight.shape[-1]]) + + compressed_kv, k_pe = paddle.split(kv_init, [kv_lora_rank, qk_rope_head_dim], axis=-1) + + kv_ln_t, kv_ln_invar = fused_ln.fused_rms_norm(compressed_kv, kv_ln_weight, eps) + # kv = paddle.matmul(kv_ln_t, kv_up_weight) + kv_orig_shape = kv_ln_t.shape + kv = FP8LinearFunctionBase.compute_fp8_linear( + kv_ln_t.reshape([-1, kv_orig_shape[-1]]), kv_up_weight, weight_transpose=True, return_transpose_only=True + ) + kv = kv.reshape(kv_orig_shape[:-1] + [kv_up_weight.shape[-1]]) + + query_states, key_states, value_states = qkv_pre_process( + q, + kv, + k_pe, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + ) + + q_head_dim = query_states.shape[-1] + + if FA_VERSION == 2: + softmax_scale = softmax_scale * (q_head_dim**0.5) + query_states = query_states * softmax_scale + kv_seq_len = value_states.shape[1] + v_num_heads = value_states.shape[2] + value_padding = paddle.zeros( + [bsz, kv_seq_len, v_num_heads, q_head_dim - v_head_dim], + dtype=value_states.dtype, + ) + value_states_pad = paddle.concat([value_states, value_padding], axis=-1) + + attn_out, _, softmax_lse, seed_offset = _C_ops.flash_attn( + query_states, + key_states, + value_states_pad, + None, + None, + 0.0, + True, + False, + False, + "", + ) + + elif FA_VERSION == 3: + attn_out, softmax_lse = _C_ops.flash_attn_v3( + query_states, + key_states, + value_states, + None, # q_v_ + None, # q_descale_ + None, # k_descale_ + None, # v_descale_ + softmax_scale, + True, + -1, # window_size_left + -1, # window_size_right + 0.0, # softcap + 1, # num_splits + False, # manual_set_pack_gqa + False, # pack_gqa_ + 0, # sm_margin + ) + else: + assert False, f"invalid {FA_VERSION=}" + + if FA_VERSION == 2: + ctx.save_for_backward( + q_init, + kv_init, + attn_out, + softmax_lse, + seed_offset, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + ) + elif FA_VERSION == 3: + if recompute_fa3: + ctx.save_for_backward( + q_init, + kv_init, + None, + None, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + recompute_fa3, + ) + else: + ctx.save_for_backward( + q_init, + kv_init, + attn_out, + softmax_lse, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + recompute_fa3, + ) + else: + assert False, f"invalid {FA_VERSION=}" + + return attn_out + + @staticmethod + def backward(ctx, dout): + if FA_VERSION == 2: + ( + q_init, + kv_init, + attn_out, + softmax_lse, + seed_offset, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + ) = ctx.saved_tensor() + elif FA_VERSION == 3: + ( + q_init, + kv_init, + attn_out, + softmax_lse, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + recompute_fa3, + ) = ctx.saved_tensor() + else: + assert False, f"invalid {FA_VERSION=}" + + if FA_VERSION == 2: + assert not recompute_fa3 + assert attn_out is not None and softmax_lse is not None + if FA_VERSION == 3 and not recompute_fa3: + assert attn_out is not None and softmax_lse is not None + + q_ln_t, q_ln_invar = fused_ln.fused_rms_norm(q_init, q_ln_weight, eps) + + q_ln_fp8, q_ln_scale, q_ln_trans_fp8, q_ln_trans_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + q_ln_t.reshape([-1, q_ln_t.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + ) + + q_orig_shape = q_ln_t.shape + q = FP8LinearFunctionBase.compute_fp8_linear( + (q_ln_fp8, q_ln_scale), q_up_weight, weight_transpose=True, return_transpose_only=True + ) + q = q.reshape(q_orig_shape[:-1] + [q_up_weight.shape[-1]]) + + compressed_kv, k_pe = paddle.split(kv_init, [kv_lora_rank, qk_rope_head_dim], axis=-1) + + kv_ln_t, kv_ln_invar = fused_ln.fused_rms_norm(compressed_kv, kv_ln_weight, eps) + + kv_ln_fp8, kv_ln_scale, kv_ln_trans_fp8, kv_ln_trans_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + kv_ln_t.reshape([-1, kv_ln_t.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + ) + kv_orig_shape = kv_ln_t.shape + kv = FP8LinearFunctionBase.compute_fp8_linear( + (kv_ln_fp8, kv_ln_scale), kv_up_weight, weight_transpose=True, return_transpose_only=True + ) + kv = kv.reshape(kv_orig_shape[:-1] + [kv_up_weight.shape[-1]]) + + paddle.base.core._set_has_grad(True) + q.stop_gradient = False + kv.stop_gradient = False + k_pe.stop_gradient = False + query_states, key_states, value_states = qkv_pre_process( + q, + kv, + k_pe, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + ) + + if FA_VERSION == 2: + q_head_dim = query_states.shape[-1] + query_states = query_states * softmax_scale + + bsz = value_states.shape[0] + kv_seq_len = value_states.shape[1] + v_num_heads = value_states.shape[2] + value_padding = paddle.zeros( + [bsz, kv_seq_len, v_num_heads, q_head_dim - v_head_dim], + dtype=value_states.dtype, + ) + value_states_pad = paddle.concat([value_states, value_padding], axis=-1) + + with paddle.no_grad(): + + q_grad, k_grad, v_grad = _C_ops.flash_attn_grad( + query_states, + key_states, + value_states_pad, + attn_out, + softmax_lse.view("bfloat16"), + seed_offset, + None, + dout, + 0.0, + True, + ) + + v_grad = v_grad[..., :v_head_dim] + q_grad = q_grad * softmax_scale + elif FA_VERSION == 3: + # recompute fa3 + if recompute_fa3: + with paddle.no_grad(): + attn_out, softmax_lse = _C_ops.flash_attn_v3( + query_states, + key_states, + value_states, + None, # q_v_ + None, # q_descale_ + None, # k_descale_ + None, # v_descale_ + softmax_scale, + True, + -1, # window_size_left + -1, # window_size_right + 0.0, # softcap + 1, # num_splits + False, # manual_set_pack_gqa + False, # pack_gqa_ + 0, # sm_margin + ) + with paddle.no_grad(): + q_grad, k_grad, v_grad = _C_ops.flash_attn_v3_grad( + query_states, + key_states, + value_states, + attn_out, + softmax_lse.view("bfloat16"), + dout, + softmax_scale, + True, + -1, + -1, + 0.0, + 0, + ) + else: + assert False, f"invalid {FA_VERSION=}" + + d_q, d_kv, d_k_pe = paddle.grad( + outputs=[query_states, key_states, value_states], + inputs=[q, kv, k_pe], + grad_outputs=[q_grad, k_grad, v_grad], + create_graph=False, + retain_graph=False, + ) + + paddle.base.core._set_has_grad(False) + + # call up proj + if hasattr(kv_up_weight, "main_grad"): + d_kv_fp8, d_kv_scale, d_kv_t_fp8, d_kv_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + d_kv.reshape([-1, d_kv.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + ) + + d_kv_ln_t = FP8LinearFunctionBase.compute_fp8_linear( + (d_kv_fp8, d_kv_scale), kv_up_weight, weight_transpose=False + ) + d_kv_ln_t = d_kv_ln_t.reshape(d_kv.shape[:-1] + [kv_up_weight.shape[0]]) + + def kv_up_weight_grad(kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_scale, kv_up_weight): + FP8LinearFunctionBase.kitchen_gemm( + kv_ln_trans_fp8, + kv_ln_trans_scale, + d_kv_t_fp8, + d_kv_t_scale, + True, + True, + kv_up_weight.main_grad, + paddle.float32, + ) + + if WeightGradStore.enabled: + + WeightGradStore.put( + partial( + kv_up_weight_grad, kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_scale, kv_up_weight + ) + ) + else: + kv_up_weight_grad(kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_scale, kv_up_weight) + + d_kv_up_weight = None + + else: + d_kv_ln_t, d_kv_up_weight = _C_ops.matmul_grad(kv_ln_t, kv_up_weight, d_kv, False, False) + + d_compressed_kv, d_kv_ln_weight = fused_ln.fused_rms_norm_grad_func( + compressed_kv, kv_ln_weight, kv_ln_invar, d_kv_ln_t, eps + ) + + d_kv_init = paddle.concat([d_compressed_kv, d_k_pe], axis=-1) + + if hasattr(q_up_weight, "main_grad"): + + d_q_fp8, d_q_scale, d_q_t_fp8, d_q_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + d_q.reshape([-1, d_q.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + ) + # d_q_ln_t = paddle.matmul(d_q, q_up_weight, transpose_y=True) + + d_q_ln_t = FP8LinearFunctionBase.compute_fp8_linear( + (d_q_fp8, d_q_scale), q_up_weight, weight_transpose=False + ) + d_q_ln_t = d_q_ln_t.reshape(d_q.shape[:-1] + [q_up_weight.shape[0]]) + + def q_up_weight_grad(q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q_up_weight): + FP8LinearFunctionBase.kitchen_gemm( + q_ln_trans_fp8, + q_ln_trans_scale, + d_q_t_fp8, + d_q_t_scale, + True, + True, + q_up_weight.main_grad, + paddle.float32, + ) + + if WeightGradStore.enabled: + WeightGradStore.put( + partial(q_up_weight_grad, q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q_up_weight) + ) + else: + q_up_weight_grad(q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q_up_weight) + + d_q_up_weight = None + + else: + d_q_ln_t, d_q_up_weight = _C_ops.matmul_grad(q_ln_t, q_up_weight, d_q, False, False) + + d_q_init, d_q_ln_weight = fused_ln.fused_rms_norm_grad_func(q_init, q_ln_weight, q_ln_invar, d_q_ln_t, eps) + + return d_q_init, d_kv_init, d_q_ln_weight, d_kv_ln_weight, d_q_up_weight, d_kv_up_weight + + +class MemroyRecomputeAttn(paddle.nn.Layer): + def __init__( + self, + q_norm_hidden_size, + kv_norm_hidden_size, + q_up_in_dim, + q_up_out_dim, + kv_up_in_dim, + kv_up_out_dim, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + eps, + kv_lora_rank, + softmax_scale, + recompute_fa3=False, + ) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.q_ln_weight = paddle.create_parameter( + shape=[q_norm_hidden_size], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(1.0), + ) + self.kv_ln_weight = paddle.create_parameter( + shape=[kv_norm_hidden_size], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(1.0), + ) + + self.q_up_weight = self.create_parameter( + shape=[q_up_in_dim, q_up_out_dim], + dtype=self._dtype, + is_bias=False, + ) + + self.kv_up_weight = self.create_parameter( + shape=[kv_up_in_dim, kv_up_out_dim], + dtype=self._dtype, + is_bias=False, + ) + ( + self.rotary_emb, + self.num_heads, + self.q_head_dim, + self.qk_nope_head_dim, + self.v_head_dim, + self.qk_rope_head_dim, + self.eps, + self.kv_lora_rank, + self.softmax_scale, + self.recompute_fa3, + ) = ( + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + eps, + kv_lora_rank, + softmax_scale, + recompute_fa3, + ) + set_parameter_color([self.q_up_weight, self.kv_up_weight], "memory_attn") + + def fp8_quant_weight(self, quant_transpose=None): + cache_fp8_weight(self.q_up_weight, quant_transpose=quant_transpose) + cache_fp8_weight(self.kv_up_weight, quant_transpose=quant_transpose) + + def forward(self, q_init, kv_init, position_ids): + + seq_len = q_init.shape[1] + + if self.rotary_emb.max_seq_len_cached is None or seq_len > self.rotary_emb.max_seq_len_cached: + self.rotary_emb._set_cos_sin_cache(seq_len) + + return MemroyRecomputeAttnFunc.apply( + q_init, + kv_init, + self.q_ln_weight, + self.kv_ln_weight, + self.q_up_weight, + self.kv_up_weight, + self.rotary_emb, + self.num_heads, + self.q_head_dim, + self.qk_nope_head_dim, + self.v_head_dim, + self.qk_rope_head_dim, + position_ids, + self.eps, + self.kv_lora_rank, + self.softmax_scale, + recompute_fa3=self.recompute_fa3, + ) + + +class FusedRMSLinearFunc(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, rms_norm_weight, q_down_weight, kv_down_weight, eps): + + hidden_states, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + + h_fp8, h_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + hidden_states.reshape([-1, hidden_states.shape[-1]]), output_scale_transpose=True, quant_method="1x128" + ) + + h_orig_shape = hidden_states.shape + q = FP8LinearFunctionBase.compute_fp8_linear( + (h_fp8, h_scale), q_down_weight, weight_transpose=True, return_transpose_only=True + ) + q = q.reshape(h_orig_shape[:-1] + [q_down_weight.shape[-1]]) + + kv = paddle.matmul(hidden_states, kv_down_weight) + + ctx.save_for_backward(x, rms_norm_weight, q_down_weight, kv_down_weight) + ctx.eps = eps + return q, kv + + @staticmethod + def backward(ctx, d_q, d_kv): + x, rms_norm_weight, q_down_weight, kv_down_weight = ctx.saved_tensor() + eps = ctx.eps + hidden_states, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + + h_t_fp8, h_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + hidden_states.reshape([-1, hidden_states.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + return_transpose_only=True, + ) + + h_grad, d_kv_down_weight = _C_ops.matmul_grad(hidden_states, kv_down_weight, d_kv, False, False) + + if hasattr(q_down_weight, "main_grad"): + d_q_fp8, d_q_scale, d_q_t_fp8, d_q_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + d_q.reshape([-1, d_q.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + ) + FP8LinearFunctionBase.compute_fp8_linear( + (d_q_fp8, d_q_scale), q_down_weight, weight_transpose=False, out=h_grad.view([-1, h_grad.shape[-1]]) + ) + + def q_down_weight_grad(h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight): + FP8LinearFunctionBase.kitchen_gemm( + h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, True, True, q_down_weight.main_grad, paddle.float32 + ) + + if WeightGradStore.enabled: + WeightGradStore.put( + partial(q_down_weight_grad, h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight) + ) + else: + q_down_weight_grad(h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight) + + d_q_down_weight = None + + else: + h_grad_0, d_q_down_weight = _C_ops.matmul_grad(hidden_states, q_down_weight, d_q, False, False) + h_grad = h_grad + h_grad_0 + + dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, rms_norm_weight, invar, h_grad, eps) + + return dx, d_rms_norm_weight, d_q_down_weight, d_kv_down_weight + + +class FusedRMSLinear(paddle.nn.Layer): + def __init__(self, hidden_size, q_out_dim, kv_outdim, eps=1e-6) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.rms_norm_weight = paddle.create_parameter( + shape=[hidden_size], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(1.0), + ) + + self.q_down_weight = self.create_parameter( + shape=[hidden_size, q_out_dim], + dtype=self._dtype, + is_bias=False, + ) + + self.kv_down_weight = self.create_parameter( + shape=[hidden_size, kv_outdim], + dtype=self._dtype, + is_bias=False, + ) + self.eps = eps + set_parameter_color([self.q_down_weight], "rms_linear") + + def fp8_quant_weight(self, quant_transpose=None): + cache_fp8_weight(self.q_down_weight, quant_transpose=quant_transpose) + + def forward(self, x): + + return FusedRMSLinearFunc.apply(x, self.rms_norm_weight, self.q_down_weight, self.kv_down_weight, self.eps) + + +class FusedRMSLinearSingleFunc(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, rms_norm_weight, linear_weight, eps): + + hidden_states, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + q = paddle.matmul(hidden_states, linear_weight) + + ctx.save_for_backward(x, rms_norm_weight, linear_weight, eps) + return q + + @staticmethod + def backward(ctx, d_q, d_kv): + x, rms_norm_weight, linear_weight, eps = ctx.saved_tensor() + hidden_states, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + + h_grad, d_linear_weight = _C_ops.matmul_grad(hidden_states, linear_weight, d_q, False, False) + + dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, rms_norm_weight, invar, h_grad, eps) + + return dx, d_rms_norm_weight, d_linear_weight + + +class FusedRMSLinearSingle(paddle.nn.Layer): + def __init__(self, hidden_size, q_out_dim, kv_outdim, eps=1e-6) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.rms_norm_weight = paddle.create_parameter( + shape=[hidden_size], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(1.0), + ) + + self.linear_weight = self.create_parameter( + shape=[hidden_size, q_out_dim], + dtype=self._dtype, + is_bias=False, + ) + self.eps = eps + + def forward(self, x): + + return FusedRMSLinearFunc.apply(x, self.rms_norm_weight, self.linear_weight, self.eps) + + # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 class DeepseekV2Attention(nn.Layer): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -1890,6 +2926,25 @@ def forward( ) +class FastCrossEntropyFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, preds, labels): + softmax_val, loss = paddle._C_ops.cross_entropy_with_softmax(preds, labels, False, True, False, -100, -1) + + ctx.save_for_backward(labels, softmax_val) + return loss + + @staticmethod + def backward(ctx, dout): + labels, softmax_val = ctx.saved_tensor() + + preds_grad = paddle.incubate.nn.functional.cross_entropy_with_softmax_bwd_w_downcast( + labels, softmax_val.cast(paddle.float32), dout.cast(paddle.float32) + ) + + return preds_grad, None + + class DeepseekV2PretrainingCriterion(nn.Layer): """ Criterion for Mixtral. @@ -1956,7 +3011,7 @@ def add_loss(main_loss, loss): class DeepseekV2LMHead(nn.Layer): - def __init__(self, config: DeepseekV2Config): + def __init__(self, config: DeepseekV2Config, embedding_weight=None): super(DeepseekV2LMHead, self).__init__() self.config = config @@ -1970,11 +3025,16 @@ def __init__(self, config: DeepseekV2Config): else: vocab_size = config.vocab_size - self.weight = self.create_parameter( - shape=[config.hidden_size, vocab_size], - dtype=paddle.get_default_dtype(), - default_initializer=nn.initializer.XavierNormal(1.0), - ) + if embedding_weight is not None: + self.transpose_y = True + self.weight = embedding_weight + else: + self.transpose_y = False + self.weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.XavierNormal(1.0), + ) # Must set distributed attr for Tensor Parallel ! self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False if get_env_device() == "xpu": @@ -2004,7 +3064,9 @@ def forward(self, hidden_states, tensor_parallel_output=None): training=self.training, ) else: - logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) + logits = parallel_matmul( + hidden_states, self.weight, transpose_y=self.transpose_y, tensor_parallel_output=tensor_parallel_output + ) return logits def extra_repr(self): diff --git a/paddleformers/transformers/deepseek_v2/modeling_fast.py b/paddleformers/transformers/deepseek_v2/modeling_fast.py new file mode 100644 index 00000000000..6748112e4c1 --- /dev/null +++ b/paddleformers/transformers/deepseek_v2/modeling_fast.py @@ -0,0 +1,1674 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 DeepSeek. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Paddle DeepSeek model.""" + +from __future__ import annotations + +import contextlib +import math +import os +import warnings +from functools import partial +from typing import List, Optional, Tuple, Union + +import paddle +import paddle.distributed as dist +import paddle.distributed.fleet.meta_parallel as mpu +import paddle.nn.functional as F +from paddle import Tensor, nn +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.distributed.fleet.recompute.recompute import recompute + +try: + from paddle.incubate.nn.functional import fused_rotary_position_embedding +except ImportError: + fused_rotary_position_embedding = None + +try: + from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + GatherOp, + ScatterOp, + ) +except: + pass + + +try: + from paddle.nn.functional.flash_attention import flash_attention +except: + flash_attention = None +from paddleformers.transformers.model_utils import dtype_guard + +from ...utils.initializer import kaiming_uniform_ +from ...utils.log import logger +from ...utils.tools import get_env_device +from ..activations import ACT2FN +from ..conversion_utils import StateDictNameMapping, init_name_mappings +from ..llama.modeling import get_use_casual_mask +from ..model_outputs import BaseModelOutputWithPastAndMTP +from ..model_utils import PretrainedModel, register_base_model +from ..moe_gate import PretrainedMoEGate +from ..moe_layer import MoELayer +from . import fp8_linear as linear_utils +from .configuration import DeepseekV2Config + +FA_VERSION = int(os.getenv("FA_VERSION", 2)) + +from ..fp8_utils import ( + FP8KeepXLinear, + FP8Linear, + FP8LinearFunction, + FP8Mlp, + set_parameter_color, +) +from .fp8_linear import Linear + +DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true" +DSV3_USE_ATTEN_RECOMPUTE = os.getenv("DSV3_USE_ATTEN_RECOMPUTE", "False").lower() == "true" + +Linear = FP8Linear if DSV3_USE_FP8_GEMM else Linear + +try: + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + + +try: + from paddle.incubate.nn.functional import fused_partial_rope +except ImportError: + fused_partial_rope = None + +from .modeling import ( + AddAuxiliaryLoss, + DeepseekV2DynamicNTKScalingRotaryEmbedding, + DeepseekV2LinearScalingRotaryEmbedding, + DeepseekV2PretrainingCriterion, + DeepseekV2RMSNorm, + DeepseekV2RotaryEmbedding, + DeepseekV2YarnRotaryEmbedding, + FakeGate, + FastCrossEntropyFunction, + FusedNormGateFunc, + FusedRMSLinear, + LMHeadFunction, + MemroyRecomputeAttn, + _expand_2d_mask, + _make_causal_mask, + apply_rotary_pos_emb, + is_casual_mask, + scaled_dot_product_attention, + set_global_step, + yarn_get_mscale, +) + +__all__ = [ + "DeepseekV2PretrainingCriterionFast", + "DeepseekV2ModelFast", + "DeepseekV2PretrainedModelFast", +] + + +def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_output=True): + is_fleet_init = True + tensor_parallel_degree = 1 + try: + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + tensor_parallel_degree = hcg.get_model_parallel_world_size() + except AttributeError: + is_fleet_init = False + + if paddle.in_dynamic_mode(): + y_is_distributed = y.is_distributed + else: + y_is_distributed = tensor_parallel_degree > 1 + + if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed: + # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg' + input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group) + logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y) + + if tensor_parallel_output: + return logits + + return paddle.distributed.collective._c_concat(logits, group=model_parallel_group) + + else: + logits = LMHeadFunction.apply(x, y, transpose_y=transpose_y) + return logits + + +class DeepseekV2MLP(nn.Layer): + def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None, is_moe=False): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + self.fuse_attention_ffn = config.fuse_attention_ffn + + def linear_dtype_gaurd(): + if config.use_fp8: + return dtype_guard("float8_e4m3fn") + else: + return contextlib.nullcontext() + + if config.sequence_parallel: + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear + else: + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear + + with linear_dtype_gaurd(): + if config.tensor_parallel_degree > 1 and not is_moe: + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + input_is_parallel=True, + has_bias=False, + ) + else: + if config.fuse_attention_ffn: + self.gate_up_fused_proj = Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) + else: + self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.fuse_attention_ffn: + x = swiglu(self.gate_up_fused_proj(x)) + else: + x = swiglu(self.gate_proj(x), self.up_proj(x)) + out = self.down_proj(x) + return out + + +class MoEGate(PretrainedMoEGate): + def __init__( + self, + config, + num_experts, + expert_hidden_size, + using_post_norm_recompute=False, + norm_weight=None, + norm_eps=None, + **kwargs + ): + super().__init__(config, num_experts, expert_hidden_size, **kwargs) + # [hidden_size, n_expert] + + self.scoring_func = config.scoring_func + self.topk_method = config.topk_method + + self.weight = paddle.create_parameter( + shape=[expert_hidden_size, num_experts], + dtype=paddle.float32, + is_bias=False, + # default_initializer=nn.initializer.Constant(1.0), + ) + + self.config = config + self.using_post_norm_recompute = using_post_norm_recompute + + if config.topk_method == "noaux_tc": + self.e_score_correction_bias = paddle.create_parameter( + shape=[num_experts], + dtype=paddle.float32, + default_initializer=nn.initializer.Constant(0.0), + ) + self.e_score_correction_bias.is_distributed = True + + if self.using_post_norm_recompute: + assert norm_weight is not None and norm_eps is not None + self.norm_weight = norm_weight + self.norm_eps = norm_eps + self.using_flex_token = False + + def forward(self, hidden_states): + """ + Args: + hidden_states (_type_): [batch_size * seq_len, hidden_size] + """ + _, _, h_dim = hidden_states.shape + + # compute gating score + if self.using_post_norm_recompute: + logits, norm_out = FusedNormGateFunc.apply(hidden_states, self.norm_weight, self.weight, self.norm_eps) + if hasattr(self.config, "using_fake_gate") and self.config.using_fake_gate: + logits = FakeGate.apply( + hidden_states, + self.weight, + self.config.fakse_gate_restrict_balance, + self.config.num_experts_per_tok, + ) + else: + with paddle.amp.auto_cast(False): + hidden_states = hidden_states.cast(self.weight.dtype) + if hasattr(self.config, "using_fake_gate") and self.config.using_fake_gate: + logits = FakeGate.apply( + hidden_states, + self.weight, + self.config.fakse_gate_restrict_balance, + self.config.num_experts_per_tok, + ) + else: + logits = F.linear(hidden_states, self.weight, None) + + scores = self.gate_score_func(logits=logits) + scores = scores.cast(paddle.float32) + + # Compute all possible return values + if self.using_flex_token: + scores, routing_map, exp_counts, l_aux, l_zloss = self.topkgating_nodrop( + scores + ) # (scores, routing_map, exp_counts, l_aux, l_zloss) + ret = (scores, routing_map, l_aux, l_zloss) + else: + ret = self.topkgating(scores) # (capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss) + + # Append norm_out if needed + if self.using_post_norm_recompute: + ret = (*ret, norm_out) + + return ret + + +class DeepseekV2MoE(MoELayer): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config: DeepseekV2Config, norm_weight=None, norm_eps=None): + assert config.tensor_parallel_degree <= 1, "tensor_parallel_degree should be 1" + + self.using_post_norm_recompute = config.using_post_norm_recompute + if self.using_post_norm_recompute: + assert norm_weight is not None and norm_eps is not None + + gate = MoEGate( + config=config, + num_experts=config.n_routed_experts, + expert_hidden_size=config.hidden_size, + top_k=config.num_experts_per_tok, + topk_method=config.topk_method, + n_group=config.n_group, + topk_group=config.topk_group, + norm_topk_prob=config.norm_topk_prob, + routed_scaling_factor=config.routed_scaling_factor, + drop_tokens=False, + using_post_norm_recompute=self.using_post_norm_recompute, + norm_weight=norm_weight, + norm_eps=norm_eps, + ) + DeepseekV2MLPClass = FP8Mlp if DSV3_USE_FP8_GEMM else DeepseekV2MLP + + super().__init__( + config=config, + moe_num_experts=config.n_routed_experts, + expert_class=DeepseekV2MLPClass, + expert_kwargs={ + "config": config, + "intermediate_size": config.moe_intermediate_size, + "is_moe": True, + }, + gate=gate, + capacity=2.0, + moe_group="expert", + using_post_norm_recompute=self.using_post_norm_recompute, + ) + + if config.offline_quant_expert_weight and config.clear_origin_weight_when_offline_quant: + moe_grad_group = fleet.get_hybrid_communicate_group().expert_grad_comm_group + expert_w1_list = [expert.w1 for expert in self.experts if expert is not None] + expert_w2_list = [expert.w2 for expert in self.experts if expert is not None] + for p in expert_w1_list: + setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) + for p in expert_w2_list: + setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) + + self.alpha = config.aux_loss_alpha + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + if self.using_post_norm_recompute: + assert DeepseekV2MLPClass is FP8Mlp + self.shared_experts = DeepseekV2MLPClass( + config=config, + intermediate_size=intermediate_size, + is_moe=False, + using_post_norm_recompute=self.using_post_norm_recompute, + norm_weight=norm_weight, + norm_eps=norm_eps, + recompute_fwd_gate_up=True, + ) + else: + self.shared_experts = DeepseekV2MLPClass( + config=config, intermediate_size=intermediate_size, is_moe=False + ) + set_parameter_color([self.shared_experts.w1, self.shared_experts.w2], "shared_expert") + + def fp8_quant_weight(self, batch_mode=False, quant_transpose=None): + """Quantize weights in FP8 format. + + Args: + batch_mode: If True, quantize all weights in batch mode using the first expert's weights. + If False, quantize each expert's weights individually. + """ + + def quantize_weights(weight_list, weight_obj=None, quant_transpose=None): + """Helper function to quantize a list of weights.""" + if weight_obj is None: + weight_obj = weight_list[0] + if hasattr(weight_obj, "fp8_weight_stacked") or hasattr(weight_obj, "fp8_weight_stacked_transpose"): + return + + if quant_transpose is None: + fp8_weight, fp8_scale = paddle.incubate.nn.functional.fused_stack_transpose_quant( + weight_list, transpose=False + ) + setattr(weight_obj, "fp8_weight_stacked", fp8_weight) + setattr(weight_obj, "fp8_scale_stacked", fp8_scale) + + fp8_weight_t, fp8_scale_t = paddle.incubate.nn.functional.fused_stack_transpose_quant( + weight_list, transpose=True + ) + setattr(weight_obj, "fp8_weight_stacked_transpose", fp8_weight_t) + setattr(weight_obj, "fp8_scale_stacked_transpose", fp8_scale_t) + elif quant_transpose is False: + # Only quantize without transpose + fp8_weight, fp8_scale = paddle.incubate.nn.functional.fused_stack_transpose_quant( + weight_list, transpose=False + ) + setattr(weight_obj, "fp8_weight_stacked", fp8_weight) + setattr(weight_obj, "fp8_scale_stacked", fp8_scale) + elif quant_transpose is True: + # Only quantize with transpose + fp8_weight_t, fp8_scale_t = paddle.incubate.nn.functional.fused_stack_transpose_quant( + weight_list, transpose=True + ) + setattr(weight_obj, "fp8_weight_stacked_transpose", fp8_weight_t) + setattr(weight_obj, "fp8_scale_stacked_transpose", fp8_scale_t) + else: + raise ValueError("Invalid value for `quant_transpose`.") + + if batch_mode: + # Batch mode: process all experts' weights together + expert_w1_list = [expert.w1 for expert in self.experts if expert is not None] + expert_w2_list = [expert.w2 for expert in self.experts if expert is not None] + + if expert_w1_list: + quantize_weights(expert_w1_list, expert_w1_list[0], quant_transpose) + if expert_w2_list: + quantize_weights(expert_w2_list, expert_w2_list[0], quant_transpose) + else: + # Individual mode: process each expert's weights separately + for expert in self.experts: + if expert is not None: + quantize_weights([expert.w1], quant_transpose=quant_transpose) + quantize_weights([expert.w2], quant_transpose=quant_transpose) + + if self.config.n_shared_experts is not None: + self.shared_experts.fp8_quant_weight(quant_transpose) + + def forward(self, hidden_states): + if self.using_post_norm_recompute: + super().update_flex_token() + if self.using_flex_token: + probs, routing_map, l_aux, l_zloss, norm_out = self.router(hidden_states) + final_hidden_states, l_aux, l_zloss = super().forward( + norm_out, probs=probs, routing_map=routing_map, l_aux=l_aux, l_zloss=l_zloss + ) + else: + capacity, topk_weight, topk_ids, token_priority, l_aux, l_zloss, norm_out = self.gate(hidden_states) + final_hidden_states, l_aux, l_zloss = super().forward( + norm_out, + capacity=capacity, + topk_weight=topk_weight, + topk_ids=topk_ids, + token_priority=token_priority, + l_aux=l_aux, + l_zloss=l_zloss, + ) + final_hidden_states = self.post_process(hidden_states, final_hidden_states, l_aux) + else: + final_hidden_states, l_aux, l_zloss = super().forward(hidden_states) + final_hidden_states = self.post_process(hidden_states, final_hidden_states, l_aux) + return final_hidden_states + + def post_process(self, hidden_states, final_hidden_states, l_aux): + if self.training and self.alpha > 0.0: + l_aux = l_aux * self.alpha + final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, l_aux) + + if self.config.n_shared_experts is not None: + shared_expert_output = self.shared_experts(hidden_states) + final_hidden_states = final_hidden_states + shared_expert_output + return final_hidden_states + + +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 +class DeepseekV2Attention(nn.Layer): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False, recompute_fa3: bool = False): + super().__init__() + self.config = config + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.is_causal = True + self.fuse_rope = config.use_fused_rope + + if config.num_nextn_predict_layers > 0: + self.seq_length = config.seq_length - config.num_nextn_predict_layers + else: + self.seq_length = config.seq_length + self.sequence_parallel = config.sequence_parallel + + self.recompute_fa3 = recompute_fa3 + + self.input_layernorm = DeepseekV2RMSNorm(config) + + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + + def linear_dtype_gaurd(): + if config.use_fp8: + return dtype_guard("float8_e4m3fn") + else: + return contextlib.nullcontext() + + # Note (@DrownFish19): For tensor parallel we consider that q_a_proj and kv_a_proj_with_mqa + # are the small weight and cannot achieve performance gain. So we use the original + # linear layers. We use the tensor parallel linear layers for q_proj,q_b_proj and kv_b_proj + # for which are the large weight and can achieve performance gain. + + self._init_rope() + self.softmax_scale = self.q_head_dim ** (-0.5) + + # fmt: off + if self.config.tensor_parallel_degree > 1: + # for tensor parallel + if config.sequence_parallel: + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear + else: + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear + + if self.q_lora_rank is None: + with linear_dtype_gaurd(): + self.q_proj = ColumnParallelLinear(self.hidden_size, self.num_heads * self.q_head_dim, has_bias=False, gather_output=True) + else: + with linear_dtype_gaurd(): + self.q_a_proj = Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias) + self.q_b_proj = ColumnParallelLinear(config.q_lora_rank, self.num_heads * self.q_head_dim, has_bias=False, gather_output=True) + self.q_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.q_lora_rank, use_sequence_parallel=False) + + with linear_dtype_gaurd(): + self.kv_a_proj_with_mqa = paddle.nn.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias) + self.kv_b_proj = ColumnParallelLinear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), has_bias=False, gather_output=True) + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, self.hidden_size, has_bias=config.attention_bias, input_is_parallel=False) + self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank, use_sequence_parallel=False) + else: + # for without tensor parallel + if DSV3_USE_ATTEN_RECOMPUTE: + self.fused_rms_norm_linear = FusedRMSLinear(self.hidden_size, config.q_lora_rank, config.kv_lora_rank + config.qk_rope_head_dim, 1e-6) + kv_up_dim = self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim) + self.memory_recompute_att = MemroyRecomputeAttn(config.q_lora_rank, config.kv_lora_rank, config.q_lora_rank, self.num_heads * self.q_head_dim, config.kv_lora_rank, kv_up_dim, self.rotary_emb, self.num_heads, self.q_head_dim, self.qk_nope_head_dim, self.v_head_dim, self.qk_rope_head_dim, 1e-6, self.kv_lora_rank, self.softmax_scale, recompute_fa3=self.recompute_fa3) + self.o_proj = FP8KeepXLinear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias) + else: + + if self.q_lora_rank is None: + with linear_dtype_gaurd(): + self.q_proj = Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias_attr=False) + else: + with linear_dtype_gaurd(): + self.q_a_proj = Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias) + self.q_b_proj = Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias_attr=False) + self.q_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.q_lora_rank) + + with linear_dtype_gaurd(): + self.kv_a_proj_with_mqa = paddle.nn.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias) + self.kv_b_proj = Linear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias_attr=False) + self.o_proj = Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias) + self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank) + + # fmt: on + self.softmax_scale = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + self.attn_func = scaled_dot_product_attention + + def fp8_quant_weight(self, quant_transpose=None): + + if DSV3_USE_ATTEN_RECOMPUTE: + self.o_proj.fp8_quant_weight(quant_transpose=quant_transpose) + self.memory_recompute_att.fp8_quant_weight(quant_transpose=quant_transpose) + self.fused_rms_norm_linear.fp8_quant_weight(quant_transpose=quant_transpose) + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = DeepseekV2RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "yarn": + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = DeepseekV2YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: paddle.Tensor, seq_len: int, bsz: int): + return tensor.reshape([bsz, seq_len, self.num_heads, self.v_head_dim]).transpose([1, 0, 2, 3]) + + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.shape + + # DeepSeekV2 q_lora_rank=1536 + # DeepSeekV2-lite q_lora_rank=None + if DSV3_USE_ATTEN_RECOMPUTE: + + q_t1, compressed_kv = self.fused_rms_norm_linear(hidden_states) + + outputs = self.memory_recompute_att(q_t1, compressed_kv, position_ids) + + if self.v_head_dim * self.num_heads != outputs.shape[-1]: + outputs = outputs.reshape([bsz, q_len, self.num_heads, -1]) + outputs = outputs[..., : self.v_head_dim] + outputs = outputs.reshape([bsz, q_len, -1]) + else: + hidden_states = self.input_layernorm(hidden_states) + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + + if self.sequence_parallel: + target_query_shape = [-1, self.seq_length, self.num_heads, self.q_head_dim] + target_key_value_shape = [-1, self.seq_length, self.num_heads, self.qk_nope_head_dim + self.v_head_dim] + else: + target_query_shape = [0, 0, self.num_heads, self.q_head_dim] + target_key_value_shape = [0, 0, self.num_heads, self.qk_nope_head_dim + self.v_head_dim] + + q = q.reshape(shape=target_query_shape) + q_nope, q_pe = paddle.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) + + # DeepSeekV2 kv_lora_rank+qk_rope_head_dim=512+64 + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = paddle.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], axis=-1) + if self.sequence_parallel: + k_pe = GatherOp.apply(k_pe) + k_pe = k_pe.reshape([-1, q_len, 1, self.qk_rope_head_dim]).expand( + [-1, q_len, self.num_heads, self.qk_rope_head_dim] + ) + + # self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim = 128+64 + # self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim) = config.qk_nope_head_dim + self.v_head_dim = 128+128 + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).reshape(shape=target_key_value_shape) + + k_nope, value_states = paddle.split(kv, [self.qk_nope_head_dim, self.v_head_dim], axis=-1) + kv_seq_len = value_states.shape[1] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-3] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids, self.fuse_rope) + + query_states = paddle.concat([q_nope, q_pe], axis=-1) + key_states = paddle.concat([k_nope, k_pe], axis=-1) + + # [bs, seq_len, num_head, head_dim] + if past_key_value is not None: + # reuse k, v, self_attention + key_states = paddle.concat([past_key_value[0], key_states], axis=1) + value_states = paddle.concat([past_key_value[1], value_states], axis=1) + past_key_value = (key_states, value_states) if use_cache else None + + has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient) + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "core_attn" + ): + outputs = recompute( + self.attn_func, + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + softmax_scale=self.softmax_scale, + training=self.training, + sequence_parallel=self.sequence_parallel, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = self.attn_func( + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + softmax_scale=self.softmax_scale, + training=self.training, + sequence_parallel=self.sequence_parallel, + ) + if output_attentions: + attn_output, attn_weights = outputs + else: + attn_output = outputs + + # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim] + # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism. + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + outputs = (attn_output,) + + if output_attentions: + outputs += (attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class DeepseekV2DecoderLayer(nn.Layer): + def __init__( + self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute: bool = False, recompute_fa3: bool = False + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + self.using_post_norm_recompute = config.using_post_norm_recompute + + self.hidden_size = config.hidden_size + + self.self_attn = DeepseekV2Attention( + config=config, layerwise_recompute=layerwise_recompute, recompute_fa3=recompute_fa3 + ) + + DeepseekV2MLPClass = FP8Mlp if DSV3_USE_FP8_GEMM else DeepseekV2MLP + + self.input_layernorm = DeepseekV2RMSNorm(config) + self.post_attention_layernorm = DeepseekV2RMSNorm(config) + + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ): + self.mlp = ( + DeepseekV2MoE( + config, self.post_attention_layernorm.weight, self.post_attention_layernorm.variance_epsilon + ) + if config.using_post_norm_recompute + else DeepseekV2MoE(config) + ) + else: + self.mlp = DeepseekV2MLPClass(config, recompute_fwd_gate_up=True) + + def fp8_quant_weight(self, batch_mode=False, quant_transpose=None): + """fp8_quant_weight""" + if isinstance(self.mlp, DeepseekV2MoE): + # logger.info(f"fp8 quant weight for mlp {type(self.mlp)}") + self.mlp.fp8_quant_weight(batch_mode, quant_transpose=quant_transpose) + self.self_attn.fp8_quant_weight(quant_transpose=quant_transpose) + elif isinstance(self.mlp, FP8Mlp): + self.self_attn.fp8_quant_weight(quant_transpose=quant_transpose) + + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + """ + Args: + hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_axis)` + attention_mask (`paddle.Tensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + residual = hidden_states + + # Self Attention + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "full_attn" + ): + outputs = recompute( + self.self_attn, + hidden_states=hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + past_key_value=past_key_value, + use_cache=use_cache, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + **kwargs, + ) + else: + outputs = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + past_key_value=past_key_value, + use_cache=use_cache, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + **kwargs, + ) + + if type(outputs) is tuple: + hidden_states = outputs[0] + else: + hidden_states = outputs + + if output_attentions: + self_attn_weights = outputs[1] + + if use_cache: + present_key_value = outputs[2 if output_attentions else 1] + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + + if not (self.using_post_norm_recompute and isinstance(self.mlp, DeepseekV2MoE)): + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + def self_attn_compute(self, hidden_states, **kwargs): + residual = hidden_states + + # Self Attention + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "full_attn" + ): + outputs = recompute( + self.self_attn, + hidden_states=hidden_states, + position_ids=None, + attention_mask=None, + output_attentions=False, + past_key_value=None, + use_cache=False, + attn_mask_startend_row_indices=None, + **kwargs, + ) + else: + outputs = self.self_attn( + hidden_states=hidden_states, + position_ids=None, + attention_mask=None, + output_attentions=False, + past_key_value=None, + use_cache=False, + attn_mask_startend_row_indices=None, + **kwargs, + ) + + if type(outputs) is tuple: + hidden_states = outputs[0] + else: + hidden_states = outputs + + hidden_states = residual + hidden_states + + residual = hidden_states + + if not (self.using_post_norm_recompute and isinstance(self.mlp, DeepseekV2MoE)): + hidden_states = self.post_attention_layernorm(hidden_states) + + return hidden_states, residual + + def pre_dispatch_compute(self, hidden_states): + l_aux, l_zloss, intermediate_hidden_states, token_indices, token_probs = self.mlp.pre_dispatch_compute( + hidden_states + ) + + return l_aux, l_zloss, intermediate_hidden_states, token_indices, token_probs + + def expert_forward_compute(self, intermediate_hidden_states, dispatched_indices, dispatched_probs): + (global_input_tokens, token_permuted_indices, prob_permuted_indices) = self.mlp.post_dispatch_compute( + intermediate_hidden_states, dispatched_indices, dispatched_probs + ) + + expert_output = self.mlp.expert_forward(global_input_tokens) + + expert_output = self.mlp.pre_combine_compute( + expert_output, token_permuted_indices, prob_permuted_indices, dispatched_probs + ) + + return expert_output + + def post_combine_compute(self, residual, hidden_states, final_hidden_states, l_aux): + final_hidden_states = self.mlp.post_combine_compute(final_hidden_states) + + final_hidden_states = self.mlp.post_process(hidden_states, final_hidden_states, l_aux) + + final_hidden_states = residual + final_hidden_states + + outputs = (final_hidden_states,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class DeepseekV2MTPLayer(DeepseekV2DecoderLayer): + def __init__( + self, + config: DeepseekV2Config, + layer_idx: int, + layerwise_recompute: bool = False, + ): + super(DeepseekV2MTPLayer, self).__init__(config, layer_idx, layerwise_recompute) + + self.enorm = DeepseekV2RMSNorm(config) + self.hnorm = DeepseekV2RMSNorm(config) + self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias_attr=False) + + def forward( + self, + hidden_states: paddle.Tensor, + nextn_hidden_state: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + hidden_states = self.hnorm(hidden_states) + nextn_hidden_state = self.enorm(nextn_hidden_state) + + concat_h = paddle.concat([nextn_hidden_state, hidden_states], axis=-1) + hidden_states = FP8LinearFunction.apply(concat_h, self.eh_proj) + + layer_outputs = super(DeepseekV2MTPLayer, self).forward( + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + **kwargs, + ) + + if type(layer_outputs) is tuple: + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + return hidden_states + + +class DeepseekV2PretrainedModelFast(PretrainedModel): + config_class = DeepseekV2Config + base_model_prefix = "deepseek_v2" + _no_split_modules = ["DeepseekV2DecoderLayer"] + + def _get_model_flops(self, batch_size=1, seq_length=None, **kwargs): + from .mfu_utils import DeepSeekProjection + + # self._ + mfu_cal_proj = DeepSeekProjection(self.config) + if seq_length is None: + if hasattr(self.config, "seq_length"): + seq_length = self.config.seq_length + else: + seq_length = 2048 + + return mfu_cal_proj.get_num_flop_per_token() + + def _get_hardware_flops(self, *args, **kwargs): + return self._get_model_flops(*args, **kwargs) + + @classmethod + def _get_name_mappings(cls, config: DeepseekV2Config) -> list[StateDictNameMapping]: + mappings: list[StateDictNameMapping] = [] + model_mappings = [ + ["embed_tokens.weight"], + ["norm.weight"], + ] + # last one layer contains MTP (eagle) parameters for inference + for layer_index in range(config.num_hidden_layers + config.num_nextn_predict_layers): + layer_mappings = [ + [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.q_a_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.q_a_layernorm.weight"], + [f"layers.{layer_index}.self_attn.q_b_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.kv_a_proj_with_mqa.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.kv_a_layernorm.weight"], + [f"layers.{layer_index}.self_attn.kv_b_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + model_mappings.extend(layer_mappings) + + # MoE parameters + model_mappings.append([f"layers.{layer_index}.mlp.gate.weight", None, "transpose"]) + model_mappings.append([f"layers.{layer_index}.mlp.gate.e_score_correction_bias"]) + for expert_idx in range(config.n_routed_experts): + expert_mappings = [ + [f"layers.{layer_index}.mlp.experts.{expert_idx}.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.experts.{expert_idx}.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.experts.{expert_idx}.down_proj.weight", None, "transpose"], + ] + model_mappings.extend(expert_mappings) + model_mappings.append([f"layers.{layer_index}.mlp.shared_experts.gate_proj.weight", None, "transpose"]) + model_mappings.append([f"layers.{layer_index}.mlp.shared_experts.up_proj.weight", None, "transpose"]) + model_mappings.append([f"layers.{layer_index}.mlp.shared_experts.down_proj.weight", None, "transpose"]) + + # MTP (eagle) parameters for inference + if layer_index >= config.num_hidden_layers: + model_mappings.append([f"layers.{layer_index}.embed_tokens.weight"]) + model_mappings.append([f"layers.{layer_index}.enorm.weight"]) + model_mappings.append([f"layers.{layer_index}.hnorm.weight"]) + model_mappings.append([f"layers.{layer_index}.eh_proj.weight", None, "transpose"]) + model_mappings.append([f"layers.{layer_index}.shared_head.norm.weight"]) + model_mappings.append([f"layers.{layer_index}.shared_head.head.weight", None, "transpose"]) + + init_name_mappings(mappings=model_mappings) + if cls.base_model_class.__name__ not in config.architectures: + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = f"{cls.base_model_prefix}." + mapping[1] + if not config.tie_word_embeddings: + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + + @classmethod + def _get_tensor_parallel_mappings(cls, config: DeepseekV2Config, is_split=True): + from paddleformers.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + + base_actions = { + # Row Linear + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + } + if config.use_fp8: + base_actions["layers.0.self_attn.o_proj.weight.weight_scale_inv"] = partial(fn, is_column=False) + + if config.tie_word_embeddings: + base_actions["lm_head.weight"] = partial(fn, is_column=False) + else: + base_actions["lm_head.weight"] = partial(fn, is_column=True) + + if not config.vocab_size % config.tensor_parallel_degree == 0: + base_actions.pop("lm_head.weight") + base_actions.pop("embed_tokens.weight") + + # Column Linear + base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.q_b_proj.weight"] = partial(fn, is_column=True) + + # if we have enough num_key_value_heads to split, then split it. + # ??? + if config.num_key_value_heads % config.tensor_parallel_degree == 0: + base_actions["layers.0.self_attn.kv_b_proj.weight"] = partial(fn, is_column=True) + if config.use_fp8: + base_actions["layers.0.self_attn.kv_b_proj.weight.weight_scale_inv"] = partial(fn, is_column=True) + + # dense mlp + base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.down_proj.weight"] = partial(fn, is_column=False) + if config.use_fp8: + base_actions["layers.0.mlp.up_proj.weight.weight_scale_inv"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.gate_proj.weight.weight_scale_inv"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.down_proj.weight.weight_scale_inv"] = partial(fn, is_column=False) + + # moe unit routed experts + moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + expert_parallel_degree = dist.get_world_size(moe_group) + if expert_parallel_degree <= 1: + for e_i in range(config.n_routed_experts): + base_actions[f"layers.0.mlp.experts.{e_i}.up_proj.weight"] = partial(fn, is_column=True) + base_actions[f"layers.0.mlp.experts.{e_i}.gate_proj.weight"] = partial(fn, is_column=True) + base_actions[f"layers.0.mlp.experts.{e_i}.down_proj.weight"] = partial(fn, is_column=False) + + # moe unit shared experts + base_actions["layers.0.mlp.shared_experts.gate_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.shared_experts.up_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.shared_experts.down_proj.weight"] = partial(fn, is_column=False) + if config.use_fp8: + base_actions["layers.0.mlp.shared_experts.gate_proj.weight.weight_scale_inv"] = partial( + fn, is_column=True + ) + base_actions["layers.0.mlp.shared_experts.up_proj.weight.weight_scale_inv"] = partial( + fn, is_column=True + ) + base_actions["layers.0.mlp.shared_experts.down_proj.weight.weight_scale_inv"] = partial( + fn, is_column=False + ) + + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + # for MTP (eagle) parameters for inference + base_actions.pop("embed_tokens.weight") + base_actions.pop("lm_head.weight") + base_actions["layers.0.embed_tokens.weight"] = partial(fn, is_column=False) + base_actions["layers.0.shared_head.head.weight"] = partial(fn, is_column=True) + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range( + config.num_hidden_layers, config.num_hidden_layers + config.num_nextn_predict_layers + ): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + else: + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + + return mappings + + def _init_weights(self, layer): + if self.config.tensor_parallel_degree > 1: + rng_tracker = get_rng_state_tracker().rng_state + + if isinstance( + layer, + ( + nn.Linear, + nn.Embedding, + mpu.VocabParallelEmbedding, + mpu.RowParallelLinear, + mpu.ColumnParallelLinear, + linear_utils.RowSequenceParallelLinear, + linear_utils.ColumnSequenceParallelLinear, + Linear, + ), + ): + # In the dygraph mode, use the `set_value` to reset the parameter directly, + # and reset the `state_dict` to update parameter in static mode. + if isinstance(layer.weight, paddle.Tensor): + if layer.weight.is_distributed: + with rng_tracker(): + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.initializer_range, + shape=layer.weight.shape, + ) + ) + else: + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.initializer_range, + shape=layer.weight.shape, + ) + ) + + # set bias to zeros + if getattr(layer, "bias", None) is not None: + layer.bias.set_value(paddle.zeros(shape=layer.bias.shape)) + + if isinstance(layer, nn.Embedding): + if layer._padding_idx is not None: + layer.weight.data[layer._padding_idx].fill_(0) + + if isinstance(layer, MoEGate): + kaiming_uniform_(layer.weight, a=math.sqrt(5)) + + moe_grad_group = fleet.get_hybrid_communicate_group().expert_grad_comm_group + if moe_grad_group is not None and moe_grad_group.nranks > 1: + for p in layer.parameters(): + if hasattr(p, "color") and "color" in p.color: + if p.color["color"] == "moe_expert": + paddle.distributed.broadcast(p, src=moe_grad_group.ranks[0], group=moe_grad_group) + + def step_flex_token(self, cur_step): + set_global_step(cur_step) + + +@register_base_model +class DeepseekV2ModelFast(DeepseekV2PretrainedModelFast): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`] + + Args: + config: DeepseekV2Config + """ + + def __init__(self, config: DeepseekV2Config): + super().__init__(config) + + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # Recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.recompute_granularity = config.recompute_granularity + self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + + if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: + self.embed_tokens = mpu.VocabParallelEmbedding(config.vocab_size, config.hidden_size) + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + + self.layers = nn.LayerList( + [ + DeepseekV2DecoderLayer(config, layer_idx, layer_idx not in self.no_recompute_layers) + for layer_idx in range(config.num_hidden_layers) + ] + ) + for layer_idx in range(config.num_hidden_layers, config.num_hidden_layers + config.num_nextn_predict_layers): + self.layers.append(DeepseekV2MTPLayer(config, layer_idx, layer_idx not in self.no_recompute_layers)) + + self.norm = DeepseekV2RMSNorm(config) + + self.enable_recompute = False + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @staticmethod + def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype): + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if len(attention_mask.shape) == 2: + expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1]) + # For decoding phase in generation, seq_length = 1, we don't need to add causal mask + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + past_key_values_length=past_key_values_length, + ) + expanded_attn_mask = expanded_attn_mask & combined_attention_mask + # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len] + elif len(attention_mask.shape) == 3: + expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool") + # if attention_mask is already 4-D, do nothing + else: + expanded_attn_mask = attention_mask + else: + expanded_attn_mask = _make_causal_mask( + input_shape, + past_key_values_length=past_key_values_length, + ) + # Convert bool attention_mask to float attention mask, which will be added to attention_scores later + if get_env_device() == "xpu": + x = paddle.to_tensor(0.0, dtype="float32") + y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32") + expanded_attn_mask = paddle.where(expanded_attn_mask, x, y) + else: + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min).astype( + dtype + ) + return expanded_attn_mask + + @paddle.jit.not_to_static + def recompute_training_full( + self, + layer_module: nn.Layer, + hidden_states: Tensor, + position_ids: Optional[Tensor], + attention_mask: Tensor, + output_attentions: bool, + past_key_value: Tensor, + use_cache: bool, + attn_mask_startend_row_indices: Optional[Tensor] = None, + ): + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = recompute( + create_custom_forward(layer_module), + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + use_reentrant=self.config.recompute_use_reentrant, + ) + + return hidden_states + + def forward( + self, + input_ids: paddle.Tensor = None, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + attn_mask_startend_row_indices: Optional[Tensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPastAndMTP]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.config.num_nextn_predict_layers > 0: + seq_length -= self.config.num_nextn_predict_layers + + if attention_mask is not None: + attention_mask = attention_mask[ + :, :, : -self.config.num_nextn_predict_layers, : -self.config.num_nextn_predict_layers + ] + + if self.enable_recompute and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." + ) + use_cache = False + + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + # NOTE: to make cache can be clear in-time + past_key_values = list(past_key_values) + + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[1] + seq_length_with_past += past_key_values_length + + if position_ids is None: + position_ids = paddle.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=paddle.int64 + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + # [bs, seq_len, dim] + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + if attn_mask_startend_row_indices is not None or get_use_casual_mask(): + attention_mask = None + else: + # [bs, seq_len] + attention_mask = ( + paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + if attention_mask is None + else attention_mask + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), past_key_values_length, inputs_embeds.dtype + ) # [bs, 1, seq_len, seq_len] + if self.config.use_flash_attention: + attention_mask = None if is_casual_mask(attention_mask) else attention_mask + + if self.config.num_nextn_predict_layers > 0: + inputs_embeds_extra = inputs_embeds[:, -self.config.num_nextn_predict_layers :, :] # [B, S, D] + inputs_embeds = inputs_embeds[:, : -self.config.num_nextn_predict_layers, :] + inputs_embeds_ori = inputs_embeds + + if self.config.sequence_parallel: + # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] + bs, seq_len, hidden_size = inputs_embeds.shape + inputs_embeds = paddle.reshape(inputs_embeds, [bs * seq_len, hidden_size]) + # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) + inputs_embeds = ScatterOp.apply(inputs_embeds) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + mtp_outputs = [] + + for idx in range(self.config.num_hidden_layers): + decoder_layer = self.layers[idx] + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and idx not in self.no_recompute_layers + and has_gradient + and self.recompute_granularity == "full" + ): + layer_outputs = self.recompute_training_full( + decoder_layer, + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + ) + + # NOTE: clear outdate cache after it has been used for memory saving + past_key_value = past_key_values[idx] = None + if type(layer_outputs) is tuple: + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if self.config.num_nextn_predict_layers > 0: + mtp_outputs.append(hidden_states) + + for nextn in range(self.config.num_nextn_predict_layers): + decoder_layer = self.layers[nextn + self.config.num_hidden_layers] + + if self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + hidden_states = hidden_states.reshape([-1, seq_length, hidden_states.shape[-1]]) + + inputs_embeds_cur_depth = paddle.concat( + [inputs_embeds_ori[:, (nextn + 1) :, :], inputs_embeds_extra[:, : (nextn + 1), :]], axis=1 + ) + + past_key_value = None + layer_outputs = decoder_layer( + hidden_states, + inputs_embeds_cur_depth, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + ) + + if isinstance(layer_outputs, (tuple, list)): + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + mtp_outputs.append(hidden_states) + mtp_outputs = [self.norm(hidden_states) for hidden_states in mtp_outputs] + hidden_states, mtp_outputs = mtp_outputs[0], mtp_outputs[1:] + else: + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple( + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, mtp_outputs] if v is not None + ) + return BaseModelOutputWithPastAndMTP( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + mtp_outputs=mtp_outputs, + ) + + +class DeepseekV2PretrainingCriterionFast(nn.Layer): + """ + Criterion for Mixtral. + It calculates the final loss. + """ + + def __init__(self, config: DeepseekV2Config): + super(DeepseekV2PretrainingCriterion, self).__init__() + self.ignore_index = getattr(config, "ignore_index", -100) + self.config = config + self.enable_parallel_cross_entropy = config.tensor_parallel_degree > 1 and config.tensor_parallel_output + + if self.enable_parallel_cross_entropy: # and False: # and lm_head is distributed + self.loss_func = mpu.ParallelCrossEntropy(ignore_index=self.ignore_index) + else: + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + def forward(self, prediction_scores, masked_lm_labels, router_loss=None, mtp_logits=None): + if self.enable_parallel_cross_entropy: + if prediction_scores.shape[-1] == self.config.vocab_size: + warnings.warn( + f"enable_parallel_cross_entropy, the vocab_size should be splitted: {prediction_scores.shape[-1]}, {self.config.vocab_size}" + ) + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + def compute_loss(preds, labels): + with paddle.amp.auto_cast(False): + masked_lm_loss = FastCrossEntropyFunction.apply(preds, labels.unsqueeze(2)) + binary_sequence = paddle.where( + masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss) + ) + count = paddle.sum(binary_sequence) + loss = paddle.where( + count == 0, + paddle.sum(masked_lm_loss * binary_sequence), + paddle.sum(masked_lm_loss * binary_sequence) / count, + ) + return loss diff --git a/paddleformers/transformers/deepseek_v2/modeling_pp.py b/paddleformers/transformers/deepseek_v2/modeling_pp.py index 42b0e5de776..e9b3ca847fc 100644 --- a/paddleformers/transformers/deepseek_v2/modeling_pp.py +++ b/paddleformers/transformers/deepseek_v2/modeling_pp.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import math +import os from typing import OrderedDict, Tuple, Union import paddle @@ -20,32 +21,86 @@ import paddle.nn as nn from paddle.distributed.fleet.meta_parallel import ( LayerDesc, + LocalSharedLayerDesc, PipelineLayer, + ScheduleChunk, + ScheduleNode, SharedLayerDesc, ) +from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import WeightGradStore + +try: + from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import EventStore +except ImportError: + EventStore = None + from paddle.distributed.fleet.recompute.recompute import recompute from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp +from ...utils.log import logger from ...utils.tools import get_env_device from ..model_utils import PipelinePretrainedModel -from .modeling import ( - DeepseekV2Config, - DeepseekV2DecoderLayer, - DeepseekV2LMHead, - DeepseekV2Model, - DeepseekV2MTPLayer, - DeepseekV2PretrainedModel, - DeepseekV2PretrainingCriterion, - DeepseekV2RMSNorm, + +if not os.getenv("DSV3_FAST_PRETRAIN", "False"): + from .modeling import ( + DeepseekV2Config, + DeepseekV2DecoderLayer, + DeepseekV2LMHead, + DeepseekV2Model, + DeepseekV2MoE, + DeepseekV2MTPLayer, + DeepseekV2PretrainedModel, + DeepseekV2PretrainingCriterion, + DeepseekV2RMSNorm, + TemporaryVarContext, + set_global_step, + ) +else: + from .modeling import ( + DeepseekV2Config, + DeepseekV2LMHead, + DeepseekV2PretrainingCriterion, + DeepseekV2RMSNorm, + TemporaryVarContext, + set_global_step, + ) + from .modeling_fast import DeepseekV2DecoderLayer + from .modeling_fast import DeepseekV2ModelFast as DeepseekV2Model + from .modeling_fast import DeepseekV2MoE, DeepseekV2MTPLayer + from .modeling_fast import ( + DeepseekV2PretrainedModelFast as DeepseekV2PretrainedModel, + ) + +try: + import paddle.distributed.communication.deep_ep as deep_ep +except ImportError: + deep_ep = None + +from paddleformers.transformers.fused_a2a import ( + fused_combine_backward_func, + fused_combine_forward_func, + fused_dispatch_backward_func, + fused_dispatch_forward_func, ) +from paddleformers.transformers.moe_layer import FusionMoeNode + +from ..fp8_utils import FP8LinearFunction, FP8LinearFunctionBase __all__ = [ "DeepseekV2ForCausalLMPipe", ] +import queue + +global_inputs_embeds_mtp_queue = queue.Queue() + + +DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true" +DSV3_USE_FP8_DISPATCH = os.getenv("DSV3_USE_FP8_DISPATCH", "False").lower() == "true" + def parse_args(args): - if isinstance(args, tuple): + if isinstance(args, (tuple, list)): if len(args) == 4: hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = args @@ -55,6 +110,9 @@ def parse_args(args): elif len(args) == 2: hidden_states, attention_mask = args attn_mask_startend_row_indices, position_ids = None, None + else: # len(args) == 1: + hidden_states = args[0] + attention_mask, attn_mask_startend_row_indices, position_ids = None, None, None else: hidden_states = args attention_mask, attn_mask_startend_row_indices, position_ids = None, None, None @@ -93,6 +151,1302 @@ def get_attr(layer, name): return get_attr(layer._layer, name) +def calc_stream_wait(group_id): + comm_event = deep_ep.get_event_from_comm_stream(group_id) + comm_event.calc_stream_wait(group_id) + + +class TensorMeta: + """Recording the meta info of forward inputs, to avoid 0-size problems""" + + def __init__(self, tensor): + self.shape = tensor.shape + self.dtype = tensor.dtype + + +class PostProcessNode(ScheduleNode): + def __init__( + self, + send_mtp_embed, + training, + alpha, + config, + shared_experts=None, + using_post_norm_recompute=False, + output_mtp_embed_first=False, + name="PostProcessNode", + ): + self.send_mtp_embed = send_mtp_embed + self.shared_experts = shared_experts + self.traning = training + self.config = config + self.alpha = alpha + self.using_post_norm_recompute = using_post_norm_recompute + self.output_mtp_embed_first = output_mtp_embed_first + self.name = name + + if self.using_post_norm_recompute: + assert self.shared_experts is not None + assert self.shared_experts.norm_weight is not None and self.shared_experts.norm_eps is not None + + def forward_without_residual(self, inputs): + + if isinstance(inputs, list): + inputs = tuple(inputs) + if self.using_post_norm_recompute: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states, norm_out) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states, norm_out) = inputs + else: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states) = inputs + + with paddle.no_grad(): + if self.shared_experts is not None: + if self.using_post_norm_recompute: + _, _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( + norm_out, self.shared_experts.w1, self.shared_experts.w2 + ) + norm_out = None + del norm_out + else: + _, _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( + hidden_states, self.shared_experts.w1, self.shared_experts.w2 + ) + residual = residual + shared_expert_output + + self.x = hidden_states + self.l_aux = l_aux + + hidden_states = residual + hidden_states.stop_gradient = False + + if self.send_mtp_embed: + assert not self.output_mtp_embed_first, "forward_without_residual doesn't support output_mtp_embed_first" + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + self.mtp_embed_shape = inputs_embeds_mtp.shape # 保存mtp_embed的shape用于反向传播 + + return return_args(hidden_states) + + def forward(self, inputs): + + if isinstance(inputs, list): + inputs = tuple(inputs) + + if self.using_post_norm_recompute: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states, norm_out) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states, norm_out) = inputs + else: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states) = inputs + + with paddle.no_grad(): + if self.shared_experts is not None: + if self.using_post_norm_recompute: + _, _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( + norm_out, self.shared_experts.w1, self.shared_experts.w2 + ) + norm_out = None + del norm_out + else: + _, _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( + hidden_states, self.shared_experts.w1, self.shared_experts.w2 + ) + final_hidden_states = final_hidden_states + shared_expert_output + + self.x = hidden_states + self.l_aux = l_aux + hidden_states = residual + final_hidden_states + + if self.send_mtp_embed: + if self.output_mtp_embed_first: + hidden_states = paddle.concat([inputs_embeds_mtp, hidden_states], axis=-1) + else: + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + self.mtp_embed_shape = inputs_embeds_mtp.shape # 保存mtp_embed的shape用于反向传播 + + return return_args(hidden_states) + + @paddle.no_grad() + def backward(self, output_grad): + (do3,) = output_grad + + if self.send_mtp_embed: + # 分割梯度:do3的前部分对应hidden_states,后部分对应inputs_embeds_mtp + hidden_size = do3.shape[-1] - self.mtp_embed_shape[-1] + if self.output_mtp_embed_first: + hidden_states_grad = do3[..., hidden_size:] + inputs_embeds_mtp_grad = do3[..., :hidden_size] + else: + hidden_states_grad = do3[..., :hidden_size] + inputs_embeds_mtp_grad = do3[..., hidden_size:] + else: + hidden_states_grad = do3 + inputs_embeds_mtp_grad = None + + if self.using_post_norm_recompute: + dx, norm_out, invar = FP8LinearFunctionBase.fp8_mlp_bwd_norm_rc( + hidden_states_grad, + self.x, + self.shared_experts.norm_weight, + self.shared_experts.norm_eps, + self.shared_experts.w1, + self.shared_experts.w2, + ) + else: + dx = FP8LinearFunctionBase.fp8_mlp_bwd( + hidden_states_grad, self.x, self.shared_experts.w1, self.shared_experts.w2, True + ) + + self.x = None + + residual_grad = hidden_states_grad + l_aux_grad = paddle.ones(1, dtype=self.l_aux.dtype) * self.alpha + final_hidden_states_grad = hidden_states_grad + + if self.using_post_norm_recompute: + if self.send_mtp_embed: + return ( + inputs_embeds_mtp_grad, + dx, + residual_grad, + l_aux_grad, + final_hidden_states_grad, + norm_out, + invar, + ) + else: + return (dx, residual_grad, l_aux_grad, final_hidden_states_grad, norm_out, invar) + else: + if self.send_mtp_embed: + return (inputs_embeds_mtp_grad, dx, residual_grad, l_aux_grad, final_hidden_states_grad) + else: + return (dx, residual_grad, l_aux_grad, final_hidden_states_grad) + + +class DecoderLayerNode(ScheduleNode): + def __init__( + self, + attn_node, + dispatch_node, + mlp_node, + combine_node, + post_process_node, + mlp_layer, + name="DecoderLayerNode", + ): + super().__init__(fwd_func=None, name=name) + assert (dispatch_node is None and combine_node is None) or ( + dispatch_node is not None and combine_node is not None + ) + self.attn_node = attn_node + self.dispatch_node = dispatch_node + self.mlp_node = mlp_node + self.combine_node = combine_node + self.post_process_node = post_process_node + + self.mlp_layer = mlp_layer + self.moe_group = mlp_layer.moe_group + self.moe_num_experts = mlp_layer.moe_num_experts + + self.states = None + self.hidden_states_meta = None + self.dispatched_probs_meta = None + self.combine_output_meta = None + + def dispatch_forward(self, inputs, previous_event=None, allocate_on_comm_stream=False): + paddle.base.core.nvprof_nvtx_push("raw_dispatch_forward") + if isinstance(inputs, list): + inputs = tuple(inputs) + ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + intermediate_hidden_states, + token_indices, + token_probs, + ) = inputs + + with paddle.no_grad(): + intermediate_hidden_states, dispatched_probs, states, _ = fused_dispatch_forward_func( + intermediate_hidden_states, + token_indices, + token_probs, + self.moe_num_experts, + self.moe_group, + previous_event=previous_event, + async_finish=True, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + dispatched_indices = states["dispatched_indices"] + self.mlp_layer.set_tokens_per_expert(states["tokens_per_expert"]) + dispatched_indices.stop_gradient = True + intermediate_hidden_states.stop_gradient = False + dispatched_probs.stop_gradient = False + self.states = states + self.hidden_states_meta = TensorMeta(intermediate_hidden_states) + self.dispatched_probs_meta = TensorMeta(dispatched_probs) + + inputs = ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + intermediate_hidden_states, + dispatched_indices, + dispatched_probs, + ) + paddle.base.core.nvprof_nvtx_pop() + return inputs + + def combine_forward(self, inputs, previous_event=None): + paddle.base.core.nvprof_nvtx_push("raw_combine_forward") + if isinstance(inputs, list): + inputs = tuple(inputs) + (inputs_embeds_mtp, hidden_states, residual, l_aux, expert_output) = inputs + + with paddle.no_grad(): + combine_output = fused_combine_forward_func( + expert_output, self.moe_group, self.states, previous_event=previous_event, async_finish=True + ) + combine_output.stop_gradient = False + self.combine_output_meta = TensorMeta(combine_output) + inputs = (inputs_embeds_mtp, hidden_states, residual, l_aux, combine_output) + paddle.base.core.nvprof_nvtx_pop() + return inputs + + def dispatch_backward(self, output_grad): + paddle.base.core.nvprof_nvtx_push("raw_dispatch_backward") + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + intermediate_hidden_states_grad, + dispatched_indices_grad, + dispatched_probs_grad, + ) = output_grad + + if intermediate_hidden_states_grad is None: + intermediate_hidden_states_grad = paddle.zeros( + self.hidden_states_meta.shape, self.hidden_states_meta.dtype + ) + if dispatched_probs_grad is None: + dispatched_probs_grad = paddle.zeros(self.dispatched_probs_meta.shape, self.dispatched_probs_meta.dtype) + with paddle.no_grad(): + intermediate_hidden_states_grad, token_indices_grad, token_probs_grad = fused_dispatch_backward_func( + intermediate_hidden_states_grad, + dispatched_probs_grad, + self.moe_group, + self.states["handle"], + async_finish=True, + ) + + output_grad = ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + intermediate_hidden_states_grad, + token_indices_grad, + token_probs_grad, + ) + paddle.base.core.nvprof_nvtx_pop() + return output_grad + + def combine_backward(self, output_grad): + paddle.base.core.nvprof_nvtx_push("raw_combine_backward") + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + combine_output_grad, + ) = output_grad + + if combine_output_grad is None: + combine_output_grad = paddle.zeros(self.combine_output_meta.shape, self.combine_output_meta.dtype) + with paddle.no_grad(): + expert_output_grad = fused_combine_backward_func( + combine_output_grad, self.moe_group, self.states["handle"], async_finish=True + ) + + output_grad = ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + expert_output_grad, + ) + paddle.base.core.nvprof_nvtx_pop() + return output_grad + + def forward(self, inputs): + inputs = self.attn_node.forward(inputs) + + if self.dispatch_node is None: + inputs = self.dispatch_forward(inputs) + calc_stream_wait(self.moe_group.id) + else: + inputs = self.dispatch_node.forward(inputs) + + inputs = self.mlp_node.forward(inputs) + + if self.combine_node is None: + inputs = self.combine_forward(inputs) + calc_stream_wait(self.moe_group.id) + else: + inputs = self.combine_node.forward(inputs) + + inputs = self.post_process_node.forward(inputs) + return inputs + + def backward(self, output_grad=None, scaler=None): + assert (output_grad is not None) and (scaler is None) + + output_grad = self.post_process_node.backward(output_grad) + + if self.combine_node is None: + output_grad = self.combine_backward(output_grad) + calc_stream_wait(self.moe_group.id) + else: + output_grad = self.combine_node.backward(output_grad) + + output_grad = self.mlp_node.backward(output_grad) + + if self.dispatch_node is None: + output_grad = self.dispatch_backward(output_grad) + calc_stream_wait(self.moe_group.id) + else: + output_grad = self.dispatch_node.backward(output_grad) + + output_grad = self.attn_node.backward(output_grad) + return output_grad + + +class OverlapedScheduleChunk: + def __init__(self, forward_nodes, backward_nodes, use_fuion=True): + assert len(forward_nodes) == len(backward_nodes) + self.nodes = [] + for f, b in zip(forward_nodes, backward_nodes): + schedule_node_class = OverlapedScheduleNode + if use_fuion: + schedule_node_class = OverlapedFUsionScheduleNode + if isinstance(f, DenseDecoderLayerNode) or isinstance(b, DenseDecoderLayerNode): + schedule_node_class = OverlapedDenseFusionScheduleNode + self.nodes.append(schedule_node_class(f, b, f"OverlapedNode_{len(self.nodes)}")) + + def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None): + # print(" fwd pp stream", pp_stream) + event_to_wait = combine_bw_event_to_wait + for i, n in enumerate(self.nodes): + pp_stream_t = pp_stream + if i + 1 != len(self.nodes): + pp_stream_t = None + + inputs, output_grad, event_to_wait = n.forward_backward( + inputs, output_grad, combine_bw_event_to_wait=event_to_wait, pp_stream=pp_stream_t + ) + return inputs, output_grad, None + + +class DecoderBackwardScheduleChunk: + def __init__(self, nodes): + self.nodes = nodes + + def backward(self, output_grad, combine_bw_event_to_wait=None, pp_stream=None): + event_to_wait = combine_bw_event_to_wait + for i, n in enumerate(self.nodes): + pp_stream_t = pp_stream if i + 1 == len(self.nodes) else None + output_grad, event_to_wait = n.backward_for_fusion( + output_grad, combine_bw_event_to_wait=event_to_wait, pp_stream=pp_stream_t + ) + return output_grad + + +class OverlapedScheduleNode: + def __init__(self, forward_node, backward_node, name=""): + assert isinstance(forward_node, DecoderLayerNode) and isinstance(backward_node, DecoderLayerNode) + self.forward_node = forward_node + self.backward_node = backward_node + self.name = name + + def forward_backward(self, inputs, output_grad, event_to_wait=None): + paddle.base.core.nvprof_nvtx_push("forward_backward") + output_grad = self.backward_node.post_process_node.backward(output_grad) + + output_grad = self.backward_node.combine_backward(output_grad) + inputs = self.forward_node.attn_node.forward(inputs) + + calc_stream_wait(self.backward_node.moe_group.id) + attn_compute_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + output_grad = self.backward_node.mlp_node.backward(output_grad) + inputs = self.forward_node.dispatch_forward( + inputs, previous_event=attn_compute_event, allocate_on_comm_stream=True + ) + + calc_stream_wait(self.forward_node.moe_group.id) + output_grad = self.backward_node.dispatch_backward(output_grad) + inputs = self.forward_node.mlp_node.forward(inputs) + + calc_stream_wait(self.backward_node.moe_group.id) + inputs = self.forward_node.combine_forward(inputs) + output_grad = self.backward_node.attn_node.backward(output_grad) + + calc_stream_wait(self.forward_node.moe_group.id) + inputs = self.forward_node.post_process_node.forward(inputs) + paddle.base.core.nvprof_nvtx_pop() + return inputs, output_grad + + +class FusionFp8DecoderLayerNode(ScheduleNode): + def __init__( + self, + attn_and_gate_node, + fp8_fusion_moe_node, + post_process_node, + mlp_layer, + send_mtp_embed, + using_post_norm_recompute=False, + stepped_recompute_fwd_gate_up=False, + name="", + ): + self.attn_and_gate_node = attn_and_gate_node + self.fp8_fusion_moe_node = fp8_fusion_moe_node + self.post_process_node = post_process_node + self.send_mtp_embed = send_mtp_embed + + self.using_post_norm_recompute = using_post_norm_recompute + self.stepped_recompute_fwd_gate_up = stepped_recompute_fwd_gate_up + self.name = name + + self.moe_group = mlp_layer.moe_group + + def attn_forward(self, inputs): + inputs = self.attn_and_gate_node.forward(inputs) + + if self.send_mtp_embed: + if self.using_post_norm_recompute: + inputs_embeds_mtp, hidden_states, residual, probs, routing_map, l_aux, norm_out = inputs + else: + inputs_embeds_mtp, hidden_states, residual, probs, routing_map, l_aux = inputs + else: + if self.using_post_norm_recompute: + hidden_states, residual, probs, routing_map, l_aux, norm_out = inputs + else: + hidden_states, residual, probs, routing_map, l_aux = inputs + + if self.using_post_norm_recompute: + hs_2d, token_indices, token_probs = self.fp8_fusion_moe_node.dispatch_quant_node.forward( + norm_out, probs, routing_map + ) + # common return values + ret = (hidden_states, residual, l_aux, hs_2d, token_indices, token_probs, norm_out) + + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + return ret + else: + hs_2d, token_indices, token_probs = self.fp8_fusion_moe_node.dispatch_quant_node.forward( + hidden_states, probs, routing_map + ) + + # common return values + ret = (hidden_states, residual, l_aux, hs_2d, token_indices, token_probs) + + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + return ret + + def dispatch_forward(self, inputs, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + inputs_embeds_mtp, hidden_states, residual, l_aux, hs_2d, token_indices, token_probs, norm_out = inputs + else: + hidden_states, residual, l_aux, hs_2d, token_indices, token_probs, norm_out = inputs + else: + if self.send_mtp_embed: + inputs_embeds_mtp, hidden_states, residual, l_aux, hs_2d, token_indices, token_probs = inputs + else: + hidden_states, residual, l_aux, hs_2d, token_indices, token_probs = inputs + + (hs_dispatched, dispatched_indices, dispatched_probs,) = self.fp8_fusion_moe_node.dispatch_node.forward( + hs_2d, + token_indices, + token_probs, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + ret = (hidden_states, residual, l_aux, hs_dispatched, dispatched_indices, dispatched_probs) + + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret + return ret + + def mlp_forward(self, inputs): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + hs_dispatched, + dispatched_indices, + dispatched_probs, + norm_out, + ) = inputs + else: + hidden_states, residual, l_aux, hs_dispatched, dispatched_indices, dispatched_probs, norm_out = inputs + else: + if self.send_mtp_embed: + ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + hs_dispatched, + dispatched_indices, + dispatched_probs, + ) = inputs + else: + hidden_states, residual, l_aux, hs_dispatched, dispatched_indices, dispatched_probs = inputs + + hidden_states_out = self.fp8_fusion_moe_node.mlp_node.forward( + hs_dispatched, dispatched_indices, dispatched_probs + ) + ret = (hidden_states, residual, l_aux, hidden_states_out) + + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret + return ret + + def combine_forward(self, inputs, async_finish=False, previous_event=None, allocate_on_comm_stream=False): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, hidden_states_out, norm_out) = inputs + else: + (hidden_states, residual, l_aux, hidden_states_out, norm_out) = inputs + else: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, hidden_states_out) = inputs + else: + (hidden_states, residual, l_aux, hidden_states_out) = inputs + + output_combine = self.fp8_fusion_moe_node.combine_node.forward( + hidden_states_out, + async_finish=async_finish, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream and previous_event is not None, + ) + + ret = (hidden_states, residual, l_aux, output_combine) + + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret + return ret + + def post_process_forward(self, inputs, with_residual=True): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, output_combine, norm_out) = inputs + else: + (hidden_states, residual, l_aux, output_combine, norm_out) = inputs + else: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, output_combine) = inputs + else: + (hidden_states, residual, l_aux, output_combine) = inputs + final_hidden_states = self.fp8_fusion_moe_node.combine_quant_node.forward(output_combine) + + inputs = (hidden_states, residual, l_aux, final_hidden_states) + inputs = (inputs_embeds_mtp, *inputs) if self.send_mtp_embed else inputs + inputs = (*inputs, norm_out) if self.using_post_norm_recompute else inputs + + if with_residual: + inputs = self.post_process_node.forward(inputs) + else: + inputs = self.post_process_node.forward_without_residual(inputs) + return inputs + + def post_process_backward(self, output_grad, event_to_wait=None): + grad = self.post_process_node.backward(output_grad) + + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + final_hidden_states_grad, + norm_out, + invar, + ) = grad + else: + hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad, norm_out, invar = grad + else: + if self.send_mtp_embed: + inputs_embeds_mtp_grad, hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad = grad + else: + hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad = grad + + output_combine_grad, quant_event = self.fp8_fusion_moe_node.combine_quant_node.backward( + final_hidden_states_grad, event_to_wait + ) + + ret = (hidden_states_grad, residual_grad, l_aux_grad, output_combine_grad, quant_event) + ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret + return ret + + def combine_backward(self, output_grad, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + output_combine_grad, + quant_event, + norm_out, + invar, + ) = output_grad + else: + ( + hidden_states_grad, + residual_grad, + l_aux_grad, + output_combine_grad, + quant_event, + norm_out, + invar, + ) = output_grad + else: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + output_combine_grad, + quant_event, + ) = output_grad + else: + ( + hidden_states_grad, + residual_grad, + l_aux_grad, + output_combine_grad, + quant_event, + ) = output_grad + + if DSV3_USE_FP8_DISPATCH and quant_event is not None: + combine_backward_wait_event = quant_event + else: + combine_backward_wait_event = previous_event + hidden_states_out_grad = self.fp8_fusion_moe_node.combine_node.backward( + output_combine_grad, + async_finish=async_finish, + previous_event=combine_backward_wait_event, + allocate_on_comm_stream=allocate_on_comm_stream and quant_event is not None, + ) + + ret = (hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad) + ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret + return ret + + def mlp_backward(self, output_grad): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hidden_states_out_grad, + norm_out, + invar, + ) = output_grad + else: + hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad, norm_out, invar = output_grad + else: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hidden_states_out_grad, + ) = output_grad + else: + hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad = output_grad + hs_dispatched_grad, dispatched_probs_grad = self.fp8_fusion_moe_node.mlp_node.backward(hidden_states_out_grad) + + ret = (hidden_states_grad, residual_grad, l_aux_grad, hs_dispatched_grad, dispatched_probs_grad) + ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret + return ret + + def dispatch_backward(self, output_grad, async_finish=False, previous_event=None, allocate_on_comm_stream=False): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_dispatched_grad, + dispatched_probs_grad, + norm_out, + invar, + ) = output_grad + else: + ( + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_dispatched_grad, + dispatched_probs_grad, + norm_out, + invar, + ) = output_grad + else: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_dispatched_grad, + dispatched_probs_grad, + ) = output_grad + else: + hidden_states_grad, residual_grad, l_aux_grad, hs_dispatched_grad, dispatched_probs_grad = output_grad + + hs_grad, token_probs_grad = self.fp8_fusion_moe_node.dispatch_node.backward( + hs_dispatched_grad, + dispatched_probs_grad, + async_finish=async_finish, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream and previous_event is not None, + ) + + ret = (hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad) + ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret + return ret + + def attn_backward(self, output_grad): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_grad, + token_probs_grad, + norm_out, + invar, + ) = output_grad + inputs_embeds_mtp_grad_shape = hidden_states_grad.shape + inputs_embeds_mtp_grad_shape[-1] = -1 + inputs_embeds_mtp_grad = inputs_embeds_mtp_grad.view(inputs_embeds_mtp_grad_shape) + else: + hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad, norm_out, invar = output_grad + else: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_grad, + token_probs_grad, + ) = output_grad + inputs_embeds_mtp_grad_shape = hidden_states_grad.shape + inputs_embeds_mtp_grad_shape[-1] = -1 + inputs_embeds_mtp_grad = inputs_embeds_mtp_grad.view(inputs_embeds_mtp_grad_shape) + else: + hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad = output_grad + + hidden_states_grad_, probs_grad, routing_map_grad = self.fp8_fusion_moe_node.dispatch_quant_node.backward( + hs_grad, token_probs_grad + ) + + output_grad = (residual_grad, probs_grad, routing_map_grad, l_aux_grad) + + output_grad = ( + (hidden_states_grad, *output_grad, hidden_states_grad_) + if self.using_post_norm_recompute + else (hidden_states_grad + hidden_states_grad_, *output_grad) + ) + output_grad = (inputs_embeds_mtp_grad, *output_grad) if self.send_mtp_embed else output_grad + + if self.using_post_norm_recompute: + with TemporaryVarContext(norm_out, invar): + output_grad = self.attn_and_gate_node.backward(output_grad) + else: + output_grad = self.attn_and_gate_node.backward(output_grad) + return output_grad + + def backward_for_fusion(self, output_grad, combine_bw_event_to_wait=None, pp_stream=None): + paddle.base.core.nvprof_nvtx_push("backward") + if combine_bw_event_to_wait is None: + combine_bw_event_to_wait = deep_ep.get_event_from_calc_stream(self.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("post_process_backward") + output_grad = self.post_process_backward(output_grad, combine_bw_event_to_wait) + paddle.base.core.nvprof_nvtx_pop() + + paddle.base.core.nvprof_nvtx_push("combine_backward") + output_grad = self.combine_backward( + output_grad, previous_event=combine_bw_event_to_wait, async_finish=True, allocate_on_comm_stream=True + ) + combine_backward_event = deep_ep.get_event_from_comm_stream(self.moe_group.id) + combine_backward_event.calc_stream_wait(self.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() + + if WeightGradStore.enabled: + paddle.base.core.nvprof_nvtx_push("mlp_backward") + output_grad = self.mlp_backward(output_grad) + paddle.base.core.nvprof_nvtx_pop() + + paddle.base.core.nvprof_nvtx_push("dispatch_backward") + output_grad = self.dispatch_backward(output_grad) + paddle.base.core.nvprof_nvtx_pop() + + paddle.base.core.nvprof_nvtx_push("attn_backward") + output_grad = self.attn_backward(output_grad) + paddle.base.core.nvprof_nvtx_pop() + + event_to_wait = None + + else: + paddle.base.core.nvprof_nvtx_push("mlp_backward_dx") + assert WeightGradStore.funcs_queue.empty() + WeightGradStore.enabled = True + output_grad = self.mlp_backward(output_grad) + WeightGradStore.enabled = False + WeightGradStore.flush() + output_grad_event = deep_ep.get_event_from_calc_stream(self.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() + + paddle.base.core.nvprof_nvtx_push("dispatch_backward") + output_grad = self.dispatch_backward( + output_grad, async_finish=True, previous_event=output_grad_event, allocate_on_comm_stream=True + ) + dispatch_backward_event = deep_ep.get_event_from_comm_stream(self.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() + + paddle.base.core.nvprof_nvtx_push("mlp_backward_dw") + WeightGradStore.pop() + assert WeightGradStore.funcs_queue.empty() + paddle.base.core.nvprof_nvtx_pop() + + paddle.base.core.nvprof_nvtx_push("attn_backward_dx") + dispatch_backward_event.calc_stream_wait(self.moe_group.id) + WeightGradStore.enabled = True + output_grad = self.attn_backward(output_grad) + WeightGradStore.enabled = False + WeightGradStore.flush() + event_to_wait = deep_ep.get_event_from_calc_stream(self.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() + + paddle.base.core.nvprof_nvtx_push("attn_backward_dw") + WeightGradStore.pop() + assert WeightGradStore.funcs_queue.empty() + paddle.base.core.nvprof_nvtx_pop() + + paddle.base.core.nvprof_nvtx_pop() + return output_grad, event_to_wait + + def forward(self, inputs): + if self.stepped_recompute_fwd_gate_up: + self.fp8_fusion_moe_node.mlp_node.set_recompute_fwd_gate_up(True) + inputs = self.attn_forward(inputs) + inputs = self.dispatch_forward(inputs) + inputs = self.mlp_forward(inputs) + inputs = self.combine_forward(inputs) + inputs = self.post_process_forward(inputs) + return inputs + + def backward(self, output_grad=None, scaler=None): + assert (output_grad is not None) and (scaler is None) + output_grad = self.post_process_backward(output_grad) + output_grad = self.combine_backward(output_grad) + output_grad = self.mlp_backward(output_grad) + # todo(phlrain): overlap here + output_grad = self.dispatch_backward(output_grad) + output_grad = self.attn_backward(output_grad) + return output_grad + + +class DenseDecoderLayerNode(ScheduleNode): + def __init__( + self, + attn_node, + mlp_node, + name="DenseDecoderLayerNode", + ): + super().__init__(fwd_func=None, name=name) + self.attn_node = attn_node + self.mlp_node = mlp_node + + def forward(self, inputs): + inputs = self.attn_node.forward(inputs) + inputs = self.mlp_node.forward(inputs) + return inputs + + def backward(self, output_grad=None, scaler=None): + assert (output_grad is not None) and (scaler is None) + output_grad = self.mlp_node.backward(output_grad) + output_grad = self.attn_node.backward(output_grad) + return output_grad + + +class OverlapedFUsionScheduleNode: + def __init__(self, forward_node, backward_node, name=""): + assert isinstance(forward_node, FusionFp8DecoderLayerNode) and isinstance( + backward_node, FusionFp8DecoderLayerNode + ) + self.forward_node = forward_node + self.backward_node = backward_node + self.name = name + + def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None): + paddle.base.core.nvprof_nvtx_push("forward_backward") + + combine_bwd_event = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("attn_forward") + inputs = self.forward_node.attn_forward(inputs) + paddle.base.core.nvprof_nvtx_pop() + attn_compute_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("post_process_backward") + output_grad = self.backward_node.post_process_backward(output_grad, combine_bw_event_to_wait) + paddle.base.core.nvprof_nvtx_pop() + + paddle.base.core.nvprof_nvtx_push("combine_backward") + if combine_bw_event_to_wait is not None: + # print(" event", combine_bw_event_to_wait) + output_grad = self.backward_node.combine_backward( + output_grad, previous_event=combine_bw_event_to_wait, async_finish=True, allocate_on_comm_stream=True + ) + else: + output_grad = self.backward_node.combine_backward( + output_grad, previous_event=combine_bwd_event, async_finish=True, allocate_on_comm_stream=True + ) + # get combine event + combine_backward_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() + + combine_backward_event.calc_stream_wait(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_push("mlp_backward_dx") + assert WeightGradStore.funcs_queue.empty() + WeightGradStore.enabled = True + output_grad = self.backward_node.mlp_backward(output_grad) + WeightGradStore.enabled = False + WeightGradStore.flush() + paddle.base.core.nvprof_nvtx_pop() + + output_grad_event = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("dispatch_forward") + inputs = self.forward_node.dispatch_forward( + inputs, previous_event=attn_compute_event, async_finish=True, allocate_on_comm_stream=True + ) + paddle.base.core.nvprof_nvtx_pop() + dispatch_forward_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("dispatch_backward") + output_grad = self.backward_node.dispatch_backward( + output_grad, async_finish=True, previous_event=output_grad_event, allocate_on_comm_stream=True + ) + paddle.base.core.nvprof_nvtx_pop() + # get dispatch backward event + dispatch_backward_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("dispatch_backward_dw") + WeightGradStore.pop() + assert WeightGradStore.funcs_queue.empty() + paddle.base.core.nvprof_nvtx_pop() + + dispatch_forward_event.calc_stream_wait(self.forward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_push("mlp_forward") + inputs = self.forward_node.mlp_forward(inputs) + paddle.base.core.nvprof_nvtx_pop() + + if pp_stream is not None: + paddle.base.core.nvprof_nvtx_push("post_process_forward") + + final_out = self.forward_node.post_process_node.forward_without_residual(inputs) + paddle.base.core.nvprof_nvtx_pop() + + final_out_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("combine_forward") + inputs = self.forward_node.combine_forward( + inputs, previous_event=final_out_event, async_finish=True, allocate_on_comm_stream=True + ) + paddle.base.core.nvprof_nvtx_pop() + + combine_forward_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id) + + combine_fwd_out = inputs[-2] if self.forward_node.using_post_norm_recompute else inputs[-1] + + if pp_stream is not None: + send_recv_stream = paddle.device.Stream(stream_base=pp_stream) + + paddle.base.core.nvprof_nvtx_push("pp stream add") + + with paddle.device.stream_guard(send_recv_stream): + combine_forward_event.current_stream_wait() + final_out_event.current_stream_wait() + + # TODO: check correct + # if final_out.shape[-1] != combine_fwd_out.shape[-1]: + # final_out[:, :, : combine_fwd_out.shape[-1]] += combine_fwd_out # 直接广播并相加 + # else: + # final_out += combine_fwd_out + inputs = final_out + combine_fwd_out + + final_out._record_stream() + combine_fwd_out._record_stream() + + paddle.base.core.nvprof_nvtx_pop() + + dispatch_backward_event.calc_stream_wait(self.backward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("attn_backward") + assert WeightGradStore.funcs_queue.empty() + WeightGradStore.enabled = True + output_grad = self.backward_node.attn_backward(output_grad) + event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) + + if EventStore is not None: + EventStore.set(event_to_wait) + + WeightGradStore.enabled = False + WeightGradStore.flush() + WeightGradStore.pop() + assert WeightGradStore.funcs_queue.empty() + + paddle.base.core.nvprof_nvtx_pop() + + # residual add + if pp_stream is None: + combine_forward_event.calc_stream_wait(self.forward_node.moe_group.id) + + final_out = self.forward_node.post_process_node.forward_without_residual(inputs) + if final_out.shape[-1] != combine_fwd_out.shape[-1]: + final_out[:, :, : combine_fwd_out.shape[-1]] += combine_fwd_out # 直接广播并相加 + else: + final_out += combine_fwd_out + inputs = final_out + combine_fwd_out._record_stream() + + paddle.base.core.nvprof_nvtx_pop() + return inputs, output_grad, event_to_wait + + +class OverlapedDenseFusionScheduleNode: + def __init__(self, forward_node, backward_node, name=""): + assert isinstance(forward_node, FusionFp8DecoderLayerNode) or isinstance( + backward_node, FusionFp8DecoderLayerNode + ) + assert isinstance(forward_node, DenseDecoderLayerNode) or isinstance(backward_node, DenseDecoderLayerNode) + self.forward_node = forward_node + self.backward_node = backward_node + self.name = name + + def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None): + # Dense forward + MoE backward + if isinstance(self.forward_node, DenseDecoderLayerNode): + paddle.base.core.nvprof_nvtx_push("dense_fw_moe_bw") + + paddle.base.core.nvprof_nvtx_push("dense_attn_moe_combine") + # Note: the input combine_bw_event_to_wait is unreliable, we need to record a new event here. + combine_bw_event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) + output_grad = self.backward_node.post_process_backward(output_grad, combine_bw_event_to_wait) + output_grad = self.backward_node.combine_backward( + output_grad, previous_event=combine_bw_event_to_wait, async_finish=True, allocate_on_comm_stream=True + ) + combine_bw_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id) + inputs = self.forward_node.attn_node.forward(inputs) + combine_bw_event.calc_stream_wait(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_attn_moe_combine + + paddle.base.core.nvprof_nvtx_push("moe_mlp") + assert WeightGradStore.funcs_queue.empty() + WeightGradStore.enabled = True + output_grad = self.backward_node.mlp_backward(output_grad) + WeightGradStore.enabled = False + WeightGradStore.flush() + paddle.base.core.nvprof_nvtx_pop() # moe_mlp + + paddle.base.core.nvprof_nvtx_push("dense_mlp_moe_dispatch") + output_grad = self.backward_node.dispatch_backward( + output_grad, async_finish=True, allocate_on_comm_stream=True + ) + dispatch_bw_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_push("dispatch_backward_dw") + WeightGradStore.pop() + assert WeightGradStore.funcs_queue.empty() + paddle.base.core.nvprof_nvtx_pop() + inputs = self.forward_node.mlp_node.forward(inputs) + dispatch_bw_event.calc_stream_wait(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_mlp_moe_dispatch + + paddle.base.core.nvprof_nvtx_push("moe_attn") + output_grad = self.backward_node.attn_backward(output_grad) + paddle.base.core.nvprof_nvtx_pop() # moe_attn + + event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_fw_moe_bw + + # Dense backward + MoE forward + else: + paddle.base.core.nvprof_nvtx_push("dense_bw_moe_fw") + + paddle.base.core.nvprof_nvtx_push("moe_attn") + inputs = self.forward_node.attn_forward(inputs) + attn_fw_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # moe_attn + + paddle.base.core.nvprof_nvtx_push("dense_mlp_moe_dispatch") + if combine_bw_event_to_wait is not None: + combine_bw_event_to_wait.calc_stream_wait(self.forward_node.moe_group.id) + output_grad = self.backward_node.mlp_node.backward(output_grad) + inputs = self.forward_node.dispatch_forward( + inputs, previous_event=attn_fw_event, async_finish=True, allocate_on_comm_stream=True + ) + dispatch_fw_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id) + dispatch_fw_event.calc_stream_wait(self.forward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_mlp_moe_dispatch + + paddle.base.core.nvprof_nvtx_push("moe_mlp") + inputs = self.forward_node.mlp_forward(inputs) + paddle.base.core.nvprof_nvtx_pop() # moe_mlp + + paddle.base.core.nvprof_nvtx_push("dense_attn_moe_combine") + inputs = self.forward_node.combine_forward(inputs, async_finish=True, allocate_on_comm_stream=True) + combine_fw_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id) + output_grad = self.backward_node.attn_node.backward(output_grad) + combine_fw_event.calc_stream_wait(self.forward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_attn_moe_combine + + paddle.base.core.nvprof_nvtx_push("moe_post") + inputs = self.forward_node.post_process_forward(inputs) + paddle.base.core.nvprof_nvtx_pop() # moe_post + + event_to_wait = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_bw_moe_fw + + return inputs, output_grad, event_to_wait + + +def build_overlapped_nodes(forward_chunk, backward_chunk): + overlap_element_class = ( + FusionFp8DecoderLayerNode if DSV3_USE_FP8_GEMM else DecoderLayerNode, + DenseDecoderLayerNode, + ) + forward_decoder_layer_num = 0 + backward_decoder_layer_num = 0 + assert isinstance(forward_chunk, ScheduleChunk) and isinstance(backward_chunk, ScheduleChunk) + for n in forward_chunk.nodes: + if isinstance(n, overlap_element_class): + forward_decoder_layer_num += 1 + for n in reversed(backward_chunk.nodes): + if isinstance(n, overlap_element_class): + backward_decoder_layer_num += 1 + + overlap_layers_num = min(forward_decoder_layer_num, backward_decoder_layer_num) + forward_pre_overlap_layers = [] + forward_post_overlap_layers = [] + forward_overlap_layers = [] + is_pre = True + for n in forward_chunk.nodes: + if not isinstance(n, overlap_element_class): + if is_pre: + forward_pre_overlap_layers.append(n) + else: + forward_post_overlap_layers.append(n) + else: + is_pre = False + if len(forward_overlap_layers) == overlap_layers_num: + forward_post_overlap_layers.append(n) + else: + forward_overlap_layers.append(n) + forward_pre_node = ScheduleChunk(forward_pre_overlap_layers) + forward_post_node = ScheduleChunk(forward_post_overlap_layers) + + backward_pre_overlap_layers = [] + backward_post_overlap_layers = [] + backward_overlap_layers = [] + is_pre = True + for n in reversed(backward_chunk.nodes): + if not isinstance(n, overlap_element_class): + if is_pre: + backward_pre_overlap_layers.append(n) + else: + backward_post_overlap_layers.append(n) + else: + is_pre = False + if len(backward_overlap_layers) == overlap_layers_num: + backward_post_overlap_layers.append(n) + else: + backward_overlap_layers.append(n) + + backward_pre_node = ScheduleChunk(list(reversed(backward_pre_overlap_layers))) + backward_post_node = ScheduleChunk(list(reversed(backward_post_overlap_layers))) + + if not forward_chunk.nodes and all(isinstance(n, FusionFp8DecoderLayerNode) for n in backward_chunk.nodes): + backward_post_node = DecoderBackwardScheduleChunk(backward_post_overlap_layers) + + overlap_node = OverlapedScheduleChunk(forward_overlap_layers, backward_overlap_layers, use_fuion=DSV3_USE_FP8_GEMM) + return forward_pre_node, backward_pre_node, overlap_node, forward_post_node, backward_post_node + + +class EmbeddingFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, weight): + out = paddle.nn.functional.embedding( + x, weight=weight, padding_idx=None, max_norm=None, norm_type=2.0, sparse=False, scale_grad_by_freq=False + ) + + ctx.save_for_backward(x, weight) + return out + + @staticmethod + def backward(ctx, dout): + x, weight = ctx.saved_tensor() + + if hasattr(weight, "main_grad"): + paddle.incubate.nn.functional.embedding_grad_add_to_(x, weight.main_grad, dout) + else: + paddle.incubate.nn.functional.embedding_grad_add_to_(x, weight.grad, dout) + + return None, None + + class DeepseekV2EmbeddingPipe(nn.Layer): def __init__(self, config: DeepseekV2Config): super(DeepseekV2EmbeddingPipe, self).__init__() @@ -122,7 +1476,10 @@ def forward(self, args): _type_: _description_ """ input_ids, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) - inputs_embeds = self.embed_tokens(input_ids) + if not os.getenv("DSV3_FAST_PRETRAIN", "False"): + inputs_embeds = self.embed_tokens(input_ids) + else: + inputs_embeds = EmbeddingFunction.apply(input_ids, self.embed_tokens.weight) batch_size, seq_length = input_ids.shape if self.config.num_nextn_predict_layers > 0: @@ -160,6 +1517,7 @@ def forward(self, args): # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) inputs_embeds = ScatterOp.apply(inputs_embeds) embeds_res = [inputs_embeds] + mtp_embeds = [] for depth in range(self.config.num_nextn_predict_layers): inputs_embeds_mtp = paddle.concat( [ @@ -171,12 +1529,19 @@ def forward(self, args): if self.sequence_parallel: inputs_embeds_mtp = inputs_embeds_mtp.reshape([-1, inputs_embeds_mtp.shape[-1]]) inputs_embeds_mtp = ScatterOp.apply(inputs_embeds_mtp) - embeds_res.append(inputs_embeds_mtp) - # if not self.sequence_parallel - # mtp_embeds: [B*num_nextn_predict_layers, seq_len, hidden_size] - # else: - # mtp_embeds: [B*seq_len*num_nextn_predict_layers, hidden_size] - inputs_embeds = paddle.concat(embeds_res, axis=-1) + mtp_embeds.append(inputs_embeds_mtp) + + if self.config.send_mtp_embed: + embeds_res.extend(mtp_embeds) + # if not self.sequence_parallel + # mtp_embeds: [B*num_nextn_predict_layers, seq_len, hidden_size] + # else: + # mtp_embeds: [B*seq_len*num_nextn_predict_layers, hidden_size] + inputs_embeds = paddle.concat(embeds_res, axis=-1) + else: + global global_inputs_embeds_mtp_queue + cloned_mtp_embeds = [t.detach() for t in mtp_embeds] + global_inputs_embeds_mtp_queue.put(cloned_mtp_embeds) return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids) else: if self.sequence_parallel: @@ -184,15 +1549,18 @@ def forward(self, args): inputs_embeds = ScatterOp.apply(inputs_embeds) return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids) + def build_schedule_node(self): + return ScheduleNode(self.forward, name="DeepseekV2EmbeddingPipe") + class DeepseekV2DecoderLayerPipe(DeepseekV2DecoderLayer): def forward(self, args): hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) - if self.config.num_nextn_predict_layers > 0: + if self.config.send_mtp_embed: batch_size, _, hidden_size = hidden_states.shape batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1) - inputs_embeds_mtp = hidden_states[..., -batch_size_mtp:] + inputs_embeds_mtp = hidden_states[..., batch_size_mtp:] hidden_states = hidden_states[..., :batch_size_mtp] has_gradient = not hidden_states.stop_gradient @@ -235,19 +1603,286 @@ def forward(self, args): attn_mask_startend_row_indices=attn_mask_startend_row_indices, ) - if self.config.num_nextn_predict_layers > 0: + if self.config.send_mtp_embed: hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids) + def attn_compute(self, args): + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) + assert attention_mask is None + assert attn_mask_startend_row_indices is None + assert position_ids is None + assert self.config.send_mtp_embed + + batch_size, _, hidden_size = hidden_states.shape + batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1) + inputs_embeds_mtp = hidden_states[..., batch_size_mtp:] + hidden_states = hidden_states[..., :batch_size_mtp] + + def attn_compute_func(hidden_states): + hidden_states, residual = self.self_attn_compute(hidden_states) + l_aux, _, intermediate_hidden_states, token_indices, token_probs = self.pre_dispatch_compute(hidden_states) + return (hidden_states, residual, l_aux, intermediate_hidden_states, token_indices, token_probs) + + has_gradient = not hidden_states.stop_gradient + if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + # for pretrain + outputs = recompute( + attn_compute_func, + hidden_states, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = attn_compute_func(hidden_states) + + return (inputs_embeds_mtp, *outputs) + + def attn_compute_for_fusion(self, args): + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) + assert attention_mask is None + assert attn_mask_startend_row_indices is None + assert position_ids is None + + send_mtp_embed = self.config.send_mtp_embed + + if send_mtp_embed: + # slice from holy tensor + batch_size, _, hidden_size = hidden_states.shape + batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1) + inputs_embeds_mtp = hidden_states[..., batch_size_mtp:] + hidden_states = hidden_states[..., :batch_size_mtp] + + hidden_states, residual = self.self_attn_compute(hidden_states) + _, _, d_model = hidden_states.shape + + if self.using_post_norm_recompute: + probs, routing_map, l_aux, _, norm_out = self.mlp.router(hidden_states) + else: + probs, routing_map, l_aux, _ = self.mlp.router(hidden_states) + + # common return values + ret = ( + hidden_states, + residual, + probs, + routing_map, + l_aux, + ) + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if send_mtp_embed else ret + # append norm_out if using post_norm recompute + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret + + return ret + + def mlp_compute(self, inputs): + if isinstance(inputs, list): + inputs = tuple(inputs) + send_mtp_embed = self.config.send_mtp_embed + + if send_mtp_embed: + ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + intermediate_hidden_states, + dispatched_indices, + dispatched_probs, + ) = inputs + else: + ( + hidden_states, + residual, + l_aux, + intermediate_hidden_states, + dispatched_indices, + dispatched_probs, + ) = inputs + has_gradient = not intermediate_hidden_states.stop_gradient + if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + expert_output = recompute( + self.expert_forward_compute, + intermediate_hidden_states, + dispatched_indices, + dispatched_probs, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + expert_output = self.expert_forward_compute( + intermediate_hidden_states, dispatched_indices, dispatched_probs + ) + if send_mtp_embed: + return (inputs_embeds_mtp, hidden_states, residual, l_aux, expert_output) + else: + return (hidden_states, residual, l_aux, expert_output) + + def post_process_compute(self, inputs): + send_mtp_embed = self.config.send_mtp_embed + + if isinstance(inputs, list): + inputs = tuple(inputs) + if send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, combine_output) = inputs + else: + (hidden_states, residual, l_aux, combine_output) = inputs + has_gradient = not hidden_states.stop_gradient + if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + hidden_states = recompute( + self.post_combine_compute, + residual, + hidden_states, + combine_output, + l_aux, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + hidden_states = self.post_combine_compute( + residual, + hidden_states, + combine_output, + l_aux, + ) + if send_mtp_embed: + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + + return return_args(hidden_states) + + def post_process_compute_for_fusion(self, inputs): + send_mtp_embed = self.config.send_mtp_embed + + if isinstance(inputs, list): + inputs = tuple(inputs) + + if send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states) = inputs + + final_hidden_states = self.mlp.post_process(hidden_states, final_hidden_states, l_aux) + + hidden_states = residual + final_hidden_states + + hidden_states = (hidden_states,) + + if type(hidden_states) is tuple and len(hidden_states) == 1: + hidden_states = hidden_states[0] + + if send_mtp_embed: + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + + return return_args(hidden_states) + + def attn_compute_dense(self, args): + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) + assert attention_mask is None + assert attn_mask_startend_row_indices is None + assert position_ids is None + + if self.config.send_mtp_embed: + batch_size, _, hidden_size = hidden_states.shape + batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1) + inputs_embeds_mtp = hidden_states[..., batch_size_mtp:] + hidden_states = hidden_states[..., :batch_size_mtp] + + hidden_states, residual = self.self_attn_compute(hidden_states) + + ret = (hidden_states, residual) + ret = (inputs_embeds_mtp, *ret) if self.config.send_mtp_embed else ret + return ret + + def mlp_compute_dense(self, inputs): + if self.config.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual) = inputs + else: + (hidden_states, residual) = inputs + + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + if self.config.send_mtp_embed: + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + + return hidden_states + + def build_schedule_node(self): + if isinstance(self.mlp, DeepseekV2MoE): + self.mlp.update_flex_token() + if self.mlp.using_flex_token: + if DSV3_USE_FP8_GEMM: + attn_and_gate_node = ScheduleNode(self.attn_compute_for_fusion, name="attn_and_gate_node") + + # recompute_fwd_gate_up_ may be 1, 0 or -1, 1 means recompute, 0 means disable recompute, -1 means adaptive recompute. + recompute_fwd_gate_up_ = 1 if self.layer_idx in self.config.recompute_fwd_gate_up_list else 0 + if recompute_fwd_gate_up_ == 0 and self.config.adaptive_remained_O1_recompute_ratio: + recompute_fwd_gate_up_ = -1 + + fp8_fusion_moe_node = FusionMoeNode( + self.mlp, + recompute_fwd_gate_up=recompute_fwd_gate_up_, + is_split_group_gemm=self.config.is_split_group_gemm, + mlp_fwd_subbatch_rows=self.config.mlp_fwd_subbatch_rows, + mlp_bwd_subbatch_rows=self.config.mlp_bwd_subbatch_rows, + output_subbatch_rows=self.config.output_subbatch_rows, + name="fp8_fusion_moe_node", + ) + post_process_node = PostProcessNode( + self.config.send_mtp_embed, + self.mlp.training, + self.mlp.alpha, + self.config, + self.mlp.shared_experts, + self.config.using_post_norm_recompute, + output_mtp_embed_first=isinstance(self, DeepseekV2MTPLayer), + name="post_process_node", + ) + return FusionFp8DecoderLayerNode( + attn_and_gate_node=attn_and_gate_node, + fp8_fusion_moe_node=fp8_fusion_moe_node, + post_process_node=post_process_node, + mlp_layer=self.mlp, + send_mtp_embed=self.config.send_mtp_embed, + using_post_norm_recompute=self.config.using_post_norm_recompute, + stepped_recompute_fwd_gate_up=self.config.stepped_recompute_fwd_gate_up, + name="FusionFp8DecoderLayerNode", + ) + else: + attn_node = ScheduleNode(self.attn_compute, name="attn_node") + mlp_node = ScheduleNode(self.mlp_compute, name="mlp_node") + post_process_node = ScheduleNode(self.post_process_compute, name="post_process_node") + return DecoderLayerNode( + attn_node=attn_node, + dispatch_node=None, + mlp_node=mlp_node, + combine_node=None, + post_process_node=post_process_node, + mlp_layer=self.mlp, + name="DecoderLayerNode", + ) + + attn_node = ScheduleNode(self.attn_compute_dense, name="attn_node") + mlp_node = ScheduleNode(self.mlp_compute_dense, name="mlp_node") + return DenseDecoderLayerNode( + attn_node=attn_node, + mlp_node=mlp_node, + name="DenseDecoderLayerNode", + ) + class DeepseekV2MTPLayerPipe(DeepseekV2MTPLayer): def forward(self, args): hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) - hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1) - hidden_states_main_model = hidden_states_list[0] - inputs_embeds_cur_depth_list = hidden_states_list[1:] + if self.config.send_mtp_embed: + hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1) + hidden_states_main_model = hidden_states_list[0] + inputs_embeds_cur_depth_list = hidden_states_list[1:] + else: + hidden_states_main_model = hidden_states + global global_inputs_embeds_mtp_queue + inputs_embeds_cur_depth_list = global_inputs_embeds_mtp_queue.get() + has_gradient = not hidden_states_main_model.stop_gradient if attention_mask is not None and attention_mask.dtype == paddle.int32: @@ -299,6 +1934,67 @@ def forward(self, args): hidden_states = paddle.concat(output_list, axis=-1) return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids) + def attn_compute_for_fusion(self, args): + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) + assert attention_mask is None + assert attn_mask_startend_row_indices is None + assert position_ids is None + assert self.config.num_nextn_predict_layers == 1 + + if self.config.send_mtp_embed: + hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1) + hidden_states_main_model = hidden_states_list[0] + inputs_embeds_cur_depth_list = hidden_states_list[1:] + else: + hidden_states_main_model = hidden_states + global global_inputs_embeds_mtp_queue + inputs_embeds_cur_depth_list = global_inputs_embeds_mtp_queue.get() + + hidden_states = hidden_states_main_model + nextn_hidden_state = inputs_embeds_cur_depth_list[0] + + # mtp compute + hidden_states = self.hnorm(hidden_states) + nextn_hidden_state = self.enorm(nextn_hidden_state) + + concat_h = paddle.concat([nextn_hidden_state, hidden_states], axis=-1) + hidden_states = FP8LinearFunction.apply(concat_h, self.eh_proj) + + # attention compute + hidden_states, residual = self.self_attn_compute(hidden_states) + + if self.using_post_norm_recompute: + probs, routing_map, l_aux, _, norm_out = self.mlp.router(hidden_states) + else: + probs, routing_map, l_aux, _ = self.mlp.router(hidden_states) + + # common return values + ret = ( + hidden_states_main_model, + hidden_states, + residual, + probs, + routing_map, + l_aux, + ) + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret + + return ret + + def build_schedule_node(self): + if isinstance(self.mlp, DeepseekV2MoE): + self.mlp.update_flex_token() + if self.mlp.using_flex_token and DSV3_USE_FP8_GEMM and self.config.num_nextn_predict_layers == 1: + prev_send_mtp_embed = self.config.send_mtp_embed + self.config.send_mtp_embed = True # must be True in MTP node + + node = DeepseekV2DecoderLayerPipe.build_schedule_node(self) + assert isinstance(node, FusionFp8DecoderLayerNode) + + self.config.send_mtp_embed = prev_send_mtp_embed + return node + return ScheduleNode(self.forward, name="DeepseekV2MTPLayerPipe") + class DeepseekV2RMSNormPipe(nn.Layer): def __init__(self, config): @@ -321,10 +2017,13 @@ def forward(self, args): else: return self.norm(hidden_states) + def build_schedule_node(self): + return ScheduleNode(self.forward, name="DeepseekV2RMSNormPipe") + class DeepseekV2LMHeadPipe(DeepseekV2LMHead): - def __init__(self, config): - super(DeepseekV2LMHeadPipe, self).__init__(config) + def __init__(self, config, embedding_weight=None): + super(DeepseekV2LMHeadPipe, self).__init__(config, embedding_weight=embedding_weight) @property def embedding_weight(self): @@ -340,6 +2039,9 @@ def forward(self, args: Union[Tuple, paddle.Tensor]): logits = super().forward(hidden_states) return logits + def build_schedule_node(self): + return ScheduleNode(self.forward, name="DeepseekV2LMHeadPipe") + class DeepseekV2PretrainingCriterionPipe(DeepseekV2PretrainingCriterion): def forward(self, logits, labels): @@ -348,9 +2050,14 @@ def forward(self, logits, labels): logits = logits[0] loss = super().forward(logits, labels, mtp_logits=mtp_logits) else: + if isinstance(logits, (tuple, list)): + logits = logits[0] loss = super().forward(logits, labels) return loss + def build_schedule_node(self): + return ScheduleNode(self.forward, name="DeepseekV2PretrainingCriterionPipe") + class DeepseekV2ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): """DeepseekV2ForPretraining adapted for pipeline parallelism. @@ -371,6 +2078,9 @@ class DeepseekV2ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): # DONOT Add base_model_prefix !!!! + def step_flex_token(self, cur_step): + set_global_step(cur_step) + @classmethod def _prepare_pipeline_inputs_func(cls, inputs): first_stage_keys = ["input_ids", "attention_mask", "attn_mask_startend_row_indices", "position_ids"] @@ -408,6 +2118,10 @@ def __init__(self, config: DeepseekV2Config): assert len(self.no_recompute_layers) == 0, "for pp with full recompute, no_recompute_layers is not support" virtual_pp_degree = getattr(self.config, "virtual_pp_degree", 1) + use_dualpipev = getattr(self.config, "use_dualpipev", False) + if use_dualpipev: + assert LocalSharedLayerDesc is not None, "LocalSharedLayerDesc is None, please update your paddle." + shared_class = LocalSharedLayerDesc if use_dualpipev else SharedLayerDesc def get_hcg(): return fleet.get_hybrid_communicate_group() @@ -422,7 +2136,7 @@ def get_hcg(): if config.tie_word_embeddings: self.add_sequential_layer( - SharedLayerDesc( + shared_class( "DeepseekV2_shared_weight", DeepseekV2EmbeddingPipe, shared_weight_attr="embedding_weight", @@ -435,6 +2149,68 @@ def get_hcg(): LayerDesc(DeepseekV2EmbeddingPipe, config=config), self._base_model.base_model_prefix ) + def compute_recompute_fwd_gate_up_list(pp_nums, all_dl_nums, dense_dl_nums, recompute_fwd_gate_up): + all_layers_nums = all_dl_nums + 4 # embedding, rms, lm_head, mtp + segment_size = all_layers_nums // pp_nums + boundary = math.ceil((1 + dense_dl_nums) / segment_size) * segment_size + recompute_fwd_gate_up_list = [dense_dl_nums] + for idx in range(boundary - 1, all_dl_nums, segment_size): + recompute_fwd_gate_up_list.append(idx) + + # If `recompute_fwd_gate_up` is a Boolean value and is True, means all O1 will be recomputed. + # Otherwise `recompute_fwd_gate_up` should be an integer representing how many O1 are recomputed. + assert isinstance(recompute_fwd_gate_up, (int, bool)) + if type(recompute_fwd_gate_up) is bool: + enable_k_o1_rc = segment_size if recompute_fwd_gate_up is True else 0 + else: + enable_k_o1_rc = recompute_fwd_gate_up + + ret = [] + for i in range(len(recompute_fwd_gate_up_list)): + for k in range(min(segment_size, enable_k_o1_rc)): + ret.append(recompute_fwd_gate_up_list[i] + k) + return ret + + def compute_recompute_fa3_list(pp_nums, all_dl_nums, recompute_fa3): + all_layers_nums = all_dl_nums + 4 # embedding, rms, lm_head, mtp + segment_size = all_layers_nums // pp_nums + recompute_fa3_list = [0] + for idx in range(segment_size - 1, all_dl_nums, segment_size): + recompute_fa3_list.append(idx) + + # If `recompute_fa3` is a Boolean value and is True, means all O1 will be recomputed. + # Otherwise `recompute_fa3` should be an integer representing how many O1 are recomputed. + assert isinstance(recompute_fa3, (int, bool)) + if type(recompute_fa3) is bool: + enable_k_o1_rc = segment_size if recompute_fa3 is True else 0 + else: + enable_k_o1_rc = recompute_fa3 + + ret = [] + for i in range(len(recompute_fa3_list)): + for k in range(min(segment_size, enable_k_o1_rc)): + ret.append(recompute_fa3_list[i] + k) + return ret + + pp_nums = ( + self.config["pipeline_parallel_degree"] * 2 + if self.config.use_dualpipev + else self.config["pipeline_parallel_degree"] + ) + recompute_fwd_gate_up_list = compute_recompute_fwd_gate_up_list( + pp_nums, + self.config.num_hidden_layers, + self.config.first_k_dense_replace, + self.config.recompute_fwd_gate_up, + ) + recompute_fa3_list = compute_recompute_fa3_list( + pp_nums, self.config.num_hidden_layers, self.config.recompute_fa3 + ) + + logger.info(f"recompute_fa3_list: {recompute_fa3_list}") + logger.info(f"recompute_fwd_gate_up_list: {recompute_fwd_gate_up_list}") + config.recompute_fwd_gate_up_list = recompute_fwd_gate_up_list + for i in range(config.num_hidden_layers): self.add_sequential_layer( LayerDesc( @@ -442,6 +2218,7 @@ def get_hcg(): config=config, layer_idx=i, layerwise_recompute=i not in self.no_recompute_layers, + recompute_fa3=i in recompute_fa3_list, ), f"{self._base_model.base_model_prefix}.layers.{i}", ) @@ -455,7 +2232,7 @@ def get_hcg(): if config.tie_word_embeddings: self.add_sequential_layer( - SharedLayerDesc( + shared_class( "DeepseekV2_shared_weight", DeepseekV2LMHeadPipe, shared_weight_attr="embedding_weight", @@ -491,11 +2268,69 @@ def get_hcg(): "partition": False, }, num_virtual_pipeline_stages=virtual_pp_degree, + use_dualpipev=use_dualpipev, ) # You should call init here, since there is a diamond inheritance problem self.apply(self._init_weights) # DON'T init PipelinePretrainedModel # PipelinePretrainedModel.__init__(self.super(), config=config) + def fp8_quant_weight(self, batch_mode=False, quant_transpose=True): + """fp8_quant_weight""" + with paddle.no_grad(): + for i, layer in self._sub_layers.items(): + if isinstance( + layer, paddle.distributed.fleet.meta_parallel.parallel_layers.pp_layers.PipelineLayerChunk + ): + for i, sub_layer in layer.named_sublayers(): + if isinstance(sub_layer, DeepseekV2DecoderLayer) and hasattr(sub_layer, "fp8_quant_weight"): + sub_layer.fp8_quant_weight(batch_mode, quant_transpose) + if isinstance(layer, DeepseekV2DecoderLayer) and hasattr(layer, "fp8_quant_weight"): + layer.fp8_quant_weight(batch_mode, quant_transpose) + def get_loss_fn(self, config): return DeepseekV2PretrainingCriterionPipe(config) + + def overlapped_forward_backward( + self, + forward_chunk, # the module of the forward chunk + forward_inputs, + forward_loss_fn_node, + backward_chunk, # the module of the backward chunk, maybe not used + backward_loss_fn_node, + backward_input_grads, + scaler, + combine_bw_event_to_wait=None, + pp_stream=None, + ): + if backward_loss_fn_node is not None: + if scaler: + backward_input_grads = backward_loss_fn_node.backward(scaler=scaler) + else: + backward_input_grads = backward_loss_fn_node.backward() + + ( + forward_pre_node, + backward_pre_node, + overlap_node, + forward_post_node, + backward_post_node, + ) = build_overlapped_nodes(forward_chunk, backward_chunk) + forward_inputs = forward_pre_node.forward(forward_inputs) + backward_input_grads = backward_pre_node.backward(backward_input_grads) + forward_inputs, backward_input_grads, _ = overlap_node.forward_backward( + forward_inputs, + backward_input_grads, + combine_bw_event_to_wait=combine_bw_event_to_wait, + pp_stream=pp_stream, + ) + forward_inputs = forward_post_node.forward(forward_inputs) + backward_input_grads = backward_post_node.backward(backward_input_grads) + + if forward_loss_fn_node is not None: + forward_loss = forward_loss_fn_node.forward(forward_inputs) + else: + forward_loss = None + + forward_inputs = [forward_inputs] if isinstance(forward_inputs, paddle.Tensor) else forward_inputs + return forward_inputs, forward_loss, backward_input_grads diff --git a/paddleformers/transformers/deepseek_v3/modeling.py b/paddleformers/transformers/deepseek_v3/modeling.py index 51c0d1978fe..597975adc3b 100644 --- a/paddleformers/transformers/deepseek_v3/modeling.py +++ b/paddleformers/transformers/deepseek_v3/modeling.py @@ -21,17 +21,28 @@ from __future__ import annotations +import os from typing import List, Optional, Tuple, Union import paddle -from ..deepseek_v2.modeling import ( - DeepseekV2ForSequenceClassification, - DeepseekV2LMHead, - DeepseekV2Model, - DeepseekV2PretrainedModel, - DeepseekV2PretrainingCriterion, -) +if not os.getenv("DSV3_FAST_PRETRAIN", "False"): + from ..deepseek_v2.modeling import ( + DeepseekV2ForSequenceClassification, + DeepseekV2LMHead, + DeepseekV2Model, + DeepseekV2PretrainedModel, + DeepseekV2PretrainingCriterion, + ) +else: + from ..deepseek_v2.modeling import ( + DeepseekV2ForSequenceClassification, + DeepseekV2LMHead, + DeepseekV2PretrainingCriterion, + ) + from ..deepseek_v2.modeling_fast import DeepseekV2ModelFast as DeepseekV2Model + from ..deepseek_v2.modeling_fast import DeepseekV2PretrainedModelFast as DeepseekV2PretrainedModel + from ..model_outputs import CausalLMOutputWithPast from ..model_utils import register_base_model from .configuration import DeepseekV3Config diff --git a/paddleformers/transformers/fp8_utils.py b/paddleformers/transformers/fp8_utils.py new file mode 100644 index 00000000000..506bcca3b75 --- /dev/null +++ b/paddleformers/transformers/fp8_utils.py @@ -0,0 +1,1307 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from functools import partial + +import numpy +import paddle +import paddle.nn.functional as F + +try: + import fused_ln + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + + +from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import WeightGradStore + +USE_DS_GEMM = os.getenv("USE_DS_GEMM", "False").lower() == "true" + +try: + if USE_DS_GEMM: + import deep_gemm + else: + from paddle.incubate.fp8 import deep_gemm +except: + pass + + +__all__ = [ + "FP8LinearFunctionBase", + "FP8Linear", + "FP8GroupGemmMlpFunctionNode", +] + + +def get_sm_num(): + return 112 + + +def set_parameter_color( + parameters, color, group=None, offline_quant_expert_weight=True, clear_origin_weight_when_offline_quant=True +): + if offline_quant_expert_weight and clear_origin_weight_when_offline_quant: + if group is None: + for p in parameters: + if hasattr(p, "color") and p.color is not None: + continue + setattr(p, "color", {"color": color}) + else: + for p in parameters: + if hasattr(p, "color") and p.color is not None: + continue + setattr(p, "color", {"color": color, "group": group}) + + +def extract_first_if_tuple(x): + return x[0] if isinstance(x, tuple) else x + + +def _get_fp8_weight_and_scale(weight, stacked=False, transpose=False): + """_get_fp8_weight_and_scale""" + if stacked: + if transpose: + fp8_weight, fp8_scale = weight.fp8_weight_stacked_transpose, weight.fp8_scale_stacked_transpose + else: + fp8_weight, fp8_scale = weight.fp8_weight_stacked, weight.fp8_scale_stacked + else: + if transpose: + fp8_weight, fp8_scale = weight.fp8_weight_transpose, weight.fp8_scale_transpose + else: + fp8_weight, fp8_scale = weight.fp8_weight, weight.fp8_scale + return fp8_weight, fp8_scale + + +def fused_stack_quant(expert_weight_list, transpose=False): + if transpose is False and hasattr(expert_weight_list[0], "fp8_weight_stacked"): + w, scale = _get_fp8_weight_and_scale(expert_weight_list[0], stacked=True, transpose=False) + elif transpose is True and hasattr(expert_weight_list[0], "fp8_weight_stacked_transpose"): + w, scale = _get_fp8_weight_and_scale(expert_weight_list[0], stacked=True, transpose=True) + elif transpose is True and hasattr(expert_weight_list[0], "fp8_weight_stacked"): + w, scale = _get_fp8_weight_and_scale(expert_weight_list[0], stacked=True, transpose=False) + elif transpose is False and hasattr(expert_weight_list[0], "fp8_weight_stacked_transpose"): + w, scale = _get_fp8_weight_and_scale(expert_weight_list[0], stacked=True, transpose=True) + else: + w, scale = paddle.incubate.nn.functional.fused_stack_transpose_quant(expert_weight_list, transpose=transpose) + return w, scale + + +def weight_quant(weight, transpose=False): + if transpose: + if hasattr(weight, "fp8_weight_transpose"): + return weight.fp8_weight_transpose, weight.fp8_scale_transpose + elif hasattr(weight, "fp8_weight"): + return weight.fp8_weight.T.contiguous(), weight.fp8_scale.T.contiguous() + else: + return paddle.incubate.nn.functional.fp8_quant_blockwise( + weight, + output_scale_transpose=False, + quant_method="128x128", + input_transpose=True, + return_transpose_only=True, + ) + else: + if hasattr(weight, "fp8_weight"): + return weight.fp8_weight, weight.fp8_scale + elif hasattr(weight, "fp8_weight_transpose"): + return weight.fp8_weight_transpose.T.contiguous(), weight.fp8_scale_transpose.T.contiguous() + else: + return paddle.incubate.nn.functional.fp8_quant_blockwise( + weight, + output_scale_transpose=False, + quant_method="128x128", + input_transpose=False, + return_transpose_only=False, + ) + + +class FP8LinearFunctionBase: + @staticmethod + def dequantize_fp8_to_fp32(fp8_tensor, scale): + res = fp8_tensor.reshape([-1, 128]).astype("bfloat16") * (scale.reshape([-1, 1])) + return res.reshape(fp8_tensor.shape) + + @staticmethod + def padding(x, axis): + if x.shape[axis] % 512 != 0: + if (x.shape[axis] + 128 - (x.shape[axis] % 128)) % 512 != 0: + padding_size = 512 + else: + padding_size = 128 + pad_size = padding_size - (x.shape[axis] % padding_size) + if axis == 0: + x = paddle.concat([x, paddle.zeros([pad_size, x.shape[-1]], dtype=x.dtype)], axis=0) + else: + x = paddle.concat([x, paddle.zeros([x.shape[0], pad_size], dtype=x.dtype)], axis=-1) + return x + + @staticmethod + def padding_and_quant_input(tensor): + """Quantize input to FP8, with fallback to padded transposed version if shape not aligned.""" + if tensor.shape[0] % 512 != 0: + tensor_fp8, tensor_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + tensor, output_scale_transpose=True, quant_method="1x128", input_transpose=False + ) + tensor = FP8LinearFunctionBase.padding(tensor, 0) + tensor_t_fp8, tensor_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + tensor, + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + return_transpose_only=True, + ) + return tensor_fp8, tensor_scale, tensor_t_fp8, tensor_t_scale + else: + tensor_fp8, tensor_scale, tensor_t_fp8, tensor_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + tensor, output_scale_transpose=True, quant_method="1x128", input_transpose=True + ) + return tensor_fp8, tensor_scale, tensor_t_fp8, tensor_t_scale + + @staticmethod + def kitchen_gemm( + x_fp8, x_scale, w_fp8, w_scale, is_a_1d_scaled, is_b_1d_scaled, out=None, rtn_dtype=paddle.bfloat16 + ): + if USE_DS_GEMM: + if out is None: + out = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], rtn_dtype) + if numpy.prod(x_fp8.shape) != 0 and numpy.prod(w_fp8.shape) != 0: + deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt((x_fp8, x_scale), (w_fp8, w_scale), out, num_sms=get_sm_num()) + return out + + if out is not None: + accumulate = True + out_dtype = out.dtype + else: + accumulate = False + out_dtype = rtn_dtype + if numpy.prod(x_fp8.shape) != 0 and numpy.prod(w_fp8.shape) != 0: + y = paddle.incubate.nn.functional.fp8_gemm_blockwise( + a=x_fp8, + a_decode_scale=x_scale, + b=w_fp8, + b_decode_scale=w_scale, + out_dtype=out_dtype, + out=out, + accumulate=accumulate, + use_split_accumulator=True, + is_a_1d_scaled=is_a_1d_scaled, + is_b_1d_scaled=is_b_1d_scaled, + ) + else: + y = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], out_dtype) + if out is not None: + out = out + y + return out + + return y + + @staticmethod + def compute_fp8_linear( + input, weight, weight_transpose=False, return_transpose_only=False, return_mode="output_only", *, out=None + ): + """ + FP8 Linear 计算函数,支持多种返回模式,支持量化/未量化输入。 + + Args: + input: 输入张量(原始或已经量化的(input_fp8, input_scale) 元组)。 + weight: 权重张量。 + weight_transpose (bool): 是否转置权重。 + return_transpose_only (bool): 是否仅返回转置后的权重。 + return_mode (str): 返回模式,可选: + - "output_only": 仅返回输出张量。 + - "with_input_quant": 返回输出 + 输入量化结果 (input_fp8, input_scale)。 + - "with_input_transpose_quant": 返回输出(out) + 输入量化转置结果 (input_t_fp8, input_t_scale). + Returns: + 根据 return_mode 返回不同组合的张量。 + + Raises: + RuntimeError: 如果 return_mode 不支持。 + """ + # check input + is_input_quantized = isinstance(input, (tuple, list)) and len(input) == 2 + + if is_input_quantized: + input_fp8, input_scale = input + if return_mode == "with_input_transpose_quant": + raise RuntimeError( + "Cannot return transposed quant if input is already quantized. " "Use raw input instead." + ) + else: + # quant input (with optional transposed output) + if return_mode == "with_input_transpose_quant": + input_fp8, input_scale, input_t_fp8, input_t_scale = FP8LinearFunctionBase.padding_and_quant_input( + input + ) + else: + input_fp8, input_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + input, + output_scale_transpose=True, + quant_method="1x128", + input_transpose=False, + return_transpose_only=False, + ) + + # quant weight + weight_fp8, weight_scale = weight_quant(weight, weight_transpose) + + # FP8 GEMM + if out is None: + out = paddle.empty([input_fp8.shape[0], weight_fp8.shape[0]], dtype=weight.dtype) + + deep_gemm.gemm_fp8_fp8_bf16_nt( + (input_fp8, input_scale.T), (weight_fp8, weight_scale), out, num_sms=get_sm_num() + ) + + # Return outputs + if return_mode == "output_only": + return out + elif return_mode == "with_input_quant": + return (out, input_fp8, input_scale) + elif return_mode == "with_input_transpose_quant": + return (out, input_t_fp8, input_t_scale) + else: + raise RuntimeError( + f"Unsupported return_mode: {return_mode}. " + "Supported modes: 'output_only', 'with_input_quant', 'with_input_transpose_quant'" + ) + + @staticmethod + def compute_expert_w_grad( + input_t, + input_t_scale, + dout_t, + dout_t_scale, + is_a_1d_scaled=True, + is_b_1d_scaled=True, + weight=None, + rtn_dtype=paddle.bfloat16, + ): + """ + 统一处理 expert_w 的梯度计算(支持 main_grad 和普通 grad) + """ + + if input_t is None or numpy.prod(input_t.shape) == 0: + return + + if hasattr(weight, "main_grad"): + if weight.main_grad is None: + weight.main_grad = paddle.zeros(shape=weight.shape, dtype=paddle.float32) + if WeightGradStore.enabled: + WeightGradStore.put( + partial( + FP8LinearFunctionBase.kitchen_gemm, + input_t, + input_t_scale, + dout_t, + dout_t_scale, + is_a_1d_scaled, + is_b_1d_scaled, + weight.main_grad, + rtn_dtype, + ) + ) + result = None + + else: + result = FP8LinearFunctionBase.kitchen_gemm( + input_t, + input_t_scale, + dout_t, + dout_t_scale, + is_a_1d_scaled, + is_b_1d_scaled, + weight.main_grad, + rtn_dtype, + ) + else: + if weight.grad is None: + weight.grad = paddle.zeros(shape=weight.shape, dtype=paddle.float32) + result = FP8LinearFunctionBase.kitchen_gemm( + input_t, input_t_scale, dout_t, dout_t_scale, is_a_1d_scaled, is_b_1d_scaled, weight.grad, rtn_dtype + ) + + if hasattr(weight, "_apply_backward_hook"): + weight._apply_backward_hook() + return result + + @staticmethod + def common_fp8_mlp_bwd( + do3, x_t_fp8, x_t_scale, w1, w2, o1=None, x_fp8=None, x_scale=None, apply_backward_hook=False + ): + if o1 is not None and (x_fp8 is not None or x_scale is not None): + raise ValueError("When o1 is provided, both x_fp8 and x_scale must be None.") + + if o1 is None: + if x_fp8 is None or x_scale is None: + raise ValueError("When o1 is None, both x_fp8 and x_scale must be provided.") + + # # ===== [recompute] o1 = deep_gemm(x_fp8, w1_t_fp8) ===== + + # Recompute o1 using deep_gemm(x_fp8, w1_t_fp8) + w1_fp8, w1_scale = weight_quant(w1, True) + o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=do3.dtype) + deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1, num_sms=get_sm_num()) + + # ===== [recompute] o2 = swiglu(o1) ===== + o2 = swiglu(o1) + + # ===== do2 = deep_gemm(do3_fp8, w2_fp8) + do2, do3_t_fp8, do3_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + do3, w2, return_mode="with_input_transpose_quant" + ) + + # ===== dw2 = deep_gemm(o2_t_fp8, do3_t_fp8) + o2 = FP8LinearFunctionBase.padding(o2, 0) + o2_t_fp8, o2_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + o2, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True + ) + if apply_backward_hook: + if WeightGradStore.enabled: + WeightGradStore.put( + partial( + FP8LinearFunctionBase.compute_expert_w_grad, + o2_t_fp8, + o2_t_scale, + do3_t_fp8, + do3_t_scale, + True, + True, + w2, + rtn_dtype=paddle.float32, + ) + ) + else: + + FP8LinearFunctionBase.compute_expert_w_grad( + o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, True, True, w2, rtn_dtype=paddle.float32 + ) + else: + dw2 = FP8LinearFunctionBase.kitchen_gemm( + o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, True, True, rtn_dtype=paddle.float32 + ) + + # ===== do1 = swiglu_grad(o1, None, do2) ===== + do1, _ = paddle._C_ops.swiglu_grad(o1, None, do2) + + # ===== dx = deep_gemm(do1_fp8, w1_fp8) ===== + dx, do1_t_fp8, do1_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + do1, w1, return_mode="with_input_transpose_quant" + ) + + # ===== dw1 = deep_gemm(x_t_fp8, do1_t_fp8) ===== + if apply_backward_hook: + if WeightGradStore.enabled: + WeightGradStore.put( + partial( + FP8LinearFunctionBase.compute_expert_w_grad, + x_t_fp8, + x_t_scale, + do1_t_fp8, + do1_t_scale, + True, + True, + w1, + rtn_dtype=paddle.float32, + ) + ) + + else: + FP8LinearFunctionBase.compute_expert_w_grad( + x_t_fp8, x_t_scale, do1_t_fp8, do1_t_scale, True, True, w1, rtn_dtype=paddle.float32 + ) + else: + dw1 = FP8LinearFunctionBase.kitchen_gemm( + x_t_fp8, x_t_scale, do1_t_fp8, do1_t_scale, True, True, rtn_dtype=paddle.float32 + ) + + if apply_backward_hook: + return dx + else: + assert dw1 is not None and dw2 is not None + return dx, dw1, dw2 + + @staticmethod + def fp8_mlp_fwd(x, w1, w2): + x_orig_shape = x.shape + x = x.reshape([-1, x_orig_shape[-1]]) + + # ===== o1 = deep_gemm(x_fp8, w1_t_fp8) ===== + o1, x_fp8, x_scale = FP8LinearFunctionBase.compute_fp8_linear( + x, w1, weight_transpose=True, return_transpose_only=True, return_mode="with_input_quant" + ) + + # ===== o2 = swiglu(o1) ===== + o2 = swiglu(o1) + + # ===== o3 = deep_gemm(o2_fp8, w2_t_fp8) ===== + o3 = FP8LinearFunctionBase.compute_fp8_linear(o2, w2, weight_transpose=True, return_transpose_only=True) + + if len(x_orig_shape) > 2: + o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]]) + + return o1, x_fp8, x_scale, o3 + + @staticmethod + def fp8_mlp_fwd_norm_rc(x, norm_w, norm_eps, w1, w2): + # ===== compute norm_output ===== + norm_output, _ = fused_ln.fused_rms_norm(x, norm_w, norm_eps) + # ===== compute fp8_mlp_fwd ===== + _, _, _, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(norm_output, w1, w2) + return o3 + + @staticmethod + def fp8_mlp_bwd(do3, x, w1, w2, apply_backward_hook=False): + do3_orig_shape = do3.shape + do3 = do3.reshape([-1, do3_orig_shape[-1]]) + + x_orig_shape = x.shape + x = x.reshape([-1, x_orig_shape[-1]]) + + x_fp8, x_scale, x_t_fp8, x_t_scale = FP8LinearFunctionBase.padding_and_quant_input(x) + + if apply_backward_hook: + dx = FP8LinearFunctionBase.common_fp8_mlp_bwd( + do3, + x_t_fp8, + x_t_scale, + w1, + w2, + o1=None, + x_fp8=x_fp8, + x_scale=x_scale, + apply_backward_hook=apply_backward_hook, + ) + if len(x_orig_shape) > 2: + dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]]) + return dx + else: + dx, dw1, dw2 = FP8LinearFunctionBase.common_fp8_mlp_bwd( + do3, + x_t_fp8, + x_t_scale, + w1, + w2, + o1=None, + x_fp8=x_fp8, + x_scale=x_scale, + apply_backward_hook=apply_backward_hook, + ) + if len(x_orig_shape) > 2: + dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]]) + + return dx, dw1, dw2 + + @staticmethod + def fp8_mlp_bwd_norm_rc(do3, x, norm_w, norm_eps, w1, w2): + # ===== recompute norm_output ===== + norm_output, invar = fused_ln.fused_rms_norm(x, norm_w, norm_eps) + + # ===== compute fp8_mlp_fwd ===== + d_norm_output = FP8LinearFunctionBase.fp8_mlp_bwd(do3, norm_output, w1, w2, True) + + if hasattr(norm_w, "_apply_backward_hook"): + norm_w._apply_backward_hook() + + return d_norm_output, norm_output, invar + + +class FP8LinearFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, custom_map, keep_x=False): + weight = custom_map.weight + x_orig_shape = x.shape + + # deep_gemm only support 2D + x = x.reshape([-1, x_orig_shape[-1]]).contiguous() + + if keep_x: + out = FP8LinearFunctionBase.compute_fp8_linear( + x, + weight, + weight_transpose=True, + return_transpose_only=True, + ) + # save for bwd + out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]]) + ctx.save_for_backward(x, weight) + return out + else: + x_t = x.T + out, x_t_fp8, x_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + x, weight, weight_transpose=True, return_transpose_only=True, return_mode="with_input_transpose_quant" + ) + out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]]) + ctx.save_for_backward((x_t_fp8, x_t_scale), weight) + ctx.x_t_shape = x_t.shape + return out + + @staticmethod + def backward(ctx, dout): + x, weight = ctx.saved_tensor() + dout_2d = dout.reshape([-1, dout.shape[-1]]) + + keep_x = not isinstance(x, tuple) + + if keep_x: + # padding x and quant + dx_orig_shape = x.shape + x = FP8LinearFunctionBase.padding(x, 0) + x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + x, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True + ) + + # ===== dx = deep_gemm(dout_fp8, w_fp8) + dx, dout_t_fp8, dout_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + dout_2d, weight, weight_transpose=False, return_mode="with_input_transpose_quant" + ) + dx = dx.reshape(dx_orig_shape) + + else: + x_t_fp8, x_t_scale = x + + # ===== dx = deep_gemm(dout_fp8, w_fp8) + dx, dout_t_fp8, dout_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + dout_2d, weight, weight_transpose=False, return_mode="with_input_transpose_quant" + ) + dx_orig_shape = dout.shape[:-1] + dx_orig_shape.append(ctx.x_t_shape[0]) + dx = dx.reshape(dx_orig_shape) + + # ===== dw1 = deep_gemm(x_t_fp8, dout_t_fp8) + FP8LinearFunctionBase.compute_expert_w_grad( + x_t_fp8, x_t_scale, dout_t_fp8, dout_t_scale, True, True, weight, paddle.float32 + ) + return dx + + +class FP8Linear(paddle.nn.Layer): + def __init__(self, in_features: int, out_features: int, bias_attr: bool = False) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.weight = self.create_parameter( + shape=[in_features, out_features], + dtype="bfloat16", + is_bias=False, + ) + + def forward(self, x): + return FP8LinearFunction.apply(x, self, keep_x=False) + + +def cache_fp8_weight(weight, quant_transpose=None): + if hasattr(weight, "fp8_weight") or hasattr(weight, "fp8_weight_transpose"): + return + if quant_transpose is None: + w_fp8, w_scale, w_t_fp8, w_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + weight, + output_scale_transpose=False, + quant_method="128x128", + input_transpose=True, + return_transpose_only=False, + ) + + setattr(weight, "fp8_weight_transpose", w_t_fp8) + setattr(weight, "fp8_scale_transpose", w_t_scale) + setattr(weight, "fp8_weight", w_fp8) + setattr(weight, "fp8_scale", w_scale) + elif quant_transpose is True: + w_t_fp8, w_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + weight, + output_scale_transpose=False, + quant_method="128x128", + input_transpose=True, + return_transpose_only=True, + ) + setattr(weight, "fp8_weight_transpose", w_t_fp8) + setattr(weight, "fp8_scale_transpose", w_t_scale) + elif quant_transpose is False: + w_fp8, w_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + weight, + output_scale_transpose=False, + quant_method="128x128", + input_transpose=False, + return_transpose_only=False, + ) + setattr(weight, "fp8_weight", w_fp8) + setattr(weight, "fp8_scale", w_scale) + else: + raise ValueError("quant_transpose must be either True, False or None.") + + +class FP8KeepXLinear(paddle.nn.Layer): + def __init__(self, in_features: int, out_features: int, bias_attr: bool = False) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.weight = self.create_parameter( + shape=[in_features, out_features], + dtype="bfloat16", + is_bias=False, + ) + set_parameter_color([self.weight], "attn_out_project") + + def fp8_quant_weight(self, quant_transpose=None): + cache_fp8_weight(self.weight, quant_transpose=quant_transpose) + + def forward(self, x): + return FP8LinearFunction.apply(x, self, keep_x=True) + + +class FusedNormFP8MLPFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, norm_w, w1, w2, norm_eps): + # ===== compute norm_output ===== + norm_output, invar = fused_ln.fused_rms_norm(x, norm_w, norm_eps) + # ===== reshape for deep_gemm, since deep_gemm only support 2D ===== + x_orig_shape = norm_output.shape + norm_output = norm_output.reshape([-1, x_orig_shape[-1]]) + + # ===== call func fp8_mlp_fwd ===== + _, _, _, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(norm_output, w1, w2) + + # ===== reshape to origin shape ===== + if len(x_orig_shape) > 2: + o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]]) + + # ===== save for backward ===== + ctx.save_for_backward( + norm_output, + invar, + x, + norm_w, + w1, + w2, + norm_eps, + paddle.to_tensor(x_orig_shape, dtype="int64", place=paddle.CPUPlace()), + ) + return o3 + + @staticmethod + def backward(ctx, do3): + # ===== reshape for deep_gemm, since deep_gemm only support 2D ===== + do3_orig_shape = do3.shape + do3 = do3.reshape([-1, do3_orig_shape[-1]]) + + # ===== recive saved tensors ===== + norm_output, invar, x, norm_w, w1, w2, norm_eps, x_orig_shape = ctx.saved_tensor() + + x_fp8, x_scale, x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + norm_output, output_scale_transpose=True, quant_method="1x128", input_transpose=True + ) + + # ===== call func common_fp8_mlp_bwd ===== + d_norm_output, dw1, dw2 = FP8LinearFunctionBase.common_fp8_mlp_bwd( + do3, x_t_fp8, x_t_scale, w1, w2, o1=None, x_fp8=x_fp8, x_scale=x_scale + ) + + # ===== reshape to origin shape ===== + if len(x_orig_shape) > 2: + d_norm_output = d_norm_output.reshape([x_orig_shape[0], -1, d_norm_output.shape[-1]]) + + # ===== compute norm grad ===== + dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, norm_w, invar, d_norm_output, norm_eps) + + return dx, d_rms_norm_weight, dw1, dw2 + + +class FP8MlpFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, w1, w2, recompute_fwd_gate_up): + # ===== reshape for deep_gemm, since deep_gemm only support 2D ===== + x_orig_shape = x.shape + x = x.reshape([-1, x_orig_shape[-1]]) + + # ===== call func fp8_mlp_fwd ===== + o1, x_fp8, x_scale, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(x, w1, w2) + # ===== reshape to origin shape ===== + if len(x_orig_shape) > 2: + o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]]) + + # ===== save for backward ===== + o1 = None if recompute_fwd_gate_up else o1 + ctx.save_for_backward( + o1, + x_fp8, + x_scale, + w1, + w2, + paddle.to_tensor(x_orig_shape, dtype="int64", place=paddle.CPUPlace()), + ) + return o3 + + @staticmethod + def backward(ctx, do3): + # ===== reshape for deep_gemm, since deep_gemm only support 2D ===== + do3_orig_shape = do3.shape + do3 = do3.reshape([-1, do3_orig_shape[-1]]) + + # ===== recive saved tensors ===== + o1, x_fp8, x_scale, w1, w2, x_orig_shape = ctx.saved_tensor() + + # ===== compute x_t_fp8, x_t_scale for dw1 ===== + x_dequant_fp16 = paddle.incubate.nn.functional.fused_act_dequant(x_fp8, x_scale.T.contiguous()) + x_dequant_fp16 = FP8LinearFunctionBase.padding(x_dequant_fp16, 0) + + x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + x_dequant_fp16, + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + return_transpose_only=True, + ) + + # ===== call func common_fp8_mlp_bwd ===== + if o1 is None: + dx = FP8LinearFunctionBase.common_fp8_mlp_bwd( + do3, x_t_fp8, x_t_scale, w1, w2, o1=None, x_fp8=x_fp8, x_scale=x_scale, apply_backward_hook=True + ) + else: + dx = FP8LinearFunctionBase.common_fp8_mlp_bwd( + do3, x_t_fp8, x_t_scale, w1, w2, o1=o1, x_fp8=None, x_scale=None, apply_backward_hook=True + ) + # ===== reshape to origin shape ===== + if len(x_orig_shape) > 2: + dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]]) + + return dx, None, None + + +class FP8Mlp(paddle.nn.Layer): + def __init__( + self, + config, + hidden_size=None, + intermediate_size=None, + is_moe=False, + using_post_norm_recompute=False, + norm_weight=None, + norm_eps=None, + recompute_fwd_gate_up=False, + ): + super().__init__() + self.config = config + self.using_post_norm_recompute = using_post_norm_recompute + if self.using_post_norm_recompute: + assert norm_weight is not None and norm_eps is not None + self.norm_weight = norm_weight + self.norm_eps = norm_eps + + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + + self.recompute_fwd_gate_up = recompute_fwd_gate_up + + self.w1 = self.create_parameter( + shape=[self.hidden_size, self.intermediate_size * 2], + dtype="bfloat16", + is_bias=False, + ) + self.w2 = self.create_parameter( + shape=[self.intermediate_size, self.hidden_size], + dtype="bfloat16", + is_bias=False, + ) + + def fp8_quant_weight(self, quant_transpose=None): + cache_fp8_weight(self.w1, quant_transpose) + cache_fp8_weight(self.w2, quant_transpose) + + def forward(self, x): + if self.using_post_norm_recompute: + return FusedNormFP8MLPFunction.apply(x, self.norm_weight, self.w1, self.w2, self.norm_eps) + else: + return FP8MlpFunction.apply(x, self.w1, self.w2, self.recompute_fwd_gate_up) + + +def split_group_gemm(x_fp8, x_scale, w_fp8, w_scale, tokens_per_expert, gemm_out): + start_idx = 0 + for i, token_num in enumerate(tokens_per_expert): + if token_num == 0: + continue + end_idx = start_idx + token_num + + x_scale_tma_align = x_scale[start_idx:end_idx].T.contiguous().T + + deep_gemm.gemm_fp8_fp8_bf16_nt( + (x_fp8[start_idx:end_idx], x_scale_tma_align), + (w_fp8[i], w_scale[i]), + gemm_out[start_idx:end_idx], + num_sms=get_sm_num(), + ) + + start_idx = end_idx + + return gemm_out + + +class FP8GroupGemmMlpFunctionNode: + def __init__( + self, + custom_map, + recompute_fwd_gate_up=False, + is_split_group_gemm=False, + name="experts_group_gemm_contiguous_node", + ): + self.experts = custom_map.experts + self.recompute_fwd_gate_up = recompute_fwd_gate_up + self.is_split_group_gemm = is_split_group_gemm + self.m_indices = None + self.input = None + self.input_fp8 = None + self.input_scale = None + self.o1 = None + self.all_unzipped_grad = None + self.fwd_subbatch = None + self.bwd_subbatch = None + + def reset_statue(self): + self.m_indices = None + self.fwd_subbatch = None + self.bwd_subbatch = None + self.clear_activation_tensors() + + def clear_activation_tensors(self): + self.input = None + self.input_fp8 = None + self.input_scale = None + self.o1 = None + self.all_unzipped_grad = None + + def gen_m_indices(self, tokens_per_expert): + tokens = [] + for i in range(len(tokens_per_expert)): + tokens.append(paddle.full([tokens_per_expert[i]], i, dtype="int32")) + out = paddle.concat(tokens, axis=0) + return out + + def fwd_gate_up(self, x, expert_w1, num_expert, tokens_per_expert, m_indices=None): + """ + o1 = x * w1 + [m_sum, n] = [m_sum, k] * [num_groups, k, n] (m_sum = sum(tokens_per_expert)) + """ + if not self.is_split_group_gemm and self.m_indices is None: + self.m_indices = self.gen_m_indices(tokens_per_expert) + # concat w1, shape is [num_groups, n, k] + w1_t_quant, w1_t_scale = fused_stack_quant(expert_w1, transpose=True) + w1_t_quant = w1_t_quant.reshape([num_expert, -1, w1_t_quant.shape[-1]]) + w1_t_scale = w1_t_scale.reshape([num_expert, -1, w1_t_scale.shape[-1]]) + + if hasattr(expert_w1[0], "fp8_weight_stacked") and not hasattr(expert_w1[0], "fp8_weight_stacked_transpose"): + w1_t_quant = w1_t_quant.contiguous().transpose([0, 2, 1]).contiguous() + w1_t_scale = w1_t_scale.contiguous().transpose([0, 2, 1]).contiguous() + + if x is None: + x_fp8, x_scale = self.input_fp8, self.input_scale + assert x_fp8 is not None and x_scale is not None + else: + if isinstance(x, tuple): + (x_fp8, x_scale) = x + x_scale = paddle.transpose(paddle.transpose(x_scale, [1, 0]).contiguous(), [1, 0]) + else: + # quant x_bf16 + x_fp8, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + x, output_scale_transpose=True, quant_method="1x128", input_transpose=False + ) + x_scale = x_scale.T + + # compute gemm + o1 = paddle.empty([x_fp8.shape[0], w1_t_quant.shape[1]], dtype=expert_w1[0].dtype) + if numpy.prod(x_fp8.shape) != 0: + if self.is_split_group_gemm: + split_group_gemm(x_fp8, x_scale, w1_t_quant, w1_t_scale, tokens_per_expert, o1) + else: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (x_fp8, x_scale), + (w1_t_quant, w1_t_scale), + o1, + m_indices=self.m_indices if m_indices is None else m_indices, + num_sms=get_sm_num(), + ) + + if m_indices is None: + self.input_fp8 = x_fp8 + self.input_scale = x_scale + return o1 + + def fwd_swiglu(self, o1): + o2 = swiglu(o1) + return o2 + + def fwd_down( + self, o1, unzipped_probs, expert_w2, num_expert, tokens_per_expert, m_indices=None, o3=None, clear_o1=False + ): + """ + o3 = o2 * w2 + [m_sum, k] = [m_sum, n] * [num_groups, n, k] + """ + # concat and transpose w2 + w2_quant, w2_scale = fused_stack_quant(expert_w2, transpose=True) + w2_quant = w2_quant.reshape([num_expert, -1, w2_quant.shape[-1]]) + w2_scale = w2_scale.reshape([num_expert, -1, w2_scale.shape[-1]]) + + if hasattr(expert_w2[0], "fp8_weight_stacked") and not hasattr(expert_w2[0], "fp8_weight_stacked_transpose"): + w2_quant = w2_quant.contiguous().transpose([0, 2, 1]).contiguous() + w2_scale = w2_scale.contiguous().transpose([0, 2, 1]).contiguous() + + # quant o2 + with paddle.amp.auto_cast(False): + unzipped_probs = unzipped_probs.squeeze(-1) + o2_fp8, o2_scale = paddle.incubate.nn.functional.fused_weighted_swiglu_act_quant( + o1, unzipped_probs, using_pow2_scaling=True + ) + o2_scale = paddle.transpose(paddle.transpose(o2_scale, [1, 0]).contiguous(), [1, 0]) + + if clear_o1: + o1._clear_to_zero_allocation() + + # compute gemm + o3_shape = [o2_fp8.shape[0], w2_quant.shape[1]] + if o3 is not None: + assert o3.shape == o3_shape, "{} vs {}".format(o3.shape, o3_shape) + else: + o3 = paddle.empty(o3_shape, dtype=o1.dtype) + if numpy.prod(o2_fp8.shape) != 0: + if self.is_split_group_gemm: + split_group_gemm(o2_fp8, o2_scale, w2_quant, w2_scale, tokens_per_expert, o3) + else: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (o2_fp8, o2_scale), + (w2_quant, w2_scale), + o3, + m_indices=m_indices if self.fwd_subbatch else self.m_indices, + num_sms=get_sm_num(), + ) + + return o3 + + def bwd_dowm_input(self, expert_w2, unzipped_grad, o1, tokens_per_expert, m_indices=None, unzipped_probs=None): + """ + do2 = do3 * w2_t + [m_sum, n] = [m_sum, k] * [num_groups, k, n] + """ + # recompute concated_w2_2d + bw_w2_quant, bw_w2_scale = fused_stack_quant(expert_w2, transpose=False) + bw_w2_quant = bw_w2_quant.reshape([len(expert_w2), -1, bw_w2_quant.shape[-1]]) + bw_w2_scale = bw_w2_scale.reshape([len(expert_w2), -1, bw_w2_scale.shape[-1]]) + + if hasattr(expert_w2[0], "fp8_weight_stacked_transpose") and not hasattr(expert_w2[0], "fp8_weight_stacked"): + bw_w2_quant = bw_w2_quant.contiguous().transpose([0, 2, 1]).contiguous() + bw_w2_scale = bw_w2_scale.contiguous().transpose([0, 2, 1]).contiguous() + + # compute gemm + if isinstance(unzipped_grad, tuple): + (unzipped_grad_fp8, unzipped_grad_scale) = unzipped_grad + unzipped_grad_scale = unzipped_grad_scale.T.contiguous().T + else: + unzipped_grad_fp8, unzipped_grad_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + unzipped_grad, output_scale_transpose=True, quant_method="1x128", input_transpose=False + ) + unzipped_grad_scale = unzipped_grad_scale.T + + do2_s = paddle.empty([unzipped_grad_fp8.shape[0], bw_w2_quant.shape[1]], dtype="bfloat16") + if numpy.prod(unzipped_grad_fp8.shape) != 0: + if self.is_split_group_gemm: + split_group_gemm( + unzipped_grad_fp8, unzipped_grad_scale, bw_w2_quant, bw_w2_scale, tokens_per_expert, do2_s + ) + else: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (unzipped_grad_fp8, unzipped_grad_scale), + (bw_w2_quant, bw_w2_scale), + do2_s, + m_indices=m_indices if self.bwd_subbatch else self.m_indices, + num_sms=get_sm_num(), + ) + + with paddle.amp.auto_cast(False): + do1, probs_grad, o2_s = paddle.incubate.nn.functional.fused_swiglu_weighted_bwd(o1, do2_s, unzipped_probs) + + return do1, o2_s, probs_grad + + def bwd_swiglu(self, o1, do2): + do1, _ = paddle._C_ops.swiglu_grad(o1, None, do2) + return do1 + + def bwd_gate_up_input(self, do1, expert_w1, tokens_per_expert, m_indices=None, dx=None): + """ + dx = do1 * w1_t + [m_sum, k] = [m_sum, n] * [num_groups, n, k] + """ + # recompute concated_w1_t + bw_w1_quant, bw_w1_scale = fused_stack_quant(expert_w1, transpose=False) + bw_w1_quant = bw_w1_quant.reshape([len(expert_w1), -1, bw_w1_quant.shape[-1]]) + bw_w1_scale = bw_w1_scale.reshape([len(expert_w1), -1, bw_w1_scale.shape[-1]]) + + if hasattr(expert_w1[0], "fp8_weight_stacked_transpose") and not hasattr(expert_w1[0], "fp8_weight_stacked"): + bw_w1_quant = bw_w1_quant.contiguous().transpose([0, 2, 1]).contiguous() + bw_w1_scale = bw_w1_scale.contiguous().transpose([0, 2, 1]).contiguous() + + # quant do1 + do1_fp8, do1_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + do1, output_scale_transpose=True, quant_method="1x128", input_transpose=False + ) + do1_scale = do1_scale.T + # compute gemm + dx_shape = [do1_fp8.shape[0], bw_w1_quant.shape[1]] + if dx is None or dx.dtype != do1.dtype: + dx = paddle.empty(shape=dx_shape, dtype=do1.dtype) + else: + assert dx.shape == dx_shape, f"{dx.shape} vs {dx_shape}" + if numpy.prod(do1_fp8.shape) != 0: + if self.is_split_group_gemm: + split_group_gemm(do1_fp8, do1_scale, bw_w1_quant, bw_w1_scale, tokens_per_expert, dx) + else: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (do1_fp8, do1_scale), + (bw_w1_quant, bw_w1_scale), + dx, + m_indices=m_indices if self.bwd_subbatch else self.m_indices, + num_sms=get_sm_num(), + ) + + return dx + + def fused_transpose_split_quant(self, x, scale, tokens_per_expert, pow_2_scales): + out, scale = paddle.incubate.nn.functional.fused_transpose_split_quant( + x, scale, tokens_per_expert, pow_2_scales + ) + return out, scale + + def bwd_down_weight(self, do3, o2, expert_w2, tokens_per_expert): + """ + dw2 = do2_t * do3 + [n, k] = [n, m_sum] * [m_sum, k] (m_sum = sum(tokens_per_expert)) + """ + if isinstance(o2, tuple): + o2_t_fp8, o2_t_scale = o2 + else: + o2_t_fp8, o2_t_scale = self.fused_transpose_split_quant(o2, None, tokens_per_expert, True) + + if isinstance(do3, tuple): + do3_t_fp8, do3_t_scale = do3 + else: + do3_t_fp8, do3_t_scale = self.fused_transpose_split_quant(do3, None, tokens_per_expert, True) + + def cal_weight_fn(o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, expert_w2): + with paddle.no_grad(): + for i in range(len(expert_w2)): + FP8LinearFunctionBase.compute_expert_w_grad( + o2_t_fp8[i], + o2_t_scale[i], + do3_t_fp8[i], + do3_t_scale[i], + True, + True, + expert_w2[i], + paddle.float32, + ) + + if WeightGradStore.enabled: + WeightGradStore.put(partial(cal_weight_fn, o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, expert_w2)) + else: + cal_weight_fn(o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, expert_w2) + + def bwd_gate_up_weight( + self, + do1, + input_x, + expert_w1, + tokens_per_expert, + input_fp8_slice=None, + input_scale_slice=None, + clear_input=False, + ): + """ + dw1 = dx_t * do1 + [k, n] = [k, m_sum] * [m_sum, n] (m_sum = sum(tokens_per_expert)) + """ + if input_x is None: + inp = (input_fp8_slice, input_scale_slice) if self.bwd_subbatch else (self.input_fp8, self.input_scale) + input_x_t_fp8, input_x_t_scale = self.fused_transpose_split_quant(inp[0], inp[1], tokens_per_expert, True) + + else: + input_x_t_fp8, input_x_t_scale = self.fused_transpose_split_quant(input_x, None, tokens_per_expert, True) + + if clear_input: + if self.input_fp8 is not None: + self.input_fp8._clear_to_zero_allocation() + self.input_fp8 = None + if self.input_scale is not None: + self.input_scale._clear_to_zero_allocation() + self.input_scale = None + if self.input is not None: + self.input._clear_to_zero_allocation() + self.input = None + + do1_t_fp8, do1_t_scale = self.fused_transpose_split_quant(do1, None, tokens_per_expert, True) + + def cal_weight_fn(input_x_t_fp8, input_x_t_scale, do1_t_fp8, do1_t_scale, expert_w1): + with paddle.no_grad(): + for i in range(len(expert_w1)): + FP8LinearFunctionBase.compute_expert_w_grad( + input_x_t_fp8[i], + input_x_t_scale[i], + do1_t_fp8[i], + do1_t_scale[i], + True, + True, + expert_w1[i], + paddle.float32, + ) + + if WeightGradStore.enabled: + WeightGradStore.put( + partial(cal_weight_fn, input_x_t_fp8, input_x_t_scale, do1_t_fp8, do1_t_scale, expert_w1) + ) + else: + cal_weight_fn(input_x_t_fp8, input_x_t_scale, do1_t_fp8, do1_t_scale, expert_w1) + + @paddle.no_grad() + def forward(self, hs_out, unzipped_probs, tokens_per_expert, m_indices=None): + # check subbatch + if self.fwd_subbatch: + assert m_indices is not None + # deal 0 size + dtype = paddle.bfloat16 + if hs_out is None: + assert self.input_fp8 is not None + assert self.input_scale is not None + shape = self.input_fp8.shape + else: + if isinstance(hs_out, tuple): + shape = hs_out[0].shape + else: + shape = hs_out.shape + + if shape[0] == 0: + o3 = paddle.zeros(shape, dtype=dtype) + return o3 + + # get w1/w2 + expert_w1 = [x.w1 for x in self.experts if x is not None] + expert_w2 = [x.w2 for x in self.experts if x is not None] + + num_expert = len(expert_w1) + + # o1 + o1 = self.fwd_gate_up(hs_out, expert_w1, num_expert, tokens_per_expert, m_indices) + if not self.recompute_fwd_gate_up: + self.o1 = o1 + clear_o1 = False + else: + clear_o1 = True + + # o3 + o3 = self.fwd_down( + o1, unzipped_probs, expert_w2, num_expert, tokens_per_expert, clear_o1=clear_o1, m_indices=m_indices + ) + + # save for bwd + return o3 + + @paddle.no_grad() + def backward( + self, + out_grad, + unzipped_probs, + tokens_per_expert, + input_fp8_slice=None, + input_scale_slice=None, + m_indices=None, + reset_status=False, + ): + # check subbatch + if self.bwd_subbatch: + assert ( + m_indices is not None + and input_fp8_slice is not None + and input_scale_slice is not None + and tokens_per_expert is not None + ) + # deal 0 size + dtype = paddle.bfloat16 + shape = out_grad[0].shape if isinstance(out_grad, tuple) else out_grad.shape + if shape[0] == 0: + return paddle.zeros_like(extract_first_if_tuple(out_grad), dtype=dtype), paddle.zeros_like(unzipped_probs) + + # recompute expert_w2 and expert_w1 + expert_w1 = [x.w1 for x in self.experts if x is not None] + expert_w2 = [x.w2 for x in self.experts if x is not None] + + if self.recompute_fwd_gate_up: + inp = None if not self.bwd_subbatch else (input_fp8_slice, input_scale_slice) + o1 = self.fwd_gate_up(inp, expert_w1, len(expert_w1), tokens_per_expert, m_indices=m_indices) + else: + o1 = self.o1 + + # do2 + do1, o2_s, probs_grad = self.bwd_dowm_input( + expert_w2, out_grad, o1, tokens_per_expert, unzipped_probs=unzipped_probs, m_indices=m_indices + ) + del o1 + if self.o1 is not None: + self.o1._clear_to_zero_allocation() + self.o1 = None + + # dw1 + self.bwd_gate_up_weight( + do1, + None, + expert_w1, + tokens_per_expert, + input_fp8_slice=input_fp8_slice, + input_scale_slice=input_scale_slice, + clear_input=reset_status, + ) + + if reset_status: + if self.input_fp8 is not None: + self.input_fp8._clear_to_zero_allocation() + self.input_fp8 = None + if self.input_scale is not None: + self.input_scale._clear_to_zero_allocation() + self.input_scale = None + if self.input is not None: + self.input._clear_to_zero_allocation() + self.input = None + + # dx + dx = self.bwd_gate_up_input( + do1, + expert_w1, + tokens_per_expert, + dx=out_grad[0] if isinstance(out_grad, tuple) else out_grad, + m_indices=m_indices, + ) + del do1 + + # dw2 + if isinstance(out_grad, tuple): + do3_fp8, do3_scale = self.fused_transpose_split_quant(out_grad[0], out_grad[1], tokens_per_expert, True) + out_grad[0]._clear_to_zero_allocation() + out_grad[1]._clear_to_zero_allocation() + self.bwd_down_weight((do3_fp8, do3_scale), o2_s, expert_w2, tokens_per_expert) + else: + self.bwd_down_weight(out_grad, o2_s, expert_w2, tokens_per_expert) + + if reset_status: + self.reset_statue() + return dx, probs_grad diff --git a/paddleformers/transformers/fused_a2a.py b/paddleformers/transformers/fused_a2a.py index 7b5fa09c9e0..7e1b6c9c22a 100644 --- a/paddleformers/transformers/fused_a2a.py +++ b/paddleformers/transformers/fused_a2a.py @@ -72,78 +72,145 @@ def get_buffer(group: Group, hidden_bytes: int): return _buffer +def fused_dispatch_forward_func( + x, + token_indices, + token_probs, + num_experts, + group, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, +): + """Forward pass of fused dispatch.""" + # Calculate layout before actual dispatch + if isinstance(x, tuple): + buffer = get_buffer(group, get_hidden_bytes(x[0])) + else: + buffer = get_buffer(group, get_hidden_bytes(x)) + ( + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + is_token_in_rank, + previous_event_, + ) = buffer.get_dispatch_layout( + token_indices, + num_experts, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + assert token_probs.dtype == paddle.float32 + # Do MoE dispatch + # NOTES: the CPU will wait for GPU's signal to arrive, + # so this is not compatible with CUDA graph + (recv_x, recv_token_indices, recv_token_probs, num_recv_tokens_per_expert_list, handle, event,) = buffer.dispatch( + x, + topk_idx=token_indices, + topk_weights=token_probs, + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=num_tokens_per_expert, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + states = dict() + states["dispatched_indices"] = recv_token_indices + states["tokens_per_expert"] = num_recv_tokens_per_expert_list + states["handle"] = handle + + return recv_x, recv_token_probs, states, event + + +def fused_dispatch_backward_func( + grad_output, + grad_token_probs, + group, + handle, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, +): + """Backward pass of fused dispatch.""" + buffer = get_buffer(group, get_hidden_bytes(grad_output)) + + grad_x, grad_token_probs, event = buffer.combine( + grad_output.contiguous(), + handle, + topk_weights=grad_token_probs.cast(paddle.float32), + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + return grad_x, None, grad_token_probs + + +def fused_combine_forward_func( + x, group, states, previous_event=None, async_finish=False, allocate_on_comm_stream=False +): + """Forward pass of fused combine.""" + handle = states["handle"] + buffer = get_buffer(group, get_hidden_bytes(x)) + combined_x, _, event = buffer.combine( + x, + handle=handle, + async_finish=async_finish, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + return combined_x + + +def fused_combine_backward_func( + grad_output, group, handle, previous_event=None, async_finish=False, allocate_on_comm_stream=False +): + """Backward pass of fused combine.""" + if isinstance(grad_output, tuple): + buffer = get_buffer(group, get_hidden_bytes(grad_output[0])) + grad_x, _, _, _, _, event = buffer.dispatch( + (grad_output[0].contiguous(), grad_output[1].contiguous()), + handle=handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + else: + buffer = get_buffer(group, get_hidden_bytes(grad_output)) + grad_x, _, _, _, _, event = buffer.dispatch( + grad_output.contiguous(), + handle=handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + return grad_x + + class FusedDispatch(PyLayer): """Fused dispatch operation for MoE routing combining computation and communication.""" @staticmethod def forward(ctx, x, token_indices, token_probs, num_experts, group, previous_event=None): """Forward pass of fused dispatch.""" - # Calculate layout before actual dispatch - buffer = get_buffer(group, get_hidden_bytes(x)) - ( - num_tokens_per_rank, - num_tokens_per_rdma_rank, - num_tokens_per_expert, - is_token_in_rank, - previous_event, - ) = buffer.get_dispatch_layout( - token_indices, - num_experts, - previous_event=None, - async_finish=False, - allocate_on_comm_stream=False, - ) - - # Do MoE dispatch - # NOTES: the CPU will wait for GPU's signal to arrive, - # so this is not compatible with CUDA graph - ( - recv_x, - recv_token_indices, - recv_token_probs, - num_recv_tokens_per_expert_list, - handle, - event, - ) = buffer.dispatch( - x, - topk_idx=token_indices, - topk_weights=token_probs.cast(paddle.float32), - num_tokens_per_rank=num_tokens_per_rank, - num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, - is_token_in_rank=is_token_in_rank, - num_tokens_per_expert=num_tokens_per_expert, - previous_event=None, - async_finish=False, - allocate_on_comm_stream=False, + recv_x, recv_token_probs, states, event = fused_dispatch_forward_func( + x, token_indices, token_probs, num_experts, group, previous_event ) ctx.group = group - ctx.handle = handle + ctx.handle = states["handle"] ctx.event = event - tokens_per_expert = paddle.to_tensor(num_recv_tokens_per_expert_list) - - states = dict() - states["dispatched_indices"] = recv_token_indices - states["tokens_per_expert"] = tokens_per_expert - states["handle"] = handle return recv_x, recv_token_probs, states @staticmethod def backward(ctx, grad_output, grad_token_probs): """Backward pass of fused dispatch.""" - buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output)) - handle = ctx.handle - - grad_x, grad_token_probs, event = buffer.combine( - grad_output.contiguous(), - handle, - topk_weights=grad_token_probs.cast(paddle.float32), - previous_event=None, - async_finish=False, - allocate_on_comm_stream=False, - ) - return grad_x, None, grad_token_probs + return fused_dispatch_backward_func(grad_output, grad_token_probs, ctx.group, ctx.handle) class FusedCombine(PyLayer): @@ -152,12 +219,9 @@ class FusedCombine(PyLayer): @staticmethod def forward(ctx, x, group, states, previous_event=None): """Forward pass of fused combine.""" - handle = states["handle"] - buffer = get_buffer(group, get_hidden_bytes(x)) - combined_x, _, event = buffer.combine( - x, handle=handle, async_finish=False, previous_event=None, allocate_on_comm_stream=False - ) - ctx.handle = handle + combined_x = fused_combine_forward_func(x, group, states, previous_event) + + ctx.handle = states["handle"] ctx.group = group ctx.previous_event = previous_event @@ -166,15 +230,7 @@ def forward(ctx, x, group, states, previous_event=None): @staticmethod def backward(ctx, grad_output): """Backward pass of fused combine.""" - buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output)) - grad_x, _, _, _, _, event = buffer.dispatch( - grad_output.contiguous(), - handle=ctx.handle, - previous_event=ctx.previous_event, - async_finish=False, - allocate_on_comm_stream=False, - ) - return grad_x + return fused_combine_backward_func(grad_output, ctx.group, ctx.handle, ctx.previous_event) if HAVE_DEEP_EP: @@ -214,3 +270,96 @@ def fused_combine(x, group, handle, previous_event=None): else: fused_dispatch = None fused_combine = None + + +class DispatchNode: + def __init__(self, name="dispatch"): + self.name = name + + def reset_statue(self): + self.handle = None + + def forward( + self, + x, + token_indices, + token_probs, + num_experts, + group, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, + ): + """Forward pass of fused dispatch.""" + recv_x, recv_token_probs, states, event = fused_dispatch_forward_func( + x, + token_indices, + token_probs, + num_experts, + group, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + self.group = group + self.handle = states["handle"] + self.event = event + + return recv_x, recv_token_probs, states + + def backward( + self, grad_output, grad_token_probs, previous_event=None, async_finish=False, allocate_on_comm_stream=False + ): + """Backward pass of fused dispatch.""" + out = fused_dispatch_backward_func( + grad_output, + grad_token_probs, + self.group, + self.handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + self.reset_statue() + return out + + +class CombineNode: + def __init__(self, name="combine"): + self.name = name + + def reset_statue(self): + self.handle = None + + def forward(self, x, group, handle, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + """Forward pass of fused combine.""" + states = dict() + states["handle"] = handle + combined_x = fused_combine_forward_func( + x, + group, + states, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + self.handle = handle + self.group = group + self.previous_event = previous_event + + return combined_x + + def backward(self, grad_output, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + """Backward pass of fused combine.""" + out = fused_combine_backward_func( + grad_output, + self.group, + self.handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + self.reset_statue() + return out diff --git a/paddleformers/transformers/moe_gate.py b/paddleformers/transformers/moe_gate.py index e76af803d73..693b2dafcf7 100644 --- a/paddleformers/transformers/moe_gate.py +++ b/paddleformers/transformers/moe_gate.py @@ -226,7 +226,7 @@ def _priority(self, topk_idx: paddle.Tensor, capacity: int) -> paddle.Tensor: chosen_expert = topk_idx.reshape([-1]) # Shape: [seq_len * k, num_experts]. token_priority = F.one_hot(chosen_expert, self.num_experts).cast(paddle.int32) - token_priority = paddle.logical_and(token_priority > 0, token_priority.cumsum(axis=0) <= capacity) + token_priority = paddle.logical_and(token_priority > 0, token_priority.cumsum(axis=0) < capacity) # Shape: [seq_len, num_experts]. token_priority = token_priority.reshape([-1, k, self.num_experts]).sum(axis=1) @@ -270,7 +270,7 @@ def _topk_group_limited_greedy( group_scores = scores.reshape([0, n_group, -1]).max(axis=-1) # [n, n_group] group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [n, top_k_group] - group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.to_tensor(1.0), axis=-1) # fmt:skip + group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.ones([], dtype="float32"), axis=-1) # fmt:skip score_mask = ( group_mask.unsqueeze(-1).expand([bsz_seq_len, n_group, n_experts // n_group]).reshape([bsz_seq_len, -1]) ) # [n, e] @@ -302,11 +302,11 @@ def _topk_noaux_tc( assert self.e_score_correction_bias is not None, "e_score_correction_bias is None" scores_for_choice = scores.reshape([bsz_seq_len, -1]) + self.e_score_correction_bias.unsqueeze(0) - group_scores = ( - scores_for_choice.reshape([bsz_seq_len, self.n_group, -1]).topk(2, axis=-1)[0].sum(axis=-1) - ) # fmt:skip [n, n_group] + reshape_tmp_rst = scores_for_choice.reshape([bsz_seq_len, self.n_group, -1]) + top_k = min(reshape_tmp_rst.shape[2], 2) + group_scores = reshape_tmp_rst.topk(top_k, axis=-1)[0].sum(axis=-1) # fmt:skip [n, n_group] group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [n, top_k_group] - group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.to_tensor(1.0, dtype="float32"), axis=-1) # fmt:skip + group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.ones([], dtype="float32"), axis=-1) # fmt:skip score_mask = ( group_mask.unsqueeze(-1).expand([bsz_seq_len, n_group, n_experts // n_group]).reshape([bsz_seq_len, -1]) ) # [n, e] @@ -370,9 +370,7 @@ def top1gating( _, top_idx = paddle.topk(mask1_rand, k=capacity, axis=0) # Select top_capacity tokens - new_mask1 = mask1 * paddle.zeros_like(mask1).put_along_axis( - top_idx, paddle.to_tensor(1.0, dtype="float32"), axis=0 - ) + new_mask1 = mask1 * paddle.zeros_like(mask1).put_along_axis(top_idx, paddle.ones([], dtype="float32"), axis=0) mask1 = new_mask1 # Compute locations in capacity buffer @@ -496,7 +494,7 @@ def topkgating( top_gate = top_gate * self.routed_scaling_factor # get topk mask - mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0, dtype="float32"), axis=1) + mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.ones([], dtype="float32"), axis=1) if hasattr(self.config, "seq_aux") and self.config.seq_aux: l_aux = self._cal_seq_aux_loss(gates_ori, self.top_k, top_idx) else: @@ -532,14 +530,12 @@ def topkgating( token_priority = self._priority(top_idx, capacity) # normalize gates - # gates_masked is equal to top_gate. gates_masked = gates * mask - # if self.training: - gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True) - denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps) - if self.norm_topk_prob: - gates_masked = gates_masked / denom_s - gates_masked *= self.routed_scaling_factor + if self.training: + gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True) + denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps) + if self.norm_topk_prob: + gates_masked = gates_masked / denom_s return ( capacity, @@ -569,14 +565,25 @@ def topkgating_nodrop(self, gates: paddle.Tensor): top_gate, top_idx = self._topk_noaux_tc( gates, k=self.top_k, n_group=self.n_group, topk_group=self.topk_group ) + # norm gate to sum 1 - if self.top_k > 1 and self.norm_topk_prob: - denominator = top_gate.sum(axis=-1, keepdim=True) + 1e-20 - top_gate = top_gate / denominator - top_gate = top_gate * self.routed_scaling_factor + # if self.top_k > 1 and self.norm_topk_prob: + # denominator = top_gate.sum(axis=-1, keepdim=True) + 1e-20 + # top_gate = top_gate / denominator + # top_gate = top_gate * self.routed_scaling_factor # get topk mask - mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0), axis=1) + mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.ones([], dtype="float32"), axis=1) + + gates_masked = gates * mask + # if self.training: + gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True) + denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps) + + if self.norm_topk_prob: + gates_masked = gates_masked / denom_s + + gates_masked *= self.routed_scaling_factor if hasattr(self.config, "seq_aux") and self.config.seq_aux: l_aux = self._cal_seq_aux_loss(gates_ori, self.top_k, top_idx) @@ -584,5 +591,5 @@ def topkgating_nodrop(self, gates: paddle.Tensor): l_aux = self._cal_aux_loss(gates, mask) exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0) - topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1) - return topk_masked_gates, mask, exp_counts, l_aux, l_zloss + # topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1) + return gates_masked, mask, exp_counts, l_aux, l_zloss diff --git a/paddleformers/transformers/moe_layer.py b/paddleformers/transformers/moe_layer.py index 340fba1f524..8cc2ae90dee 100644 --- a/paddleformers/transformers/moe_layer.py +++ b/paddleformers/transformers/moe_layer.py @@ -16,6 +16,7 @@ # limitations under the License. from __future__ import annotations +import os from typing import Any, List, Tuple import numpy as np @@ -24,8 +25,56 @@ from paddle import Tensor, nn from paddle.distributed.communication.group import Group +from ..utils.log import logger +from .fp8_utils import FP8GroupGemmMlpFunctionNode, extract_first_if_tuple +from .fused_a2a import CombineNode, DispatchNode, get_buffer, get_hidden_bytes from .moe_gate import PretrainedMoEGate -from .token_dispatcher import MoEFlexTokenDispatcher +from .moe_utils import ( + UnZipNode, + ZipNode, + merge_subbatch_cast, + offload, + reload, + tokens_zip_unique_add_with_subbatch, +) +from .token_dispatcher import PreDispatchNode + +if not os.getenv("DSV3_FAST_PRETRAIN", "False"): + from .token_dispatcher import MoEFlexTokenDispatcher +else: + from .token_dispatcher import MoEFlexTokenDispatcherFast as MoEFlexTokenDispatcher + + +try: + import paddle.distributed.communication.deep_ep as deep_ep +except ImportError: + deep_ep = None + +DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true" + +DSV3_USE_FP8_GROUP_GEMM = os.getenv("DSV3_USE_FP8_GROUP_GEMM", "False").lower() == "true" + +DSV3_USE_FP8_DISPATCH = os.getenv("DSV3_USE_FP8_DISPATCH", "False").lower() == "true" + +try: + import TokenDispatcherUtils as TDU +except ImportError: + TDU = None + + +def record_stream_for_multi_input(x): + if isinstance(x, (tuple, list)): + for i in range(len(x)): + x[i]._record_stream() + else: + x._record_stream() + + +def stop_gradient_for_multi_input(x): + if isinstance(x, (tuple, list)): + x[0].stop_gradient = False + else: + x.stop_gradient = False def dispatching(x, dispatch_mask, scatter_index, num_experts, capacity): @@ -162,6 +211,7 @@ def __init__( capacity: int = 1.0, moe_group: str = "data", all_to_all_dropout=0.0, + using_post_norm_recompute=False, ): super().__init__() @@ -176,12 +226,11 @@ def __init__( except AttributeError: is_fleet_init = False - if ( - is_fleet_init - and dist.fleet.get_hybrid_communicate_group().get_data_parallel_world_size() > 1 - and moe_group == "data" - ): - self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + if is_fleet_init and dist.get_world_size() > 1: + if moe_group == "data": + self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + elif moe_group == "expert": + self.moe_group = dist.fleet.get_hybrid_communicate_group().expert_parallel_group self.moe_rank = dist.get_rank(self.moe_group) self.moe_rank = 0 if self.moe_rank < 0 else self.moe_rank self.expert_parallel_degree = dist.get_world_size(self.moe_group) @@ -210,8 +259,33 @@ def __init__( self.gate = gate self.gate.group = self.moe_group + # for flex token moe layer + self.router = gate + self.ep_size = dist.get_world_size(self.moe_group) + self.moe_router_topk = gate.top_k + self.num_local_experts = moe_num_experts // self.ep_size + if self.moe_group is not None: + self.token_dispatcher = MoEFlexTokenDispatcher( + self.num_local_experts, self.moe_router_topk, self.moe_num_experts, self.moe_group + ) + self.token_drop_steps = config.token_drop_steps if hasattr(config, "token_drop_steps") else None + self.using_flex_token = False + + self.using_post_norm_recompute = using_post_norm_recompute self._post_init() + def update_flex_token(self): + from paddleformers.transformers.deepseek_v2 import get_global_step + + if (not self.config.using_flex_token) or (get_global_step() < self.token_drop_steps): + self.using_flex_token = False + self.router.using_flex_token = False + else: + if not self.using_flex_token: + logger.info("Changing to flex token moe mode") + self.using_flex_token = True + self.router.using_flex_token = True + def _parse_moe_expert_parallel(self, moe_num_experts, expert_parallel_degree): assert ( moe_num_experts >= expert_parallel_degree @@ -234,8 +308,35 @@ def _post_init(self): # logger.info(f"expert param={p.name}, no-sync={p.no_sync}") def forward( + self, + hidden_states: paddle.Tensor, + probs=None, + routing_map=None, + capacity=None, + topk_weight=None, + topk_ids=None, + token_priority=None, + l_aux=None, + l_zloss=None, + ): + self.update_flex_token() + + if self.using_flex_token: + return self.forward_flex_token(hidden_states, probs, routing_map, l_aux, l_zloss) + else: + return self.forward_drop_token( + hidden_states, capacity, topk_weight, topk_ids, token_priority, l_aux, l_zloss + ) + + def forward_drop_token( self, hidden_state: paddle.Tensor, + capacity=None, + topk_weight=None, + topk_ids=None, + token_priority=None, + l_aux=None, + l_zloss=None, ): """MoE Layer forward function 1. Gate Forward. @@ -257,7 +358,17 @@ def forward( # topk_ids : sk # token_priority : se # self.exp_counts : - capacity, topk_weight, topk_ids, token_priority, l_aux, l_zloss = self.gate(hidden_state) + if self.using_post_norm_recompute: + assert ( + capacity is not None + and topk_weight is not None + and topk_ids is not None + and token_priority is not None + and l_aux is not None + and l_zloss is not None + ) + else: + capacity, topk_weight, topk_ids, token_priority, l_aux, l_zloss = self.gate(hidden_state) """MoE expert dispatch from: https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py""" cnts = paddle.zeros([topk_ids.shape[0], len(self.experts)], dtype=topk_ids.dtype) @@ -376,3 +487,771 @@ def forward(self, hidden_states: paddle.Tensor): expert_output = self.expert_forward(dispatched_input, tokens_per_expert) output, _ = self.token_dispatcher.token_unpermutation(expert_output, None) return output, l_aux, l_zloss + + def forward_flex_token(self, hidden_states: paddle.Tensor, probs=None, routing_map=None, l_aux=None, l_zloss=None): + _, _, d_model = hidden_states.shape + # reshaped_input = hidden_states.reshape([-1, d_model]) + if self.using_post_norm_recompute: + assert probs is not None and routing_map is not None and l_aux is not None and l_zloss is not None + else: + probs, routing_map, l_aux, l_zloss = self.router(hidden_states) + if DSV3_USE_FP8_GEMM: + if DSV3_USE_FP8_DISPATCH: + output = FusionMoe.apply( + hidden_states, + probs, + routing_map, + self, + recompute_fwd_gate_up=self.config.recompute_fwd_gate_up, + is_split_group_gemm=self.config.is_split_group_gemm, + ) + else: + hidden_states, token_indices, token_probs = self.token_dispatcher.pre_dispatch( + hidden_states, probs, routing_map + ) + output = FusionMoe.apply( + hidden_states, + token_indices, + token_probs, + self, + recompute_fwd_gate_up=self.config.recompute_fwd_gate_up, + is_split_group_gemm=self.config.is_split_group_gemm, + ) + else: + ( + dispatched_input, + token_permuted_indices, + prob_permuted_indices, + dispatched_probs, + ) = self.token_dispatcher.token_permutation(hidden_states, probs, routing_map) + + expert_output = self.expert_forward(dispatched_input) + output, _ = self.token_dispatcher.token_unpermutation( + expert_output, token_permuted_indices, prob_permuted_indices, dispatched_probs, None + ) + return output, l_aux, l_zloss + + def get_tokens_per_expert(self): + return self.token_dispatcher._comm_manager.tokens_per_expert_list + + def set_tokens_per_expert(self, tokens_per_expert_list): + self.token_dispatcher._comm_manager.tokens_per_expert_list = tokens_per_expert_list + + def pre_dispatch_compute(self, hidden_states): + _, _, d_model = hidden_states.shape + probs, routing_map, l_aux, l_zloss = self.router(hidden_states) + hidden_states, token_indices, token_probs = self.token_dispatcher.pre_dispatch( + hidden_states, probs, routing_map + ) + return l_aux, l_zloss, hidden_states, token_indices, token_probs + + def post_dispatch_compute(self, hidden_states, dispatched_indices, dispatched_probs): + (global_input_tokens, token_permuted_indices, prob_permuted_indices) = self.token_dispatcher.post_dispatch( + hidden_states, dispatched_indices + ) + return (global_input_tokens, token_permuted_indices, prob_permuted_indices) + + def pre_combine_compute(self, hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs): + hidden_states = self.token_dispatcher.pre_combine( + hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs + ) + return hidden_states + + def post_combine_compute(self, hidden_states): + hidden_states = self.token_dispatcher.post_combine(hidden_states) + return hidden_states + + +class Fp8DispatchQuantNode: + def __init__(self, token_dispatcher, name="fp8_dispatch_quant_node"): + self.token_dispatcher = token_dispatcher + self.pre_dispatch_node = PreDispatchNode(token_dispatcher) + self.name = name + + @paddle.no_grad() + def forward(self, hidden_states, probs, routing_map): + # reshape + self.token_dispatcher.hidden_shape = hidden_states.shape + hs_2d = hidden_states.view([-1, self.token_dispatcher.hidden_shape[-1]]) + + if DSV3_USE_FP8_DISPATCH: + # quant + hs_fp8, hs_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + hs_2d, output_scale_transpose=False, quant_method="1x128", input_transpose=False + ) + + # pre_dispatch + token_indices, token_probs = self.pre_dispatch_node.forward(routing_map, probs) + + self.hidden_states_shape = hidden_states.shape + hs_fp8.stop_gradient = False + token_probs.stop_gradient = False + return (hs_fp8, hs_scale), token_indices, token_probs + else: + # pre_dispatch + token_indices, token_probs = self.pre_dispatch_node.forward(routing_map, probs) + + self.hidden_states_shape = hidden_states.shape + hs_2d.stop_gradient = False + token_probs.stop_gradient = False + return hs_2d, token_indices, token_probs + + @paddle.no_grad() + def backward(self, hs_grad, token_probs_grad): + # predispatch grad + probs_grad = self.pre_dispatch_node.backward(token_probs_grad) + token_probs_grad._record_stream() + + # reshape_grad + hs_grad = hs_grad.view(self.hidden_states_shape) + hs_grad._record_stream() + + return hs_grad, probs_grad, None + + +class Fp8DispatchNode: + def __init__(self, token_dispatcher, name="fp8_dispatch_node"): + self.token_dispatcher = token_dispatcher + self.dispatch_act_node = DispatchNode(token_dispatcher) + self.name = name + + @paddle.no_grad() + def forward( + self, + hs_2d, + token_indices, + token_probs, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, + ): + # dispatch + hs_2d_dispatched, dispatched_probs, states = self.dispatch_act_node.forward( + hs_2d, + token_indices, + token_probs, + self.token_dispatcher._comm_manager.num_experts, + self.token_dispatcher._comm_manager.group, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + self.token_dispatcher._comm_manager.handle = states["handle"] + self.token_dispatcher._comm_manager.tokens_per_expert = states["tokens_per_expert"] + dispatched_indices = states["dispatched_indices"] + + stop_gradient_for_multi_input(hs_2d_dispatched) + dispatched_probs.stop_gradient = False + return hs_2d_dispatched, dispatched_indices, dispatched_probs + + @paddle.no_grad() + def backward( + self, + hs_dispatched_grad, + dispatched_probs_grad, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, + ): + # dispatch grad + hs_grad, _, token_probs_grad = self.dispatch_act_node.backward( + hs_dispatched_grad, + dispatched_probs_grad, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + return hs_grad, token_probs_grad + + +class Fp8CombineNode: + def __init__(self, token_dispatcher, name="fp8_combine_node"): + self.token_dispatcher = token_dispatcher + self.combine_node = CombineNode(token_dispatcher) + self.name = name + + @paddle.no_grad() + def forward(self, hidden_states_out, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + # combine + output_combine = self.combine_node.forward( + hidden_states_out, + self.token_dispatcher._comm_manager.group, + self.token_dispatcher._comm_manager.handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + output_combine.stop_gradient = False + self.token_dispatcher._comm_manager.handle = None + return output_combine + + @paddle.no_grad() + def backward(self, output_combine_grad, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + # combine grad -> fp8 + hidden_states_out_grad = self.combine_node.backward( + output_combine_grad, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + return hidden_states_out_grad + + +class Fp8CombineQuantNode: + def __init__(self, token_dispatcher, moe_group=None, name="fp8_combine_quant_node"): + self.token_dispatcher = token_dispatcher + self.name = name + self.moe_group = moe_group + + @paddle.no_grad() + def forward(self, output_combine): + # post combine + output = output_combine.reshape(self.token_dispatcher.hidden_shape) + output_combine._record_stream() + self.output_combine_shape = output_combine.shape + output.stop_gradient = False + return output + + @paddle.no_grad() + def backward(self, output_grad, event_to_wait=None): + # post combine grad + if DSV3_USE_FP8_DISPATCH: + if event_to_wait is not None: + assert self.moe_group is not None + event_to_wait.comm_stream_wait(self.moe_group.id) + buffer = get_buffer(self.token_dispatcher._comm_manager.group, get_hidden_bytes(output_grad)) + custom_stream = paddle.device.Stream(stream_base=buffer.runtime.get_comm_stream()) + else: + custom_stream = paddle.device.current_stream() + with paddle.device.stream_guard(custom_stream): + output_combine_grad = paddle.reshape(output_grad, [-1, output_grad.shape[-1]]) + # output_combine_grad quant to fp8 + output_combine_grad_fp8, output_combine_grad_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + output_combine_grad, output_scale_transpose=False, quant_method="1x128", input_transpose=False + ) + output_grad._record_stream() + quant_event = None + if event_to_wait is not None: + quant_event = deep_ep.get_event_from_custom_stream(custom_stream.stream_base) + return (output_combine_grad_fp8, output_combine_grad_scale), quant_event + else: + output_combine_grad = paddle.reshape(output_grad, [-1, output_grad.shape[-1]]) + return output_combine_grad, None + + +class FusionMlpNode: + """ + The FusedMoeLayer class includes operations for unzipping, expert computation, and zipping. + """ + + def __init__( + self, + custom_map, + max_topk, + recompute_fwd_gate_up=False, + is_split_group_gemm=True, + mlp_fwd_subbatch_rows=0, + mlp_bwd_subbatch_rows=0, + output_subbatch_rows=0, + ): + self.token_dispatcher = custom_map.token_dispatcher + self.experts = custom_map.experts + self.unzip_node = UnZipNode() + self.zip_node = ZipNode() + self.experts_group_gemm_node = FP8GroupGemmMlpFunctionNode( + custom_map, + recompute_fwd_gate_up=recompute_fwd_gate_up, + is_split_group_gemm=is_split_group_gemm, + ) + + self.seq_length = custom_map.config.seq_length + self.num_experts_per_tok = custom_map.config.num_experts_per_tok + self.adaptive_remained_O1_recompute_ratio = custom_map.config.adaptive_remained_O1_recompute_ratio + + self.recompute_fwd_gate_up = recompute_fwd_gate_up + self.dispatched_indices = None + self.dispatched_probs = None + self.tokens_per_expert = None + self.padding_token_per_experts = None + self.router_topk = max_topk + self.mlp_fwd_subbatch_rows = mlp_fwd_subbatch_rows + self.mlp_bwd_subbatch_rows = mlp_bwd_subbatch_rows + self.output_subbatch_rows = output_subbatch_rows + + def set_recompute_fwd_gate_up(self, recompute_fwd_gate_up): + self.experts_group_gemm_node.recompute_fwd_gate_up = recompute_fwd_gate_up + + def reset_statue(self): + """ + 重置所有状态变量。 + + Args: + 无。 + + Returns: + 无。 + + """ + self.dispatched_indices = None + self.dispatched_probs = None + self.tokens_per_expert = None + self.padding_token_per_experts = None + self.router_topk = None + + del self.unzip_node + del self.zip_node + self.unzip_node = None + self.zip_node = None + + self.experts_group_gemm_node.reset_statue() + self.experts_group_gemm_node = None + + def prepare_env_subbatch(self, unzipped_tokens=None, unzipped_tokens_scale=None, is_fwd=True): + if is_fwd: + assert unzipped_tokens is not None and unzipped_tokens_scale is not None + self.experts_group_gemm_node.input_fp8 = unzipped_tokens + self.experts_group_gemm_node.input_scale = unzipped_tokens_scale + self.m_indices = self.experts_group_gemm_node.gen_m_indices(self.padding_token_per_experts) + self.experts_group_gemm_node.fwd_subbatch = True + else: + self.m_indices = ( + self.experts_group_gemm_node.gen_m_indices(self.padding_token_per_experts) + if not hasattr(self, "m_indices") + else self.m_indices + ) + self.experts_group_gemm_node.bwd_subbatch = True + reload(self.experts_group_gemm_node.input_fp8) + reload(self.experts_group_gemm_node.input_scale) + + def gemm_forward_subbatch( + self, + unzipped_tokens, + unzipped_tokens_scale, + unzipped_probs, + map_unzipped_indices_to_zipped, + output, + total_zipped_tokens, + padding_token_per_experts, + start_idx=None, + end_idx=None, + output_subbatch_rows=None, + ): + if start_idx is None or end_idx is None: + start_idx = 0 + end_idx = unzipped_tokens.shape[0] + start_idx = max(0, start_idx) + end_idx = min(unzipped_tokens.shape[0], end_idx) + + expert_out = self.experts_group_gemm_node.forward( + (unzipped_tokens[start_idx:end_idx], unzipped_tokens_scale[start_idx:end_idx]), + unzipped_probs[start_idx:end_idx], + padding_token_per_experts, + m_indices=self.m_indices[start_idx:end_idx], + ) + + output = tokens_zip_unique_add_with_subbatch( + output, + expert_out, + map_unzipped_indices_to_zipped[start_idx:end_idx], + total_zipped_tokens, + subbatch_rows=output_subbatch_rows, + ) + return output + + def gemm_backward_subbatch( + self, + unzipped_grad, + map_unzipped_indices_to_zipped, + total_zipped_tokens, + output, + padding_token_per_experts, + start_idx=None, + end_idx=None, + output_subbatch_rows=None, + reset_status=False, + ): + def split_list_prefix(l, start, end): + prefix_sum = [0] * (len(l) + 1) + for i in range(len(l)): + prefix_sum[i + 1] = prefix_sum[i] + l[i] + + result = [] + for i in range(len(l)): + segment_start = prefix_sum[i] + segment_end = prefix_sum[i + 1] + overlap_start = max(start, segment_start) + overlap_end = min(end, segment_end) + selected = max(0, overlap_end - overlap_start) + result.append(selected) + return result + + if start_idx is None or end_idx is None: + start_idx = 0 + end_idx = extract_first_if_tuple(unzipped_grad).shape[0] + + start_idx = max(0, start_idx) + end_idx = min(extract_first_if_tuple(unzipped_grad).shape[0], end_idx) + + # m_indices = self.experts_group_gemm_node.gen_m_indices(self.tokens_per_expert) + unzipped_inp_grad = ( + (unzipped_grad[0][start_idx:end_idx].contiguous(), unzipped_grad[1][start_idx:end_idx].contiguous()) + if isinstance(unzipped_grad, tuple) + else unzipped_grad[start_idx:end_idx].contiguous() + ) + unzipped_grad, unzipped_probs_grad = self.experts_group_gemm_node.backward( + unzipped_inp_grad, + self.unzipped_probs[start_idx:end_idx].contiguous(), + input_fp8_slice=self.experts_group_gemm_node.input_fp8[start_idx:end_idx].contiguous(), + input_scale_slice=self.experts_group_gemm_node.input_scale[start_idx:end_idx].contiguous(), + tokens_per_expert=split_list_prefix(padding_token_per_experts, start_idx, end_idx), + m_indices=self.m_indices[start_idx:end_idx].contiguous(), + reset_status=reset_status, + ) + + output = tokens_zip_unique_add_with_subbatch( + output, + unzipped_grad, + map_unzipped_indices_to_zipped[start_idx:end_idx], + zipped_rows=total_zipped_tokens, + subbatch_rows=output_subbatch_rows, + ) + + return output, unzipped_probs_grad + + @paddle.no_grad() + def forward(self, hs_2d_dispatched, dispatched_indices, dispatched_probs): + """ + 对输入数据进行前向传播计算。 + + Args: + hs_fp8_dispatched (Tensor): 表示被分派到各个专家的输入数据。 + dispatched_indices (Tensor):表示输入数据被分派到的专家索引。 + dispatched_probs (Tensor): 表示输入数据被分派到各个专家的概率。 + + Returns: + Tensor: 经过前向传播计算后的输出数据。 + + """ + self.tokens_per_expert = self.token_dispatcher._comm_manager.tokens_per_expert + self.dispatched_probs = dispatched_probs + num_experts = len(self.tokens_per_expert) + padding_token_per_experts = [(x + 127) // 128 * 128 for x in self.tokens_per_expert] + self.padding_token_per_experts = padding_token_per_experts + # 1 unzip + self.dispatched_indices = dispatched_indices.to(paddle.int32) + + total_zipped_tokens = extract_first_if_tuple(hs_2d_dispatched).shape[0] + (unzipped_tokens, zipped_expertwise_rowmap, unzipped_probs, unzipped_tokens_scale,) = self.unzip_node.forward( + hs_2d_dispatched, + self.dispatched_indices, + dispatched_probs, + topk=self.router_topk, + num_experts=num_experts, + tokens_per_expert=self.tokens_per_expert, + ) + record_stream_for_multi_input(hs_2d_dispatched) + dispatched_indices._record_stream() + dispatched_probs._record_stream() + + self.unzipped_probs = unzipped_probs.unsqueeze(-1) + + if DSV3_USE_FP8_DISPATCH: + total_unzipped_tokens = extract_first_if_tuple(unzipped_tokens).shape[0] + # If adaptive O1 recompute is enabled, determine whether to enable recompute O1 based on the degree of imbalance + if self.recompute_fwd_gate_up == -1: + if ( + total_unzipped_tokens + > self.seq_length * self.num_experts_per_tok * self.adaptive_remained_O1_recompute_ratio + ): + # logger.debug(f"recompute_fwd_gate_up changed to True, Because the receives {unzipped_tokens.shape[0]} Tensors greater then {self.seq_length*self.num_experts_per_tok*self.adaptive_remained_O1_recompute_ratio}.") + self.set_recompute_fwd_gate_up(True) + else: + # logger.debug(f"recompute_fwd_gate_up changed to False, Because the receives {unzipped_tokens.shape[0]} Tensors less then {self.seq_length*self.num_experts_per_tok*self.adaptive_remained_O1_recompute_ratio}.") + self.set_recompute_fwd_gate_up(False) + + # if use_mlp_subbatch is enabled, then split the unzipped_tokens into subbatches + if self.mlp_fwd_subbatch_rows != 0 and total_unzipped_tokens > self.mlp_fwd_subbatch_rows * 2: + assert ( + self.experts_group_gemm_node.recompute_fwd_gate_up + ), "recompute_fwd_gate_up must be true when use_mlp_subbatch = True" + map_unzipped_indices_to_zipped = TDU.tokens_unzip_slice( + extract_first_if_tuple(hs_2d_dispatched), + zipped_expertwise_rowmap, + num_experts, + total_unzipped_tokens, + 0, + total_unzipped_tokens + 1, + ) + if isinstance(hs_2d_dispatched, tuple): + hs_2d_dispatched[0]._clear_to_zero_allocation() + hs_2d_dispatched[1]._clear_to_zero_allocation() + else: + hs_2d_dispatched._clear_to_zero_allocation() + + subbatch_rows = min((total_unzipped_tokens // num_experts) // 128 * 128, self.mlp_fwd_subbatch_rows) + nparts = (total_unzipped_tokens + subbatch_rows - 1) // subbatch_rows + output = paddle.empty([0, extract_first_if_tuple(hs_2d_dispatched).shape[-1]], dtype=paddle.float32) + self.prepare_env_subbatch(unzipped_tokens, unzipped_tokens_scale, True) + logger.info( + f"Enable subbatch_forward!! total_zipped_tokens:{total_zipped_tokens}, total_unzipped_tokens:{total_unzipped_tokens}, nparts:{nparts}, subbatch_rows:{subbatch_rows}, output_sub_rows:{self.output_subbatch_rows}" + ) + for i in range(nparts): + start_idx = i * subbatch_rows + end_idx = min(start_idx + subbatch_rows, total_unzipped_tokens) + output = self.gemm_forward_subbatch( + unzipped_tokens, + unzipped_tokens_scale, + unzipped_probs, + map_unzipped_indices_to_zipped, + output, + total_zipped_tokens, + padding_token_per_experts, + start_idx=start_idx, + end_idx=end_idx, + output_subbatch_rows=self.output_subbatch_rows, + ) + + output = merge_subbatch_cast(output, paddle.bfloat16) + output.stop_gradient = False + offload(self.experts_group_gemm_node.input_fp8) + offload(self.experts_group_gemm_node.input_scale) + return output + + # 2 experts + expert_out = self.experts_group_gemm_node.forward( + (unzipped_tokens, unzipped_tokens_scale), unzipped_probs, padding_token_per_experts + ) + else: + # If adaptive O1 recompute is enabled, determine whether to enable recompute O1 based on the degree of imbalance + if self.recompute_fwd_gate_up == -1: + if ( + unzipped_tokens.shape[0] + > self.seq_length * self.num_experts_per_tok * self.adaptive_remained_O1_recompute_ratio + ): + self.set_recompute_fwd_gate_up(True) + else: + self.set_recompute_fwd_gate_up(False) + + # 2 experts + expert_out = self.experts_group_gemm_node.forward( + unzipped_tokens, unzipped_probs, padding_token_per_experts + ) + + # 3 zip + if isinstance(hs_2d_dispatched, tuple): + hs_2d_dispatched[0]._clear_to_zero_allocation() + hs_2d_dispatched[1]._clear_to_zero_allocation() + else: + hs_2d_dispatched._clear_to_zero_allocation() + expert_out_tmp = expert_out.reshape([-1, expert_out.shape[-1]]) + + expert_out_zipped = self.zip_node.forward( + expert_out_tmp, + zipped_expertwise_rowmap, + self.dispatched_indices, + unzipped_probs, + total_zipped_tokens=total_zipped_tokens, + num_experts=num_experts, + ) + + expert_out_zipped.stop_gradient = False + return expert_out_zipped + + @paddle.no_grad() + def backward(self, hidden_states_out_grad): + """ + 反向传播函数。 + + Args: + hidden_states_out_grad_fp8 (Tensor): 隐藏状态梯度。 + + Returns: + Tuple[Tensor, Tensor]: 包含两个元素,分别为hs_fp8_dispatched_grad和dispatched_probs_grad。 + - hs_fp8_dispatched_grad (Tensor): 解压后的隐藏状态梯度。 + - dispatched_probs_grad (Tensor): 分发概率梯度。 + + """ + # zip_grad + unzipped_grad = self.zip_node.backward( + hidden_states_out_grad, + self.dispatched_indices, + self.dispatched_probs, + top_k=self.router_topk, + num_experts=len(self.tokens_per_expert), + tokens_per_expert=self.tokens_per_expert, + ) + record_stream_for_multi_input(hidden_states_out_grad) + + total_zipped_tokens = extract_first_if_tuple(hidden_states_out_grad).shape[0] + total_unzipped_tokens = extract_first_if_tuple(unzipped_grad).shape[0] + hidden_states_size = extract_first_if_tuple(hidden_states_out_grad).shape[-1] + num_experts = len(self.tokens_per_expert) + padding_token_per_experts = [(x + 127) // 128 * 128 for x in self.tokens_per_expert] + + if self.mlp_bwd_subbatch_rows != 0 and total_unzipped_tokens > self.mlp_bwd_subbatch_rows * 2: + map_unzipped_indices_to_zipped = TDU.tokens_unzip_slice( + extract_first_if_tuple(hidden_states_out_grad), + self.unzip_node.zipped_expertwise_rowmap, + num_experts, + total_unzipped_tokens, + 0, + total_unzipped_tokens + 1, + ) + if isinstance(hidden_states_out_grad, tuple): + hidden_states_out_grad[0]._clear_to_zero_allocation() + hidden_states_out_grad[1]._clear_to_zero_allocation() + else: + hidden_states_out_grad._clear_to_zero_allocation() + + subbatch_rows = min((total_unzipped_tokens // num_experts) // 128 * 128, self.mlp_bwd_subbatch_rows) + nparts = (total_unzipped_tokens + subbatch_rows - 1) // subbatch_rows + output = paddle.empty([0, hidden_states_size], dtype=paddle.float32) + probs_grad_list = [] + self.prepare_env_subbatch(is_fwd=False) + logger.info( + f"Enable subbatch_backward!! total_zipped_tokens:{total_zipped_tokens}, total_unzipped_tokens:{total_unzipped_tokens}, nparts:{nparts}, subbatch_rows:{subbatch_rows}, output_sub_rows:{self.output_subbatch_rows}" + ) + for i in range(nparts): + reset_status = True if i == nparts - 1 else False # release saved status in the last part. + start_idx = i * subbatch_rows + end_idx = min(start_idx + subbatch_rows, total_unzipped_tokens) + output, probs_grad = self.gemm_backward_subbatch( + unzipped_grad, + map_unzipped_indices_to_zipped, + total_zipped_tokens, + output, + padding_token_per_experts, + start_idx=start_idx, + end_idx=end_idx, + output_subbatch_rows=self.output_subbatch_rows, + reset_status=reset_status, + ) + probs_grad_list.append(probs_grad) + if isinstance(unzipped_grad, tuple): + unzipped_grad[0]._clear_to_zero_allocation() + unzipped_grad[1]._clear_to_zero_allocation() + else: + unzipped_grad._clear_to_zero_allocation() + hs_dispatched_grad = merge_subbatch_cast(output, paddle.bfloat16) + dispatched_probs_grad = TDU.tokens_zip_prob_seq_subbatch( + probs_grad_list, self.unzip_node.zipped_expertwise_rowmap, self.dispatched_indices, subbatch_rows + ) + self.reset_statue() + return hs_dispatched_grad, dispatched_probs_grad + + if isinstance(hidden_states_out_grad, tuple): + hidden_states_out_grad[0]._clear_to_zero_allocation() + hidden_states_out_grad[1]._clear_to_zero_allocation() + else: + hidden_states_out_grad._clear_to_zero_allocation() + + # expert_grad + expert_out, probs_grad = self.experts_group_gemm_node.backward( + unzipped_grad, self.unzipped_probs, padding_token_per_experts + ) + + hs_dispatched_grad, dispatched_probs_grad = self.unzip_node.backward( + expert_out, + total_zipped_tokens, + probs_grad, + self.dispatched_indices, + num_experts=num_experts, + ) + + self.reset_statue() + return hs_dispatched_grad, dispatched_probs_grad + + +class FusionMoeNode: + def __init__( + self, + custom_map, + recompute_fwd_gate_up=False, + is_split_group_gemm=True, + mlp_fwd_subbatch_rows=0, + mlp_bwd_subbatch_rows=0, + output_subbatch_rows=0, + name="fusion_moe_node", + ): + self.token_dispatcher = custom_map.token_dispatcher + self.moe_router_topk = custom_map.moe_router_topk + self.dispatch_quant_node = Fp8DispatchQuantNode(self.token_dispatcher) + self.dispatch_node = Fp8DispatchNode(self.token_dispatcher) + self.mlp_node = FusionMlpNode( + custom_map, + self.moe_router_topk, + recompute_fwd_gate_up=recompute_fwd_gate_up, + is_split_group_gemm=is_split_group_gemm, + mlp_fwd_subbatch_rows=mlp_fwd_subbatch_rows, + mlp_bwd_subbatch_rows=mlp_bwd_subbatch_rows, + output_subbatch_rows=output_subbatch_rows, + ) + self.combine_node = Fp8CombineNode(self.token_dispatcher) + self.combine_quant_node = Fp8CombineQuantNode(self.token_dispatcher, custom_map.moe_group) + self.name = name + + @paddle.no_grad() + def forward(self, hidden_states, probs, routing_map): + if DSV3_USE_FP8_DISPATCH: + (hs_fp8, hs_scale), token_indices, token_probs = self.dispatch_quant_node.forward( + hidden_states, probs, routing_map + ) + ( + (hs_fp8_dispatched, hs_scale_dispatched), + dispatched_indices, + dispatched_probs, + ) = self.dispatch_node.forward((hs_fp8, hs_scale), token_indices, token_probs) + hidden_states_out = self.mlp_node.forward( + (hs_fp8_dispatched, hs_scale_dispatched), dispatched_indices, dispatched_probs + ) + output_combine = self.combine_node.forward(hidden_states_out) + output = self.combine_quant_node.forward(output_combine) + output.stop_gradient = False + return output + else: + hs_2d_dispatched, dispatched_indices, dispatched_probs = self.dispatch_node.forward( + hidden_states, probs, routing_map + ) + hidden_states_out = self.mlp_node.forward(hs_2d_dispatched, dispatched_indices, dispatched_probs) + output_combine = self.combine_node.forward(hidden_states_out) + output = self.combine_quant_node.forward(output_combine) + output.stop_gradient = False + return output + + @paddle.no_grad() + def backward(self, output_grad): + output_combine_grad, _ = self.combine_quant_node.backward(output_grad) + hidden_states_out_grad = self.combine_node.backward(output_combine_grad) + + hs_dispatched_grad, dispatched_probs_grad = self.mlp_node.backward(hidden_states_out_grad) + + if DSV3_USE_FP8_DISPATCH: + hs_fp8_grad, token_probs_grad = self.dispatch_node.backward(hs_dispatched_grad, dispatched_probs_grad) + hs_grad, probs_grad, routing_map_grad = self.dispatch_quant_node.backward(hs_fp8_grad, token_probs_grad) + return hs_grad, probs_grad, routing_map_grad + else: + hs_bf16_grad, token_probs_grad = self.dispatch_node.backward(hs_dispatched_grad, dispatched_probs_grad) + return hs_bf16_grad, None, token_probs_grad + + +class FusionMoe(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx, + hidden_states, + probs, + routing_map, + custom_map, + recompute_fwd_gate_up=False, + is_split_group_gemm=True, + ): + ctx.node = FusionMoeNode( + custom_map, + recompute_fwd_gate_up=recompute_fwd_gate_up, + is_split_group_gemm=is_split_group_gemm, + ) + return ctx.node.forward(hidden_states, probs, routing_map) + + @staticmethod + def backward(ctx, output_grad): + return ctx.node.backward(output_grad) diff --git a/paddleformers/transformers/moe_utils.py b/paddleformers/transformers/moe_utils.py index 466591b0638..881ca4ac07a 100644 --- a/paddleformers/transformers/moe_utils.py +++ b/paddleformers/transformers/moe_utils.py @@ -16,8 +16,57 @@ from typing import Optional +import numpy as np import paddle +try: + import TokenDispatcherUtils as TDU +except ImportError: + TDU = None + +from .fp8_utils import FP8LinearFunctionBase + +if not hasattr(paddle.Tensor, "_clear_to_zero_allocation"): + + def _clear_to_zero_allocation(self): + """ + _clear_to_zero_allocation + """ + old_shape = self.shape + dst = paddle.empty([0], dtype=self.dtype) + dst_t = dst.value().get_tensor() + src_t = self.value().get_tensor() + src_t._share_data_with(dst_t) + src_t._set_dims(old_shape) + + setattr(paddle.Tensor, "_clear_to_zero_allocation", _clear_to_zero_allocation) + + +if not hasattr(paddle.Tensor, "_holder_size"): + + def _holder_size(self): + """ + _holder_size + """ + if self._is_initialized(): + return int(np.prod(self.shape)) * paddle.core.size_of_dtype(self.dtype) + else: + return 0 + + setattr(paddle.Tensor, "_holder_size", _holder_size) + + +def topk_to_permuted_indices(x, num_tokens_per_expert_list, topk): + x = paddle.flatten(x) + prob_permuted_indices = paddle.concat( + [ + paddle.tensor.search._restrict_nonzero(x == i, total_true_num) + for i, total_true_num in enumerate(num_tokens_per_expert_list) + ] + ).flatten() + token_permuted_indices = prob_permuted_indices // topk + return token_permuted_indices, prob_permuted_indices + def permute( tokens, @@ -99,3 +148,404 @@ def unpermute( include_self=True, ) return output_tokens + + +def permute_fast( + tokens, + token_permuted_indices, + drop_and_pad: bool = False, +): + """Permute the tokens and probs based on the mask. + Tokens with the same designated expert will be grouped together. + The shape of mask is [tokens, num_experts], it indicates which experts were selected + by each token. + + Args: + tokens (paddle.Tensor): The input token tensor, [num_tokens, hidden]. + drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop + and pads the number of tokens to the expert capacity. + """ + assert not drop_and_pad, "token-drop and pads is not supported" + # permuted_input = paddle.gather(tokens, token_permuted_indices) + permuted_input = tokens.index_select(axis=0, index=token_permuted_indices) + return permuted_input + + +def unpermute_fast( + permuted_tokens: paddle.Tensor, + token_permuted_indices: paddle.Tensor, + prob_permuted_indices: paddle.Tensor, + restore_shape: paddle.shape, + probs: paddle.Tensor = None, + drop_and_pad: bool = False, +): + """ + Restore the original order of tokens after permutation. If probs are provided, it + will also apply them to the tokens before restoring the order. + + Args: + permuted_tokens (paddle.Tensor): The permuted token tensor. + token_permuted_indices (paddle.Tensor): The indices used to sort the tokens. + restore_shape (paddle.shape): The shape of the unpermuted tensor. + probs (paddle.Tensor, optional): The unpermuted probs tensor, + drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop + and pads the number of tokens to the expert capacity. + + Returns: + paddle.Tensor: The tokens restored to their original order. + """ + assert not drop_and_pad, "token-drop and pads is not supported" + _, hidden = restore_shape + if probs is not None: + permuted_probs = paddle.gather(probs.flatten(), prob_permuted_indices) + permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1) + # Create an output tensor filled with zeros + output_tokens = paddle.zeros(restore_shape, dtype=permuted_tokens.dtype) + # Scatter add the permuted_input back to the original positions + + output_tokens.put_along_axis_( + axis=0, + indices=token_permuted_indices.unsqueeze(1).expand([-1, hidden]), + values=permuted_tokens, + reduce="add", + include_self=True, + ) + return output_tokens + + +class UnZipNode: + def __init__(self, name="unzip"): + self.name = name + self.unzipped_probs = None + self.zipped_expertwise_rowmap = None + + def reset_statue(self): + self.unzipped_probs = None + self.zipped_expertwise_rowmap = None + + @paddle.no_grad() + def forward( + self, + hs_2d_dispatched, + dispatched_indices, + dispatched_probs, + topk, + num_experts, + tokens_per_expert, + ): + if isinstance(hs_2d_dispatched, tuple): + with paddle.amp.auto_cast(False): + ( + unzipped_tokens, + zipped_expertwise_rowmap, + unzipped_probs, + unzipped_scale, + ) = paddle.nn.functional.moe_permute( + hs_2d_dispatched[0], + hs_2d_dispatched[1], + dispatched_indices, + dispatched_probs, + num_experts=num_experts, + tokens_per_expert=tokens_per_expert, + padding_alignment=128, + ) + else: + with paddle.amp.auto_cast(False): + ( + unzipped_tokens, + zipped_expertwise_rowmap, + unzipped_probs, + unzipped_scale, + ) = paddle.nn.functional.moe_permute( + hs_2d_dispatched, + None, + dispatched_indices, + dispatched_probs, + num_experts=num_experts, + tokens_per_expert=tokens_per_expert, + padding_alignment=128, + ) + self.unzipped_probs = unzipped_probs + self.zipped_expertwise_rowmap = zipped_expertwise_rowmap + return (unzipped_tokens, zipped_expertwise_rowmap, unzipped_probs, unzipped_scale) + + @paddle.no_grad() + def backward(self, dx, total_zipped_tokens, probs_grad, dispatched_indices, num_experts): + with paddle.amp.auto_cast(False): + weighted_zipped_tokens, probs_grad_zipped = paddle.nn.functional.moe_unpermute( + dx, + self.zipped_expertwise_rowmap, + dispatched_indices, + probs_grad, + total_zipped_tokens=total_zipped_tokens, + num_experts=num_experts, + ) + self.reset_statue() + return weighted_zipped_tokens, probs_grad_zipped + + +class ZipNode: + def __init__(self, name="zip"): + self.name = name + + @paddle.no_grad() + def forward( + self, expert_out, zipped_expertwise_rowmap, routemap_topk, unzipped_probs, total_zipped_tokens, num_experts + ): + with paddle.amp.auto_cast(False): + expert_out_zipped, zipped_probs_topk = paddle.nn.functional.moe_unpermute( + expert_out, zipped_expertwise_rowmap, routemap_topk, unzipped_probs, total_zipped_tokens, num_experts + ) + return expert_out_zipped + + @paddle.no_grad() + def backward( + self, + grad_output, + dispatched_indices, + dispatched_probs, + top_k, + num_experts, + tokens_per_expert, + ): + if isinstance(grad_output, tuple): + with paddle.amp.auto_cast(False): + ( + unzipped_grad, + zipped_expertwise_rowmap_grad, + unzipped_probs_grad, + unzipped_scale_grad, + ) = paddle.nn.functional.moe_permute( + grad_output[0], + grad_output[1], + dispatched_indices, + dispatched_probs, + num_experts, + tokens_per_expert, + padding_alignment=128, + ) + return (unzipped_grad, unzipped_scale_grad) + else: + with paddle.amp.auto_cast(False): + ( + unzipped_grad, + zipped_expertwise_rowmap_grad, + unzipped_probs_grad, + unzipped_scale_grad, + ) = paddle.nn.functional.moe_permute( + grad_output, + None, + dispatched_indices, + dispatched_probs, + num_experts, + tokens_per_expert, + padding_alignment=128, + ) + + return unzipped_grad + + +class PermuteNode: + def __init__(self, token_dispatcher, name="permute"): + self.token_dispatcher = token_dispatcher + self.name = name + + def reset_status(self): + self.token_permuted_indices = None + self.prob_permuted_indices = None + + def forward(self, hidden_states, hidden_states_scale, dispatched_indices): + self.token_dispatcher._comm_manager.hidden_shape_before_permute = hidden_states.shape + self.hidden_shape_before_permute = hidden_states.shape + self.token_permuted_indices, self.prob_permuted_indices = topk_to_permuted_indices( + dispatched_indices, + self.token_dispatcher._comm_manager.tokens_per_expert, + self.token_dispatcher._comm_manager.router_topk, + ) + hidden_states = permute_fast(hidden_states, self.token_permuted_indices) + # permute scale + hidden_states_scale = permute_fast(hidden_states_scale, self.token_permuted_indices) + + return hidden_states, hidden_states_scale, self.token_permuted_indices, self.prob_permuted_indices + + def backward(self, out_grad, dispatched_probs): + input_dtype = out_grad.dtype + hidden_states_grad = unpermute_fast( + permuted_tokens=out_grad, + token_permuted_indices=self.token_permuted_indices, + prob_permuted_indices=self.prob_permuted_indices, + restore_shape=self.hidden_shape_before_permute, + probs=dispatched_probs, + ) + self.reset_status() + return hidden_states_grad.to(input_dtype) + + +class UnPermuteNode: + def __init__(self, token_dispatcher, name="unpermute"): + self.token_dispatcher = token_dispatcher + self.name = name + + def reset_status(self): + self.token_permuted_indices = None + self.hidden_states = None + self.prob_permuted_indices = None + self.faltten_dispatched_probs = None + self.hidden = None + self.permuted_tokens = None + self.output_tokens = None + + def forward( + self, + hidden_states, + token_permuted_indices, + prob_permuted_indices, + dispatched_probs, + ): + self.token_permuted_indices = token_permuted_indices + self.input_dtype = hidden_states.dtype + self.hidden_states = hidden_states + self.prob_permuted_indices = prob_permuted_indices + self.dispatched_probs_shape = dispatched_probs.shape + # permute + _, self.hidden = self.token_dispatcher._comm_manager.hidden_shape_before_permute + + self.faltten_dispatched_probs = dispatched_probs.flatten() + + self.permuted_probs = paddle.gather(self.faltten_dispatched_probs, self.prob_permuted_indices) + permuted_tokens = self.hidden_states * self.permuted_probs.unsqueeze(-1) + permuted_tokens = permuted_tokens.cast(self.hidden_states.dtype) + + # Create an output tensor filled with zeros + output_tokens = paddle.zeros( + self.token_dispatcher._comm_manager.hidden_shape_before_permute, dtype=self.hidden_states.dtype + ) + # Scatter add the permuted_input back to the original positions + output_tokens.put_along_axis_( + axis=0, + indices=self.token_permuted_indices.cast("int32").unsqueeze(1).expand([-1, self.hidden]), + values=permuted_tokens, + reduce="add", + include_self=True, + ) + with paddle.base.device_guard("cpu"): + self.output_tokens = paddle.empty(shape=output_tokens.shape, dtype=output_tokens.dtype) + + return output_tokens.to(self.input_dtype) + + def backward(self, out_grad, out_grad_scale): + hidden_states_grad = paddle.gather(out_grad, self.token_permuted_indices) + + output_tokens_grad = FP8LinearFunctionBase.dequantize_fp8_to_fp32(out_grad, out_grad_scale) + permuted_tokens = self.hidden_states * self.permuted_probs.unsqueeze(-1) + permuted_tokens = permuted_tokens.cast(self.hidden_states.dtype) + + _, permuted_tokens_grad = paddle._C_ops.put_along_axis_grad( + self.output_tokens, + self.token_permuted_indices.cast("int32").unsqueeze(1).expand([-1, self.hidden]), + permuted_tokens, + self.output_tokens, + output_tokens_grad, + 0, + "add", + True, + ) + + permuted_probs_grad = (permuted_tokens_grad * self.hidden_states).sum(axis=-1) + + faltten_dispatched_probs_grad = paddle._C_ops.gather_grad( + self.faltten_dispatched_probs, self.prob_permuted_indices, permuted_probs_grad, 0 + ) + + # dispatched_probs_grad = paddle._C_ops.flatten_grad(self.dispatched_probs, faltten_dispatched_probs_grad) + dispatched_probs_grad = faltten_dispatched_probs_grad.reshape(self.dispatched_probs_shape) + + self.reset_status() + return hidden_states_grad, dispatched_probs_grad + + +def tokens_zip_unique_add_with_subbatch(zipped, unzipped, index_unzipped, zipped_rows, subbatch_rows=None): + """ + tokens_zip_unique_add_with_subbatch + """ + if subbatch_rows is None or subbatch_rows <= 0 or zipped_rows <= 0: + return TDU.tokens_zip_unique_add(zipped, unzipped, index_unzipped, zipped_rows) + else: + if isinstance(zipped, paddle.Tensor): + num_split = (zipped_rows + subbatch_rows - 1) // subbatch_rows + remainder = zipped_rows % subbatch_rows + if remainder == 0: + rows = [subbatch_rows] * num_split + else: + rows = [subbatch_rows] * (num_split - 1) + [remainder] + + if zipped.shape[0] == 0: + dtype = zipped.dtype + hidden_size = zipped.shape[1] + zipped = [paddle.zeros([r, hidden_size], dtype=dtype) for r in rows] + else: + zipped = paddle.split(zipped, rows, axis=0) + return TDU.tokens_zip_unique_add_subbatch(zipped, unzipped, index_unzipped, zipped_rows, subbatch_rows) + + +def merge_subbatch_cast(x, dtype): + if isinstance(x, (list, tuple)): + if len(x) == 1: + x = x[0] + return x.cast(dtype) if x.dtype != dtype else x + else: + return TDU.merge_subbatch_cast(x, dtype) + else: + return x.cast(dtype) if x.dtype != dtype else x + + +def get_env_device(): + """ + Return the device name of running environment. + """ + if paddle.is_compiled_with_cuda(): + return "gpu" + elif "npu" in paddle.device.get_all_custom_device_type(): + return "npu" + elif "mlu" in paddle.device.get_all_custom_device_type(): + return "mlu" + elif "gcu" in paddle.device.get_all_custom_device_type(): + return "gcu" + elif "intel_hpu" in paddle.device.get_all_custom_device_type(): + return "intel_hpu" + elif paddle.is_compiled_with_rocm(): + return "rocm" + elif paddle.is_compiled_with_xpu(): + return "xpu" + return "cpu" + + +def to_device(tensor, place=None): + if place is None: + place = get_env_device() + + if isinstance(place, str): + place = paddle.device._convert_to_place(place) + + if not tensor.place._equals(place): + new_t = tensor._copy_to(place, True) + dst_tensor = tensor.value().get_tensor() + src_tensor = new_t.value().get_tensor() + dst_tensor._share_data_with(src_tensor) + + return tensor + + +def offload(tensor): + if paddle.is_compiled_with_cuda(): + place = paddle.CUDAPinnedPlace() + else: + place = paddle.CPUPlace() + + new_tensor = to_device(tensor, place) + assert new_tensor is tensor, "to_device must be inplace operation" + + +def reload(tensor): + new_tensor = to_device(tensor) + assert new_tensor is tensor, "to_device must be inplace operation" diff --git a/paddleformers/transformers/token_dispatcher.py b/paddleformers/transformers/token_dispatcher.py index 128f6e52f4d..fd395cedac8 100644 --- a/paddleformers/transformers/token_dispatcher.py +++ b/paddleformers/transformers/token_dispatcher.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from abc import ABC, abstractmethod from typing import Optional, Tuple @@ -21,7 +22,13 @@ from paddle.distributed.communication.group import Group from .fused_a2a import fused_combine, fused_dispatch -from .moe_utils import permute, unpermute + +if not os.getenv("DSV3_FAST_PRETRAIN", "False"): + from .moe_utils import permute, topk_to_permuted_indices, unpermute +else: + from .moe_utils import permute_fast as permute + from .moe_utils import topk_to_permuted_indices + from .moe_utils import unpermute_fast as unpermute class _DispatchManager(ABC): @@ -118,16 +125,17 @@ def setup_metadata(self, routing_map: paddle.Tensor, probs: paddle.Tensor): # Convert the format of routing map from multihot to indices. self.token_probs, self.token_indices = paddle.topk(probs, self.router_topk, axis=-1) - def dispatch(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + def dispatch( + self, hidden_states: paddle.Tensor, token_indices: paddle.Tensor, token_probs: paddle.Tensor + ) -> paddle.Tensor: hidden_states, dispatched_probs, states = fused_dispatch( - hidden_states, self.token_indices, self.token_probs, self.num_experts, self.group + hidden_states, token_indices, token_probs, self.num_experts, self.group ) self.handle = states["handle"] - self.tokens_per_expert = states["tokens_per_expert"] - self.dispatched_indices = states["dispatched_indices"] - self.dispatched_probs = dispatched_probs + self.tokens_per_expert_list = states["tokens_per_expert"] + dispatched_indices = states["dispatched_indices"] - return hidden_states + return hidden_states, dispatched_indices, dispatched_probs def _indices_to_multihot(self, indices, probs): """ @@ -181,6 +189,16 @@ def get_permuted_hidden_states_by_experts(self, hidden_states: paddle.Tensor) -> ) return hidden_states + def get_permuted_hidden_states_by_experts_fast( + self, hidden_states: paddle.Tensor, dispatched_indices: paddle.Tensor + ) -> paddle.Tensor: + self.hidden_shape_before_permute = hidden_states.shape + token_permuted_indices, prob_permuted_indices = topk_to_permuted_indices( + dispatched_indices, self.tokens_per_expert_list, self.router_topk + ) + hidden_states = permute(hidden_states, token_permuted_indices) + return hidden_states, token_permuted_indices, prob_permuted_indices + def get_restored_hidden_states_by_experts(self, hidden_states: paddle.Tensor) -> paddle.Tensor: input_dtype = hidden_states.dtype assert self.dispatched_probs.dtype == paddle.float32, "DeepEP only supports float32 probs" @@ -193,6 +211,24 @@ def get_restored_hidden_states_by_experts(self, hidden_states: paddle.Tensor) -> ) return hidden_states.to(input_dtype) + def get_restored_hidden_states_by_experts_fast( + self, + hidden_states: paddle.Tensor, + token_permuted_indices: paddle.Tensor, + prob_permuted_indices: paddle.Tensor, + dispatched_probs: paddle.Tensor, + ) -> paddle.Tensor: + input_dtype = hidden_states.dtype + assert dispatched_probs.dtype == paddle.float32, "DeepEP only supports float32 probs" + hidden_states = unpermute( + permuted_tokens=hidden_states, + token_permuted_indices=token_permuted_indices, + prob_permuted_indices=prob_permuted_indices, + restore_shape=self.hidden_shape_before_permute, + probs=dispatched_probs, + ) + return hidden_states.to(input_dtype) + class MoETokenDispatcher: """ @@ -267,7 +303,7 @@ def token_permutation( hidden_states = hidden_states.view([-1, self.hidden_shape[-1]]) self._comm_manager.setup_metadata(routing_map, probs) - hidden_states = self._comm_manager.dispatch(hidden_states) + hidden_states, _, _ = self._comm_manager.dispatch(hidden_states) global_input_tokens = self._comm_manager.get_permuted_hidden_states_by_experts(hidden_states) tokens_per_expert = self._comm_manager.get_number_of_tokens_per_expert() @@ -282,3 +318,135 @@ def token_unpermutation( hidden_states = hidden_states.reshape(self.hidden_shape) return hidden_states, None + + +class MoEFlexTokenDispatcherFast: + """ + Flexible token dispatcher for MoE models with Efficient-A2A communication kernels. + """ + + def __init__(self, num_local_experts: int, moe_router_topk: int, num_moe_experts: int, ep_group: Group): + self._ep_group = ep_group + + self.num_local_experts = num_local_experts + assert self.ep_size > 1, "Flex token dispatcher requires EP > 1" + self._comm_manager = _DeepepManager( + group=self.ep_group, + router_topk=moe_router_topk, + num_experts=num_moe_experts, + num_local_experts=self.num_local_experts, + ) + + @property + def ep_group(self): + """Get expert model parallel group.""" + return self._ep_group + + @property + def ep_size(self): + """Get expert model parallel world_size.""" + return self.ep_group.world_size + + def pre_dispatch(self, hidden_states, probs, routing_map): + self.hidden_shape = hidden_states.shape + hidden_states = hidden_states.view([-1, self.hidden_shape[-1]]) + num_tokens = routing_map.shape[0] + routing_map = routing_map.reshape([num_tokens, self._comm_manager.num_experts]) + probs = probs.reshape([num_tokens, self._comm_manager.num_experts]) + # Convert the format of routing map from multihot to indices. + token_probs, token_indices = paddle.topk(probs, self._comm_manager.router_topk, axis=-1) + return hidden_states, token_indices, token_probs + + def post_dispatch(self, hidden_states, dispatched_indices): + ( + global_input_tokens, + token_permuted_indices, + prob_permuted_indices, + ) = self._comm_manager.get_permuted_hidden_states_by_experts_fast(hidden_states, dispatched_indices) + return (global_input_tokens, token_permuted_indices, prob_permuted_indices) + + def pre_combine(self, hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs): + hidden_states = self._comm_manager.get_restored_hidden_states_by_experts_fast( + hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs + ) + return hidden_states + + def post_combine(self, hidden_states): + hidden_states = hidden_states.reshape(self.hidden_shape) + return hidden_states + + def token_permutation( + self, hidden_states: paddle.Tensor, probs: paddle.Tensor, routing_map: paddle.Tensor + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + hidden_states, token_indices, token_probs = self.pre_dispatch(hidden_states, probs, routing_map) + hidden_states, dispatched_indices, dispatched_probs = self._comm_manager.dispatch( + hidden_states, token_indices, token_probs + ) + (global_input_tokens, token_permuted_indices, prob_permuted_indices) = self.post_dispatch( + hidden_states, dispatched_indices + ) + + return ( + global_input_tokens, + token_permuted_indices, + prob_permuted_indices, + dispatched_probs, + ) + + def token_unpermutation( + self, + hidden_states: paddle.Tensor, + token_permuted_indices, + prob_permuted_indices, + dispatched_probs, + bias: Optional[paddle.Tensor] = None, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]: + assert bias is None, "Bias is not supported in MoEFlexTokenDispatcher" + hidden_states = self.pre_combine( + hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs + ) + hidden_states = self._comm_manager.combine(hidden_states) + + hidden_states = self.post_combine(hidden_states) + return hidden_states, None + + +class PreDispatchNode: + def __init__(self, token_dispatcher): + self.token_dispatcher = token_dispatcher + self.probs_origin_shape = None + + def reset_status(self): + self.probs = None + self.reshaped_probs = None + self.token_indices = None + + @paddle.no_grad() + def forward(self, routing_map, probs): + num_tokens = routing_map.shape[0] + self.probs_origin_shape = probs.shape + # routing_map = routing_map.reshape([num_tokens, token_dispatcher._comm_manager.num_experts]) + self.probs = probs + reshaped_probs = probs.reshape([num_tokens, self.token_dispatcher._comm_manager.num_experts]) + self.reshaped_probs = reshaped_probs + token_probs, token_indices = paddle.topk( + reshaped_probs, self.token_dispatcher._comm_manager.router_topk, axis=-1 + ) + self.token_indices = token_indices + token_probs.stop_gradient = False + return token_indices, token_probs + + @paddle.no_grad() + def backward(self, token_probs_g): + probs_grad = paddle._C_ops.topk_grad( + self.reshaped_probs, + self.token_indices, + token_probs_g, + self.token_dispatcher._comm_manager.router_topk, + -1, + True, + True, + ) + probs_reshape_g = paddle._C_ops.reshape_grad(self.probs, probs_grad) + self.reset_status() + return probs_reshape_g diff --git a/paddleformers/transformers/utils.py b/paddleformers/transformers/utils.py index 83c85fc147f..fb74c5a3f86 100644 --- a/paddleformers/transformers/utils.py +++ b/paddleformers/transformers/utils.py @@ -31,8 +31,7 @@ from filelock import FileLock from paddleformers import __version__ - -from ..utils.downloader import ( +from paddleformers.utils.downloader import ( COMMUNITY_MODEL_PREFIX, download_check, get_path_from_url_with_filelock, @@ -629,7 +628,7 @@ def cached_file_for_hf_hub( filename=filename, cache_dir=cache_dir, subfolder=subfolder, - library_name="PaddleNLP", + library_name="PaddleFormers", library_version=__version__, ) return resolved_file @@ -1005,3 +1004,10 @@ def caculate_llm_per_token_flops( # 2 for mul + add in matmul # 1 for forward, 2 for backwards since we caluate gradients for input_x and input_y return 2 * (layer_num * (flops_per_transformer * 3 + flops_recompute_transformer) + 3 * flops_loggits) / seq_length + + +def cast_if_needed(x, dtype): + """ + cast_if_needed + """ + return x.cast(dtype) if x.dtype != dtype else x diff --git a/paddleformers/utils/download/bos_download.py b/paddleformers/utils/download/bos_download.py new file mode 100644 index 00000000000..e9f3183ae32 --- /dev/null +++ b/paddleformers/utils/download/bos_download.py @@ -0,0 +1,287 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import re +import tempfile +from contextlib import contextmanager +from functools import partial +from pathlib import Path +from typing import Dict, Literal, Optional, Union + +from filelock import FileLock +from huggingface_hub.utils import ( + EntryNotFoundError, + GatedRepoError, + HfHubHTTPError, + RepositoryNotFoundError, + RevisionNotFoundError, +) + +logger = logging.getLogger(__name__) + +from ..env import MODEL_HOME +from .common import ( + DEFAULT_ETAG_TIMEOUT, + DEFAULT_REQUEST_TIMEOUT, + AistudioBosFileMetadata, + _as_int, + _chmod_and_replace, + _normalize_etag, + _request_wrapper, + http_get, + raise_for_status, +) + +ENDPOINT = os.getenv("PPNLP_ENDPOINT", "https://bj.bcebos.com/paddlenlp") +ENDPOINT_v2 = "https://paddlenlp.bj.bcebos.com" + +BOS_URL_TEMPLATE = ENDPOINT + "/{repo_type}/community/{repo_id}/{revision}/{filename}" +BOS_URL_TEMPLATE_WITHOUT_REVISION = ENDPOINT + "/{repo_type}/community/{repo_id}/{filename}" + + +REGEX_COMMIT_HASH = re.compile(r"^[0-9a-f]{40}$") +REPO_TYPE = "models" + + +def get_bos_file_metadata( + url: str, + token: Union[bool, str, None] = None, + proxies: Optional[Dict] = None, + timeout: Optional[float] = DEFAULT_REQUEST_TIMEOUT, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Union[Dict, str, None] = None, +): + """Fetch metadata of a file versioned on the Hub for a given url. + + Args: + url (`str`): + File url, for example returned by [`bos_url`]. + token (`str` or `bool`, *optional*): + A token to be used for the download. + - If `True`, the token is read from the BOS config + folder. + - If `False` or `None`, no token is provided. + - If a string, it's used as the authentication token. + proxies (`dict`, *optional*): + Dictionary mapping protocol to the URL of the proxy passed to + `requests.request`. + timeout (`float`, *optional*, defaults to 10): + How many seconds to wait for the server to send metadata before giving up. + library_name (`str`, *optional*): + The name of the library to which the object corresponds. + library_version (`str`, *optional*): + The version of the library. + user_agent (`dict`, `str`, *optional*): + The user-agent info in the form of a dictionary or a string. + + Returns: + A [`AistudioBosFileMetadata`] object containing metadata such as location, etag, size and + commit_hash. + """ + headers = {} + headers["Accept-Encoding"] = "identity" # prevent any compression => we want to know the real size of the file + + # Retrieve metadata + r = _request_wrapper( + method="HEAD", + url=url, + headers=headers, + allow_redirects=False, + follow_relative_redirects=True, + proxies=proxies, + timeout=timeout, + ) + raise_for_status(r) + + # Return + return AistudioBosFileMetadata( + commit_hash=None, + etag=_normalize_etag(r.headers.get("ETag")), + location=url, + size=_as_int(r.headers.get("Content-Length")), + ) + + +def bos_url( + repo_id: str, + filename: str, + *, + subfolder: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + endpoint: Optional[str] = None, +) -> str: + if subfolder == "": + subfolder = None + if subfolder is not None: + filename = f"{subfolder}/{filename}" + + url = BOS_URL_TEMPLATE_WITHOUT_REVISION.format( + repo_type=REPO_TYPE, + repo_id=repo_id, + filename=filename, + ) + + # Update endpoint if provided + if endpoint is not None and url.startswith(ENDPOINT): + url = endpoint + url[len(ENDPOINT) :] + return url + + +def bos_download( + repo_id: str = None, + filename: str = None, + subfolder: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", + user_agent: Union[Dict, str, None] = None, + force_download: bool = False, + proxies: Optional[Dict] = None, + etag_timeout: float = DEFAULT_ETAG_TIMEOUT, + resume_download: bool = False, + token: Optional[str] = None, + local_files_only: bool = False, + endpoint: Optional[str] = None, + url: Optional[str] = None, + **kwargs, +): + if url is not None: + if repo_id is None: + if url.startswith(ENDPOINT): + repo_id = "/".join(url[len(ENDPOINT) + 1 :].split("/")[:-1]) + else: + repo_id = "/".join(url[len(ENDPOINT_v2) + 1 :].split("/")[:-1]) + if filename is None: + filename = url.split("/")[-1] + subfolder = None + + if cache_dir is None: + cache_dir = MODEL_HOME + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + if subfolder == "": + subfolder = None + if subfolder is not None: + # This is used to create a URL, and not a local path, hence the forward slash. + filename = f"{subfolder}/{filename}" + + locks_dir = os.path.join(cache_dir, ".locks") + + storage_folder = os.path.join(cache_dir, repo_id) + os.makedirs(storage_folder, exist_ok=True) + if subfolder is not None: + storage_sub_folder = os.path.join(storage_folder, subfolder) + os.makedirs(storage_sub_folder, exist_ok=True) + + if url is None: + url = bos_url(repo_id, filename, repo_type=REPO_TYPE, endpoint=endpoint) + headers = None + url_to_download = url + lock_path = os.path.join(locks_dir, repo_id, f"{filename}.lock") + file_path = os.path.join(cache_dir, repo_id, filename) + + if os.name == "nt" and len(os.path.abspath(lock_path)) > 255: + lock_path = "\\\\?\\" + os.path.abspath(lock_path) + + if os.name == "nt" and len(os.path.abspath(file_path)) > 255: + file_path = "\\\\?\\" + os.path.abspath(file_path) + + Path(lock_path).parent.mkdir(parents=True, exist_ok=True) + with FileLock(lock_path): + # If the download just completed while the lock was activated. + if os.path.exists(file_path) and not force_download: + # Even if returning early like here, the lock will be released. + return file_path + + if resume_download: + incomplete_path = file_path + ".incomplete" + + @contextmanager + def _resumable_file_manager(): + with open(incomplete_path, "ab") as f: + yield f + + temp_file_manager = _resumable_file_manager + if os.path.exists(incomplete_path): + resume_size = os.stat(incomplete_path).st_size + else: + resume_size = 0 + else: + temp_file_manager = partial( # type: ignore + tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False + ) + resume_size = 0 + + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with temp_file_manager() as temp_file: + logger.info("downloading %s to %s", url_to_download, temp_file.name) + + http_get( + url_to_download, + temp_file, + proxies=proxies, + resume_size=resume_size, + headers=headers, + ) + + logger.info("storing %s in cache at %s", url_to_download, file_path) + _chmod_and_replace(temp_file.name, file_path) + try: + os.remove(lock_path) + except OSError: + pass + return file_path + + +def bos_file_exists( + repo_id: str, + filename: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + token: Optional[str] = None, + endpoint: Optional[str] = None, +) -> bool: + url = bos_url(repo_id=repo_id, repo_type=REPO_TYPE, filename=filename, endpoint=endpoint) + try: + get_bos_file_metadata(url, token=token) + return True + except GatedRepoError: # raise specifically on gated repo + raise + except (RepositoryNotFoundError, EntryNotFoundError, RevisionNotFoundError, HfHubHTTPError): + return False + + +def bos_try_to_load_from_cache( + repo_id: str, + filename: str, + cache_dir: Union[str, Path, None] = None, + revision: Optional[str] = None, + repo_type: Optional[str] = None, +): + if cache_dir is None: + cache_dir = MODEL_HOME + + cached_file = os.path.join(cache_dir, repo_id, filename) + return cached_file if os.path.isfile(cached_file) else None diff --git a/paddleformers/utils/download/download.py b/paddleformers/utils/download/download.py index bcc2e5bde70..dce51c9618c 100644 --- a/paddleformers/utils/download/download.py +++ b/paddleformers/utils/download/download.py @@ -44,6 +44,7 @@ class DownloadSource(str, Enum): HUGGINGFACE = "huggingface" AISTUDIO = "aistudio" MODELSCOPE = "modelscope" + BOS = "bos" MODEL_MAPPINGS = {} @@ -64,6 +65,7 @@ def check_repo(model_name_or_path, download_hub): DownloadSource.HUGGINGFACE, DownloadSource.AISTUDIO, DownloadSource.MODELSCOPE, + DownloadSource.BOS, ], f"download_hub must be one of {DownloadSource.HUGGINGFACE}, {DownloadSource.AISTUDIO}, {DownloadSource.MODELSCOPE}" if model_name_or_path not in HF_MODEL_MAPPINGS.keys(): # repo id set by user @@ -88,6 +90,39 @@ def strtobool(v): ) +from .bos_download import bos_download, bos_file_exists + + +def bos_hf_file_exist( + repo_id: str, + filename: str, + *, + subfolder: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + token: Optional[str] = None, + endpoint: Optional[str] = None, + from_bos: bool = True, + from_aistudio: bool = False, + from_hf_hub: bool = False, +): + assert repo_id is not None, "repo_id cannot be None" + assert filename is not None, "filename cannot be None" + + if subfolder is None: + subfolder = "" + filename = os.path.join(subfolder, filename) + out = bos_file_exists( + repo_id=repo_id, + filename=filename, + repo_type=repo_type, + revision=revision, + token=token, # donot need token + endpoint=endpoint, + ) + return out + + def resolve_file_path( repo_id: str = None, filenames: Union[str, list] = None, @@ -238,6 +273,28 @@ def resolve_file_path( ) if cached_file is not None: return cached_file + else: + log_endpoint = "BOS" + for filename in filenames: + download_kwargs["filename"] = filename + is_available = bos_hf_file_exist( + repo_id, + filename, + subfolder=subfolder, + repo_type=repo_type, + revision=revision, + token=token, + endpoint=endpoint, + from_bos=True, + from_aistudio=False, + from_hf_hub=False, + ) + if is_available: + cached_file = bos_download( + **download_kwargs, + ) + if cached_file is not None: + return cached_file except LocalEntryNotFoundError: raise EnvironmentError( "Cannot find the requested files in the cached path and"