diff --git a/paddleformers/examples/deepseek_v3/config/config.json b/paddleformers/examples/deepseek_v3/config/config.json new file mode 100644 index 00000000000..21d15b679cd --- /dev/null +++ b/paddleformers/examples/deepseek_v3/config/config.json @@ -0,0 +1,76 @@ +{ + "architectures": [ + "DeepseekV3ForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_deepseek.DeepseekV3Config", + "AutoModel": "modeling_deepseek.DeepseekV3Model", + "AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM" + }, + "aux_loss_alpha": 0.001, + "bos_token_id": 0, + "eos_token_id": 1, + "ep_size": 1, + "first_k_dense_replace": 3, + "hidden_act": "silu", + "hidden_size": 7168, + "initializer_range": 0.02, + "intermediate_size": 18432, + "kv_lora_rank": 512, + "max_position_embeddings": 163840, + "model_type": "deepseek_v3", + "moe_intermediate_size": 2048, + "moe_layer_freq": 1, + "n_group": 8, + "n_routed_experts": 8, + "n_shared_experts": 1, + "norm_topk_prob": true, + "num_attention_heads": 128, + "num_experts_per_tok": 8, + "num_hidden_layers": 15, + "num_key_value_heads": 128, + "num_nextn_predict_layers": 1, + "pretraining_tp": 1, + "q_lora_rank": 1536, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn" + }, + "rope_theta": 10000, + "routed_scaling_factor": 2.5, + "scoring_func": "sigmoid", + "seq_aux": true, + "tie_word_embeddings": false, + "topk_group": 4, + "topk_method": "noaux_tc", + "dtype": "bfloat16", + "transformers_version": "4.33.1", + "use_cache": true, + "v_head_dim": 128, + "vocab_size": 129280, + "using_flex_token": true, + "using_fake_gate": true, + "use_fused_rms_norm": true, + "fuse_attention_ffn": true, + "use_fused_rope": true, + "token_drop_steps": 0, + "recompute_fwd_gate_up": true, + "adaptive_remained_O1_recompute_ratio": 0.3, + "using_post_norm_recompute": true, + "is_split_group_gemm": false, + "use_dualpipev": true, + "send_mtp_embed": true, + "offline_quant_expert_weight": false, + "clear_origin_weight_when_offline_quant": false + } + diff --git a/paddleformers/examples/deepseek_v3/config/pretrain_argument.json b/paddleformers/examples/deepseek_v3/config/pretrain_argument.json new file mode 100644 index 00000000000..0c8d4aefed9 --- /dev/null +++ b/paddleformers/examples/deepseek_v3/config/pretrain_argument.json @@ -0,0 +1,53 @@ +{ + "model_name_or_path": "./config/", + "tokenizer_name_or_path": "deepseek-ai/DeepSeek-V3", + "input_dir": "./data", + "output_dir": "./checkpoints/pretrain_ckpts", + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 24, + "per_device_eval_batch_size": 1, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 4, + "pipeline_parallel_config": "use_dualpipev", + "sharding_parallel_degree": 2, + "sharding_parallel_config": "split_param", + "sharding_comm_buffer_size_MB": 2048, + "expert_parallel_degree": 2, + "sharding": "stage1", + "virtual_pp_degree": 1, + "sequence_parallel": 0, + "use_flash_attention": true, + "max_seq_length": 4097, + "learning_rate": 3e-05, + "min_learning_rate": 3e-06, + "warmup_steps": 30, + "logging_steps": 1, + "max_steps": 200, + "save_steps": 5000, + "eval_steps": 1000, + "weight_decay": 0.01, + "bf16": true, + "fp16_opt_level": "O2", + "warmup_ratio": 0.01, + "max_grad_norm": 1.0, + "amp_master_grad": 1, + "dataloader_num_workers": 8, + "continue_training": 0, + "do_train": true, + "do_eval": true, + "do_predict": false, + "disable_tqdm": true, + "recompute": false, + "distributed_dataloader": 1, + "unified_checkpoint": true, + "save_total_limit": 2, + "skip_profile_timer": false, + "use_fused_rms_norm": true, + "fuse_attention_ffn": true, + "use_fused_rope": true, + "save_sharded_model": false, + "load_sharded_model": false, + "use_expert_parallel": true, + "unified_checkpoint_config": "skip_save_model_weight", + "offload_optim": true + } \ No newline at end of file diff --git a/paddleformers/examples/deepseek_v3/run.sh b/paddleformers/examples/deepseek_v3/run.sh new file mode 100644 index 00000000000..b701688fae8 --- /dev/null +++ b/paddleformers/examples/deepseek_v3/run.sh @@ -0,0 +1,22 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# llama 模型数据下载 +# mkdir -p data +# wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k.bin +# wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k.idx + +rm -rf output +GLOG_v=6 nohup sh script/train_gpu.sh config/pretrain_argument.json > run.log 2>&1 & + diff --git a/paddleformers/examples/deepseek_v3/run_pretrain.py b/paddleformers/examples/deepseek_v3/run_pretrain.py new file mode 100644 index 00000000000..fbffa892782 --- /dev/null +++ b/paddleformers/examples/deepseek_v3/run_pretrain.py @@ -0,0 +1,615 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import math +import os +import sys +import time +from dataclasses import dataclass, field +from typing import Optional + +import paddle + +from paddleformers.data.causal_dataset import ( + build_train_valid_test_datasets, + check_data_split, + print_rank_0, +) +from paddleformers.trainer import ( + FP8QuantWeightCallback, + PdArgumentParser, + StepFlexToken, + Trainer, + TrainingArguments, + get_last_checkpoint, + set_seed, + speed_metrics, +) +from paddleformers.transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForCausalLMPipe, + AutoTokenizer, + CosineAnnealingWithWarmupDecay, + LinearAnnealingWithWarmupDecay, +) +from paddleformers.transformers.configuration_utils import LlmMetaConfig, llmmetaclass +from paddleformers.utils.batch_sampler import DistributedBatchSampler +from paddleformers.utils.log import logger +from paddleformers.utils.tools import get_env_device + +# Pretaining Environment Variables to support sharding stage1 overlap optimization. +os.environ["USE_CASUAL_MASK"] = "True" + + +from paddleformers.trainer.utils.doc import add_start_docstrings + + +@dataclass +@llmmetaclass +@add_start_docstrings(TrainingArguments.__doc__) +class PreTrainingArguments(TrainingArguments): + min_learning_rate: float = field( + default=1e-5, + metadata={"help": "Minimum learning rate deacyed to."}, + ) + decay_steps: float = field( + default=None, + metadata={ + "help": "The steps use to control the learing rate. If the step > decay_steps, will use the min_learning_rate." + }, + ) + enable_linear_fused_grad_add: bool = field( + default=False, + metadata={ + "help": "Enable fused linear grad add strategy, which will reduce elementwise add for grad accumulation in the backward of nn.Linear ." + }, + ) + # NOTE(gongenlei): new add autotuner_benchmark + autotuner_benchmark: bool = field( + default=False, + metadata={"help": "Weather to run benchmark by autotuner. True for from_scratch and pad_max_length."}, + ) + unified_checkpoint: bool = field( + default=True, + metadata={"help": "Enable fused linear grad add strategy."}, + ) + + def __post_init__(self): + super().__post_init__() + # NOTE(gongenlei): new add autotuner_benchmark + from paddleformers.trainer.trainer_utils import IntervalStrategy + + if self.autotuner_benchmark: + self.max_steps = 5 + self.do_train = True + self.do_export = False + self.do_predict = False + self.do_eval = False + self.overwrite_output_dir = True + self.load_best_model_at_end = False + self.report_to = [] + self.save_strategy = IntervalStrategy.NO + self.evaluation_strategy = IntervalStrategy.NO + self.unified_checkpoint = False + + +@dataclass +class DataArguments: + """ + Arguments pertaining to what data we are going to input our model for training and evaluating. + Using `PdArgumentParser` we can turn this class into argparse arguments to be able to + specify them on the command line. + """ + + input_dir: str = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + split: str = field(default="949,50,1", metadata={"help": "Train/valid/test data split."}) + + max_seq_length: int = field( + default=1024, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + share_folder: bool = field( + default=False, + metadata={"help": "Use share folder for data dir and output dir on multi machine."}, + ) + + data_impl: str = field(default="mmap", metadata={"help": "The format of the preprocessed data."}) + skip_warmup: bool = field( + default=True, + metadata={"help": "Whether to skip the warmup process of mmap files."}, + ) + data_cache: str = field(default=None, metadata={"help": "The path of the cached dataset."}) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to pre-train from. + """ + + model_name_or_path: str = field( + default="__internal_testing__/tiny-random-llama", + metadata={ + "help": "Path to pretrained model or model identifier from https://paddleformers.readthedocs.io/zh/latest/model_zoo/transformers.html" + }, + ) + tokenizer_name_or_path: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + + use_fast_layer_norm: bool = field( + default=False, + metadata={"help": "GPT3 model, use fast layernorm"}, + ) + + hidden_dropout_prob: float = field(default=0.1, metadata={"help": "The hidden dropout prob."}) + attention_probs_dropout_prob: float = field(default=0.1, metadata={"help": "The attention hidden dropout prob."}) + + fuse_attention_qkv: bool = field( + default=None, + metadata={"help": "whether to fuse attention qkv"}, + ) + fuse_attention_ffn: bool = field( + default=None, + metadata={"help": "whether to fuse first up and gate proj in mlp block"}, + ) + + continue_training: bool = field( + default=False, + metadata={ + "help": "Pre-training from existing paddleformers model weights. Default False and model will train from scratch. If set True, the model_name_or_path argument must exist in the paddleformers models." + }, + ) + num_hidden_layers: Optional[int] = field( + default=None, + metadata={"help": "num_hidden_layers."}, + ) + + +def create_pretrained_dataset( + data_args, + training_args, + data_file, + tokenizer, + need_data=True, +): + + check_data_split(data_args.split, training_args.do_train, training_args.do_eval, training_args.do_predict) + + train_val_test_num_samples = [ + training_args.per_device_train_batch_size + * training_args.dataset_world_size + * training_args.max_steps + * training_args.gradient_accumulation_steps, + training_args.per_device_eval_batch_size + * training_args.dataset_world_size + * training_args.eval_iters + * (training_args.max_steps // training_args.eval_steps + 1), + training_args.per_device_eval_batch_size * training_args.dataset_world_size * training_args.test_iters, + ] + + print_rank_0(" > datasets target sizes (minimum size):") + if training_args.do_train: + print_rank_0(" train: {}".format(train_val_test_num_samples[0])) + if training_args.do_eval: + print_rank_0(" validation: {}".format(train_val_test_num_samples[1])) + if training_args.do_predict: + print_rank_0(" test: {}".format(train_val_test_num_samples[2])) + + # Build the datasets. + train_dataset, valid_dataset, test_dataset = build_train_valid_test_datasets( + data_prefix=data_file, + data_impl=data_args.data_impl, + splits_string=data_args.split, + train_val_test_num_samples=train_val_test_num_samples, + seq_length=data_args.max_seq_length, + seed=training_args.seed, + skip_warmup=data_args.skip_warmup, + share_folder=data_args.share_folder, + data_cache_path=data_args.data_cache, + need_data=need_data, + ) + + def print_dataset(data, mode="train"): + logger.info(f"Sample data for {mode} mode.") + # input_ids, loss_mask, attention_mask, position_ids, labels = data + input_ids = data["text"] + logger.info(tokenizer._decode(list(input_ids))) + + from paddleformers.data import Stack + + def _collate_data(data, stack_fn=Stack()): + tokens_ = stack_fn([x["text"] for x in data]) + + labels = copy.deepcopy(tokens_)[:, 1:] + tokens = tokens_[:, :-1] + + return { + "input_ids": tokens, + "labels": labels, + } + + if need_data: + if training_args.do_train: + print_dataset(train_dataset[0], "train") + if training_args.do_eval: + print_dataset(valid_dataset[0], "valid") + if training_args.do_predict: + print_dataset(test_dataset[0], "test") + + return train_dataset, valid_dataset, test_dataset, _collate_data + + +def get_train_data_file(args): + if len(args.input_dir.split()) > 1: + # weight-1 data-prefix-1 weight-2 data-prefix-2 ... + return args.input_dir.split() + else: + files = [ + os.path.join(args.input_dir, f) + for f in os.listdir(args.input_dir) + if (os.path.isfile(os.path.join(args.input_dir, f)) and ("_idx.npz" in str(f) or ".idx" in str(f))) + ] + files = [x.replace("_idx.npz", "") for x in files] + files = [x.replace(".idx", "") for x in files] + + if len(files) > 1: + ret = [] + logger.info("You are using multi-dataset:") + for x in files: + ret.append(1.0) + ret.append(x) + logger.info(" > set weight of %s dataset to 1.0" % x) + return ret + + return files + + +class PretrainingTrainer(Trainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_pretraining = True + + def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"): + # keep eval_dataloader + eval_dataloader = getattr(self, "eval_dataloader", None) + if eval_dataloader is None: + eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset + eval_dataloader = self.get_eval_dataloader(eval_dataset) + # must call data loader, otherwise, it will init many times, cause OOM error. + self.eval_dataloader = eval_dataloader() + + start_time = time.time() + # Temporarily disable metric computation, we will do it in the loop here. + compute_metrics = self.compute_metrics + eval_loop = self.evaluation_loop + + output = eval_loop( + eval_dataloader, + description="Evaluation", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if compute_metrics is None else None, + ignore_keys=ignore_keys, + # Only evaluate max_eval_iters + max_eval_iters=self.args.eval_iters, + ) + + total_batch_size = self.args.eval_batch_size * self.args.world_size + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) + + self.log(output.metrics) + + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) + return output.metrics + + def _get_eval_sampler(self, eval_dataset) -> Optional[paddle.io.Sampler]: + return DistributedBatchSampler( + eval_dataset, + batch_size=self.args.per_device_eval_batch_size, + shuffle=False, + num_replicas=self.args.dataset_world_size, + rank=self.args.dataset_rank, + drop_last=self.args.dataloader_drop_last, + ) + + def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: + return DistributedBatchSampler( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + shuffle=False, + num_replicas=self.args.dataset_world_size, + rank=self.args.dataset_rank, + drop_last=self.args.dataloader_drop_last, + ) + + +def main(): + parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments)) + # Support format as "args.json --arg1 value1 --arg2 value2.” + # In case of conflict, command line arguments take precedence. + if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): + model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines() + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if training_args.no_recompute_layers is not None: + training_args.no_recompute_layers.sort() + + if training_args.enable_linear_fused_grad_add: + from utils.fused_layers import mock_layers + + mock_layers() + + if model_args.tokenizer_name_or_path is None: + model_args.tokenizer_name_or_path = model_args.model_name_or_path + + if data_args.data_cache is not None: + os.makedirs(data_args.data_cache, exist_ok=True) + + paddle.set_device(training_args.device) + set_seed(seed=training_args.seed) + + training_args.eval_iters = 10 + training_args.test_iters = training_args.eval_iters * 10 + + # Log model and data config + training_args.print_config(model_args, "Model") + training_args.print_config(data_args, "Data") + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, " + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}" + ) + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + # if last_checkpoint is None and len( + # os.listdir(training_args.output_dir)) > 1: + # raise ValueError( + # f"Output directory ({training_args.output_dir}) already exists and is not empty. " + # "Use --overwrite_output_dir to overcome.") + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path) + config = AutoConfig.from_pretrained(model_args.model_name_or_path) + + # set all llm config + LlmMetaConfig.set_llm_config(config, training_args) + config.use_fast_layer_norm = model_args.use_fast_layer_norm + + config.seq_length = data_args.max_seq_length + # There are some technique extend RotaryEmbedding context. so don't change max_position_embeddings + if not model_args.continue_training: + config.max_position_embeddings = max(config.max_position_embeddings, data_args.max_seq_length) + + if not model_args.continue_training: + config.vocab_size = max(config.vocab_size, ((tokenizer.vocab_size - 1) // 128 + 1) * 128) + logger.info(f"Reset vocab size to {config.vocab_size} for batter amp peformance.") + + config.num_hidden_layers = ( + model_args.num_hidden_layers if model_args.num_hidden_layers is not None else config.num_hidden_layers + ) + # Config for model using dropout, such as GPT. + if hasattr(config, "use_dualpipev"): + # NOTE(zhangyuqin): In Paddle, the segmentation and scheduling of pipeline parallel + # models are separate. Therefore, first we need to set the flag in the model config + # to perform V-shape segmentation. Second, we need to set the flag in the training_args + # to configure strategy.hybrid_configs to choose the DualPipeV schedule. + config.use_dualpipev = "use_dualpipev" in training_args.pipeline_parallel_config + if hasattr(config, "hidden_dropout_prob"): + config.hidden_dropout_prob = model_args.hidden_dropout_prob + if hasattr(config, "attention_probs_dropout_prob"): + config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob + if model_args.fuse_attention_qkv is not None: + config.fuse_attention_qkv = model_args.fuse_attention_qkv + if model_args.fuse_attention_ffn is not None: + config.fuse_attention_ffn = model_args.fuse_attention_ffn + + if config.sequence_parallel: + assert config.tensor_parallel_degree > 1, "tensor_parallel_degree must be larger than 1 for sequence parallel." + assert ( + config.num_attention_heads % config.sep_parallel_degree == 0 + ), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}" + assert ( + config.seq_length % config.context_parallel_degree == 0 + ), f"seq_length:{config.seq_length} must be divisible by context_parallel_degree {config.context_parallel_degree}" + + if training_args.sharding_parallel_config is not None: + # for stage1 overlap optimization + if ( + "enable_stage1_allgather_overlap" in training_args.sharding_parallel_config + or "enable_stage1_broadcast_overlap" in training_args.sharding_parallel_config + ): + from paddle.io.reader import use_pinned_memory + + use_pinned_memory(False) + + if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1: + try: + from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401 + + LinearConfig.enable_accumulate_steps_opt() + LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps) + except ImportError: + # It's OK, not use accumulate_steps optimization + pass + + print("Final pre-training config:", config) + + # Set the dtype for loading model + dtype = "float32" + if training_args.fp16_opt_level == "O2": + if training_args.fp16: + dtype = "float16" + if training_args.bf16: + dtype = "bfloat16" + + model_class = AutoModelForCausalLM + if training_args.pipeline_parallel_degree > 1: + model_class = AutoModelForCausalLMPipe + if "LLama" in str(config.architectures): + try: + from utils.register_reshard import register_pp_reshard_information + + register_pp_reshard_information(config.num_hidden_layers) + except: + print("Not register llama pp reshard information.") + + architectures_to_check = {"Qwen2Moe", "DeepseekV2", "DeepseekV3"} + if ( + any(architecture in str(config.architectures) for architecture in architectures_to_check) + and training_args.data_parallel_degree > 1 + ): + training_args.use_expert_parallel = True + + if model_args.continue_training: + # NOTE(gongenlei): new add + if training_args.autotuner_benchmark: + model = model_class.from_config(config, dtype=dtype) + else: + model = model_class.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + ) + else: + # 修改这里降低模型层数,deepseek前3层为dense层,之后才有稀疏层 + # config.num_hidden_layers = 4 # v3是61 + # config.first_k_dense_replace = 0 # v3是3 + # # 修改这里降低模型专家数量,如果希望进行EP并行,专家数量要能够被并行度整除 + # config.n_routed_experts = 64 # v3是256 + # config.num_experts_per_tok = 8 # v3是8 + # config.topk_group = 4 # v3是4 + + # config.using_flex_token = True + # config.num_nextn_predict_layers = 1 + # config.using_fake_gate = True + # config.use_fused_rms_norm = True + # config.fuse_attention_ffn = True + # config.use_fused_rope = True + # config.token_drop_steps = 0 + model = model_class.from_config(config, dtype=dtype) + + if training_args.recompute: + model.recompute_enable() + + # Create the learning_rate sheduler and optimizer + if training_args.decay_steps is None: + training_args.decay_steps = training_args.max_steps + + if training_args.warmup_steps > 0: + warmup_steps = training_args.warmup_steps + else: + warmup_steps = training_args.warmup_ratio * training_args.max_steps + + lr_scheduler = None + if training_args.lr_scheduler_type.value == "cosine": + lr_scheduler = CosineAnnealingWithWarmupDecay( + max_lr=training_args.learning_rate, + min_lr=training_args.min_learning_rate, + warmup_step=warmup_steps, + decay_step=training_args.decay_steps, + last_epoch=0, + ) + elif training_args.lr_scheduler_type.value == "linear": + lr_scheduler = LinearAnnealingWithWarmupDecay( + max_lr=training_args.learning_rate, + min_lr=training_args.min_learning_rate, + warmup_step=warmup_steps, + decay_step=training_args.decay_steps, + last_epoch=0, + ) + + data_file = get_train_data_file(data_args) + train_dataset, eval_dataset, test_dataset, data_collator = create_pretrained_dataset( + data_args, + training_args, + data_file, + tokenizer, + need_data=training_args.should_load_dataset, + ) + + total_effective_tokens = ( + training_args.per_device_train_batch_size + * training_args.dataset_world_size + * training_args.max_steps + * training_args.gradient_accumulation_steps + * data_args.max_seq_length + ) + + callbacks = [StepFlexToken(), FP8QuantWeightCallback()] + + trainer = PretrainingTrainer( + model=model, + args=training_args, + data_collator=data_collator, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + optimizers=(None, lr_scheduler), + tokenizer=tokenizer, + callbacks=callbacks, + ) + + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=checkpoint) + + # NOTE(gongenlei): new add + if not training_args.autotuner_benchmark: + metrics = train_result.metrics + if not int(os.getenv("test_ci_no_save_model", 0)): + trainer.save_model() + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + if training_args.do_predict: + test_ret = trainer.predict(test_dataset) + trainer.log_metrics("test", test_ret.metrics) + + if training_args.do_train and training_args.should_load_dataset: + effective_tokens_per_second = total_effective_tokens / train_result.metrics["train_runtime"] + print(f"Effective Tokens per second: {effective_tokens_per_second:.2f}") + print(f"ips: {effective_tokens_per_second:.2f} tokens/s") + + +if __name__ == "__main__": + main() diff --git a/paddleformers/examples/deepseek_v3/script/kill_process.sh b/paddleformers/examples/deepseek_v3/script/kill_process.sh new file mode 100644 index 00000000000..3c3db6a4639 --- /dev/null +++ b/paddleformers/examples/deepseek_v3/script/kill_process.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -x +skip_kill_time=${1:-"False"} +function kill_impl() { + skip_kill_time=$1 + # kill aadiff test finally. + pids=`ps -ef | grep pretrain.py | grep -v grep | awk '{print $2}'` + if [[ "$pids" != "" ]] ; then + echo $pids + echo $pids | xargs kill -9 + fi + + echo "Killing processes on gpu" + lsof /dev/nvidia* | awk '{print $2}' | xargs -I {} kill -9 {} +} + +kill_impl $skip_kill_time || true \ No newline at end of file diff --git a/paddleformers/examples/deepseek_v3/script/selective_launch.py b/paddleformers/examples/deepseek_v3/script/selective_launch.py new file mode 100644 index 00000000000..1f8a37bfbc5 --- /dev/null +++ b/paddleformers/examples/deepseek_v3/script/selective_launch.py @@ -0,0 +1,73 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Selective launch script. + +Usage: python script/selective_launch.py ... +""" +import os +import sys + + +def parse_ranks(ranks_strs): + """ + parse_ranks + """ + # NOTE: You can return ranks directly here to change script/train_gpu.sh + # and script/kill_process.sh together + + # Example 1: Use contiguous nodes [8, 16) + return range(6, 7) + + # Example 2: Use non-contiguous nodes [4, 8) + {10} + [30, 32), i.e., [4, 5, 6, 7, 10, 30, 31] + # return list(range(0, 16)) + list(range(24, 40)) + + # Example 3: + # Just Python code, return any nodes you want! + + if not ranks_strs: + return None + + ranks = [] + for r in ranks_strs: + r = eval(r) + if isinstance(r, int): + ranks.append(r) + else: + ranks.extend(r) + return ranks + + +def main(port, ranks): + """ + main + """ + ips = [ip.strip() for ip in os.getenv("TRAINER_INSTANCES").split(",") if ip.strip()] + if ranks is None: + ranks = list(range(len(ips))) + ranks = sorted(list(set(ranks))) + my_rank = int(os.getenv("POD_INDEX", "0")) + if my_rank not in ranks: + return + + rank = ranks.index(my_rank) + nranks = len(ranks) + + master = ips[ranks[0]] + print(f"--master {master}:{port} --rank {rank} --nnodes {nranks}") + + +if __name__ == "__main__": + main(int(sys.argv[1]), parse_ranks(sys.argv[2:])) diff --git a/paddleformers/examples/deepseek_v3/script/train_gpu.sh b/paddleformers/examples/deepseek_v3/script/train_gpu.sh new file mode 100644 index 00000000000..798ba4a1473 --- /dev/null +++ b/paddleformers/examples/deepseek_v3/script/train_gpu.sh @@ -0,0 +1,80 @@ +#!/bin/bash + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT + +nnodes=$PADDLE_TRAINERS_NUM +rank=$PADDLE_TRAINER_ID + +for name in `env | grep -E 'PADDLE|ENDPOINT' | awk -F'=' '{print $1}'`; do + unset ${name} +done + +#export FLAGS_shard_bypass_dygraph_optimizer=1 +export NCCL_IB_GID_INDEX=3 +export NVSHMEM_IB_GID_INDEX=3 +export NVSHMEM_IB_TRAFFIC_CLASS=162 + +#export NVSHMEM_IB_ENABLE_IBGDA=true +##export NVSHMEM_DISABLE_P2P=1 +export NVSHMEM_BOOTSTRAP=UID +# export NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME==xgbe0 + +unset NVSHMEM_HCA_LIST +unset NVSHMEM_ENABLE_NIC_PE_MAPPING + +LAUNCH_CMD=`python script/selective_launch.py 36677` +if [[ -z "$LAUNCH_CMD" ]]; then + exit 0 +fi + +# export PYTHONPATH=../../../:$PYTHONPATH + +export PATH=/opt/nvidia/nsight-systems/2025.3.1/bin/:$PATH + +export DSV3_USE_FP8_GEMM=true +export DSV3_USE_ATTEN_RECOMPUTE=true +export FA_VERSION=3 +export CUDA_PATH=/usr/local/cuda-12.9 +export FLAGS_share_tensor_for_grad_tensor_holder=1 +export FLAGS_use_default_stream=false +export DSV3_USE_FP8_DISPATCH=true +export USE_DS_GEMM=false + +bash script/kill_process.sh + +# source /root/paddlejob/workspace/env_run/zhangbo/env_ds/bin/activate +source /root/paddlejob/workspace/env_run/chenxi/chenxi_py3.10/bin/activate + +export FLAGS_large_pool_auto_growth_chunk_size_in_mb=500 +export FLAGS_small_pool_auto_growth_chunk_size_in_mb=20 +export FLAGS_small_pool_size_in_mb=10 + +export FLAGS_samll_pool_pre_alloc_in_mb=500 +export FLAGS_large_pool_pre_alloc_in_mb=61440 + +export DSV3_FAST_PRETRAIN=true +# nsys profile --stats=true -t cuda,nvtx -o test_no_quant_cache --force-overwrite true \ +python3.10 -m paddle.distributed.launch \ + --log_dir output/paddle_distributed_logs \ + $LAUNCH_CMD \ + --run_mode=collective \ + ${script:-run_pretrain.py} \ + $@