diff --git a/examples/run_finetune.py b/examples/run_finetune.py index a635b147d34..b93e08971eb 100644 --- a/examples/run_finetune.py +++ b/examples/run_finetune.py @@ -140,6 +140,8 @@ def main(): model_config.max_sequence_length = training_args.max_seq_len model_config.num_nextn_predict_layers = model_args.num_nextn_predict_layers model_config._attn_implementation = model_args.attn_impl + model_config.using_fake_gate = model_args.using_fake_gate + model_config.aux_loss_alpha = model_args.aux_loss_alpha logger.info(f"Final model config: {model_config}") logger.info("Creating model") @@ -278,13 +280,16 @@ def neft_post_hook(module, input, output): training_args.logging_steps = int(training_args.max_steps / training_args.num_train_epochs) callbacks = [] + if getattr(model_config, "topk_method", None) == "noaux_tc": - callbacks += [MoECorrectionBiasAdjustCallback(lr=0)] + # deepseek_v3 finetune do not update the bias, so set lr to 0.0 + callbacks += [MoECorrectionBiasAdjustCallback(lr=0.0)] if training_args.use_expert_parallel: callbacks += [MoeExpertsGradScaleCallback(training_args)] - print("callbacks:", callbacks, flush=True) + logger.info(f"callbacks: {callbacks}") + trainer = SFTTrainer( model=model, args=training_args, @@ -295,6 +300,7 @@ def neft_post_hook(module, input, output): data_collator=data_collator, do_generation=data_args.eval_with_do_generation, data_args=data_args, + callbacks=callbacks, ) trainable_parameters = [ p for p in model.parameters() if not p.stop_gradient or ("quantization_linear" in p.name and "w_1" in p.name) diff --git a/paddleformers/nn/pp_model.py b/paddleformers/nn/pp_model.py index b5b70ae5e53..38e5d9ee06b 100644 --- a/paddleformers/nn/pp_model.py +++ b/paddleformers/nn/pp_model.py @@ -508,12 +508,28 @@ class GeneralModelForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): _embed_cls = None _rotary_emb_cls = None _norm_cls = "rms_norm" + _mtp_layer_pipe_cls = None + _embedding_pipe_cls = None + _decoder_layer_pipe_cls = None + _criterion_pipe_cls = None + _lmhead_pipe_cls = None + _rms_norm_pipe_cls = None def __init__(self, config: PretrainedConfig, **kwargs): # dynamic inherit DecoderLayer if self._decoder_layer_cls is None: raise ValueError("_decoder_layer_cls must be set before init.") - DecoderLayerPipe = make_decoder_layer_pipe(self._decoder_layer_cls) + + EmbeddingPipeCls = self._embedding_pipe_cls if self._embedding_pipe_cls is not None else Embedding + + if self._decoder_layer_pipe_cls is None: + DecoderLayerPipe = make_decoder_layer_pipe(self._decoder_layer_cls) + else: + DecoderLayerPipe = self._decoder_layer_pipe_cls + + LMHeadPipeCls = self._lmhead_pipe_cls if self._lmhead_pipe_cls is not None else LMHeadPipe + MTPLayerPipeCls = self._mtp_layer_pipe_cls if self._mtp_layer_pipe_cls is not None else None + RMSNormPipeCls = self._rms_norm_pipe_cls if self._rms_norm_pipe_cls is not None else RMSNormPipe new_initializer_range = math.sqrt(0.3333 / config.hidden_size) logger.info(f"change initializer-range from {config.initializer_range} to {new_initializer_range}") @@ -560,7 +576,7 @@ def __init__(self, config: PretrainedConfig, **kwargs): else: self.add_sequential_layer( LayerDesc( - EmbeddingPipe, config=config, embed_cls=self._embed_cls, rotary_emb_cls=self._rotary_emb_cls + EmbeddingPipeCls, config=config, embed_cls=self._embed_cls, rotary_emb_cls=self._rotary_emb_cls ), "model", ) @@ -574,6 +590,12 @@ def __init__(self, config: PretrainedConfig, **kwargs): ), f"model.layers.{i}", ) + for i in range(config.num_nextn_predict_layers): + if MTPLayerPipeCls is not None: + self.add_sequential_layer( + LayerDesc(MTPLayerPipeCls, config=config, layer_idx=config.num_hidden_layers + i), + f"model.layers.{config.num_hidden_layers + i}", + ) for i in range(config.add_tail_layers): self.add_sequential_layer( LayerDesc( @@ -583,7 +605,7 @@ def __init__(self, config: PretrainedConfig, **kwargs): ) self.add_sequential_layer( - LayerDesc(RMSNormPipe if self._norm_cls == "rms_norm" else LayerNormPipe, config=config), + LayerDesc(RMSNormPipeCls if self._norm_cls == "rms_norm" else LayerNormPipe, config=config), "model.norm", ) @@ -591,14 +613,14 @@ def __init__(self, config: PretrainedConfig, **kwargs): self.add_sequential_layer( SharedLayerDesc( "model_shared_weight", - LMHeadPipe, + LMHeadPipeCls, shared_weight_attr="embedding_weight", config=config, ), "lm_head", ) else: - self.add_sequential_layer(LayerDesc(LMHeadPipe, config=config), "lm_head") + self.add_sequential_layer(LayerDesc(LMHeadPipeCls, config=config), "lm_head") recompute_interval = 0 seg_method = config.pp_seg_method if hasattr(config, "pp_seg_method") else "layer:DecoderLayer|EmptyLayer" @@ -631,10 +653,12 @@ def __init__(self, config: PretrainedConfig, **kwargs): ) def get_loss_fn(self, config): + CriterionPipeCls = self._criterion_pipe_cls if self._criterion_pipe_cls is not None else CriterionLayerPipe + if config.get("dpo_config", None) is not None: - loss_fn = CriterionLayerPipe(config, use_infohub=True) + loss_fn = CriterionPipeCls(config, use_infohub=True) else: - loss_fn = CriterionLayerPipe(config) + loss_fn = CriterionPipeCls(config) return loss_fn diff --git a/paddleformers/optimizers/moe_hybrid_parallel_optimizer.py b/paddleformers/optimizers/moe_hybrid_parallel_optimizer.py new file mode 100644 index 00000000000..a2865fb9c3f --- /dev/null +++ b/paddleformers/optimizers/moe_hybrid_parallel_optimizer.py @@ -0,0 +1,406 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.distributed as dist +from paddle.autograd import no_grad +from paddle.distributed.fleet.base.topology import ParallelMode +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizer, + DygraphShardingOptimizerV2, +) +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import ( + HybridParallelOptimizer as HPBase, +) +from paddle.distributed.fleet.utils import timer_helper as timer +from paddle.distributed.fleet.utils.hybrid_parallel_util import unwrap_optimizer +from paddle.distributed.fleet.utils.log_util import logger +from paddle.distributed.fleet.utils.mix_precision_utils import MixPrecisionOptimizer +from paddle.framework import core +from paddle.nn import ClipGradByGlobalNorm, clip + +__all__ = [] + + +class MoEHybridParallelClipGrad: + def __init__(self, clip, hcg, timers=None): + self._clip = clip + self._hcg = hcg + if hasattr(hcg, "get_moe_sharding_parallel_world_size") and hcg.get_moe_sharding_parallel_world_size() > 0: + # hybrid expert parallel + self.moe_group = hcg.get_expert_parallel_group() + self.moe_sharding_group = hcg.get_moe_sharding_parallel_group() + + self.stat = {} # for logging + self._timers = timers + self.processed_steps = 0 + + def _global_norm( + self, global_norm_var_dist, global_norm_var_not_dist, global_norm_var_dist_moe, global_norm_var_not_dist_moe + ): + # sharding first + sharding_flag = self._hcg.get_sharding_parallel_world_size() > 1 + mp_flag = self._hcg.get_model_parallel_world_size() > 1 + pp_flag = self._hcg.get_pipe_parallel_world_size() > 1 + + """do comm""" + logger.info( + f"before reduce: dist-moe-grad-norm={global_norm_var_dist_moe.item()} " + f"before reduce: non-dist-moe-grad-norm={global_norm_var_not_dist_moe.item()}" + ) + + if self.moe_sharding_group: + dist.all_reduce( + global_norm_var_dist_moe, + op=dist.ReduceOp.SUM, + group=self.moe_sharding_group, + ) + dist.all_reduce( + global_norm_var_not_dist_moe, + op=dist.ReduceOp.SUM, + group=self.moe_sharding_group, + ) + + if self.moe_group: + dist.all_reduce( + global_norm_var_dist_moe, + op=dist.ReduceOp.SUM, + group=self.moe_group, + ) + dist.all_reduce( + global_norm_var_not_dist_moe, + op=dist.ReduceOp.SUM, + group=self.moe_group, + ) + + if pp_flag: + paddle.distributed.all_reduce( + global_norm_var_dist_moe, + group=self._hcg.get_pipe_parallel_group(), + ) + paddle.distributed.all_reduce( + global_norm_var_not_dist_moe, + group=self._hcg.get_pipe_parallel_group(), + ) + + # add all reduce to get global norm of distributed params_and_grads + if sharding_flag: + # norm of mp distributed variable + if mp_flag: + # dist should reduce among sharding group、mp group、pp group + paddle.distributed.all_reduce( + global_norm_var_dist, + group=self._hcg.get_sharding_parallel_group(), + ) + # not dist only reduce among sharding group and pp group later + paddle.distributed.all_reduce( + global_norm_var_not_dist, + group=self._hcg.get_sharding_parallel_group(), + ) + + # norm of mp distributed variable + if mp_flag: + # dist should reduce among sharding group、mp group、pp group + paddle.distributed.all_reduce( + global_norm_var_dist, + group=self._hcg.get_model_parallel_group(), + ) + if pp_flag: + paddle.distributed.all_reduce( + global_norm_var_dist, + group=self._hcg.get_pipe_parallel_group(), + ) + + # add all reduce to get global norm of non-distributed params_and_grads in groups of pp + if pp_flag: + paddle.distributed.all_reduce( + global_norm_var_not_dist, + group=self._hcg.get_pipe_parallel_group(), + ) + + logger.info( + f"after reduce: dist-grad-norm={global_norm_var_dist.item()} " + f"after reduce: non-dist-grad-norm={global_norm_var_not_dist.item()}" + ) + + @no_grad() + def _dygraph_clip(self, params_grads): + if self._timers: + self._timers("dygraph-clip").start() + sum_square_dist_fp16 = [] + sum_square_dist_bf16 = [] + sum_square_dist_fp32 = [] + + sum_square_dist_moe_fp16 = [] + sum_square_dist_moe_bf16 = [] + sum_square_dist_moe_fp32 = [] + + sum_square_not_dist_fp16 = [] + sum_square_not_dist_bf16 = [] + sum_square_not_dist_fp32 = [] + + sum_square_not_dist_moe_fp16 = [] + sum_square_not_dist_moe_bf16 = [] + sum_square_not_dist_moe_fp32 = [] + + for p, g in params_grads: + if g is None: + continue + if getattr(p, "need_clip", True) is False: + continue + merge_grad = g + if g.type == core.VarDesc.VarType.SELECTED_ROWS: + merge_grad = clip.merge_selected_rows(g) + merge_grad = clip.get_tensor_from_selected_rows(merge_grad) + sum_square = clip._squared_l2_norm(merge_grad) + + not_shared_enable = (not hasattr(p, "is_firstly_shared")) or ( + hasattr(p, "is_firstly_shared") and getattr(p, "is_firstly_shared", True) + ) + + is_moe_param = getattr(p, "is_moe_param", False) + + if is_moe_param: + assert 0 + if not_shared_enable: + if getattr(p, "no_sync", False): + if p.is_distributed: + if g.dtype == paddle.float16: + sum_square_dist_moe_fp16.append(sum_square) + elif g.dtype == paddle.bfloat16: + sum_square_dist_moe_bf16.append(sum_square) + elif g.dtype == paddle.float32: + sum_square_dist_moe_fp32.append(sum_square) + else: + if g.dtype == paddle.float16: + sum_square_not_dist_moe_fp16.append(sum_square) + elif g.dtype == paddle.bfloat16: + sum_square_not_dist_moe_bf16.append(sum_square) + elif g.dtype == paddle.float32: + sum_square_not_dist_moe_fp32.append(sum_square) + + elif p.is_distributed: + if g.dtype == paddle.float16: + sum_square_dist_fp16.append(sum_square) + elif g.dtype == paddle.bfloat16: + sum_square_dist_bf16.append(sum_square) + elif g.dtype == paddle.float32: + sum_square_dist_fp32.append(sum_square) + else: + assert not getattr( + p, "no_sync", False + ), f"moe param shoud be distributed, got: {p.name}, shape={p.shape}" + if g.dtype == paddle.float16: + sum_square_not_dist_fp16.append(sum_square) + if g.dtype == paddle.bfloat16: + sum_square_not_dist_bf16.append(sum_square) + elif g.dtype == paddle.float32: + sum_square_not_dist_fp32.append(sum_square) + else: + assert not getattr(p, "no_sync", False), "MoE cannot handle shared param" + + def add_n_list(tensor_list): + if not tensor_list: + return paddle.zeros((1,), dtype=paddle.float32) + return paddle.add_n(tensor_list).cast(paddle.float32) + + # moe global norm of distributed FP16 params_and_grads + global_norm_dist_moe_fp16 = add_n_list( + sum_square_dist_moe_fp16, + ) + global_norm_not_dist_moe_fp16 = add_n_list( + sum_square_not_dist_moe_fp16, + ) + global_norm_dist_fp16 = add_n_list( + sum_square_dist_fp16, + ) + global_norm_not_dist_fp16 = add_n_list( + sum_square_not_dist_fp16, + ) + + global_norm_dist_moe_bf16 = add_n_list( + sum_square_dist_moe_bf16, + ) + global_norm_not_dist_moe_bf16 = add_n_list( + sum_square_not_dist_moe_bf16, + ) + global_norm_dist_bf16 = add_n_list( + sum_square_dist_bf16, + ) + global_norm_not_dist_bf16 = add_n_list( + sum_square_not_dist_bf16, + ) + + global_norm_dist_moe_fp32 = add_n_list( + sum_square_dist_moe_fp32, + ) + global_norm_not_dist_moe_fp32 = add_n_list( + sum_square_not_dist_moe_fp32, + ) + global_norm_dist_fp32 = add_n_list( + sum_square_dist_fp32, + ) + global_norm_not_dist_fp32 = add_n_list( + sum_square_not_dist_fp32, + ) + + global_norm_var_dist_moe = global_norm_dist_moe_fp16 + global_norm_dist_moe_bf16 + global_norm_dist_moe_fp32 + + global_norm_var_not_dist_moe = ( + global_norm_not_dist_moe_fp16 + global_norm_not_dist_moe_bf16 + global_norm_not_dist_moe_fp32 + ) + + global_norm_var_dist = global_norm_dist_fp16 + global_norm_dist_bf16 + global_norm_dist_fp32 + global_norm_var_not_dist = global_norm_not_dist_fp16 + global_norm_not_dist_bf16 + global_norm_not_dist_fp32 + result = self._comm_and_clip( + params_grads, + global_norm_var_dist, + global_norm_var_not_dist, + global_norm_var_dist_moe, + global_norm_var_not_dist_moe, + ) + if self._timers: + self._timers("dygraph-clip").stop() + + return result + + def _comm_and_clip( + self, + params_grads, + global_norm_var_dist, + global_norm_var_not_dist, + global_norm_var_dist_moe, + global_norm_var_not_dist_moe, + ): + + self._global_norm( + global_norm_var_dist, global_norm_var_not_dist, global_norm_var_dist_moe, global_norm_var_not_dist_moe + ) + + global_norm_var_fp32 = paddle.sqrt( + global_norm_var_dist + global_norm_var_not_dist + global_norm_var_dist_moe + global_norm_var_not_dist_moe + ) + self.stat["global_grad_norm"] = global_norm_var_fp32.astype("float32").item() + + max_global_norm = paddle.full( + shape=[], + dtype=global_norm_var_fp32.dtype, + fill_value=self.clip_norm, + ) + clip_var = paddle.divide( + x=max_global_norm, + y=paddle.maximum(x=global_norm_var_fp32, y=max_global_norm) + + paddle.full(shape=[], dtype=paddle.float32, fill_value=1.0e-6), + ) + logger.info(f"hybrid-moe-clip, var={clip_var.item()}, global_norm:{global_norm_var_fp32.item()}") + clip_var_fp16 = paddle.cast(clip_var, paddle.float16) + + if ( + not isinstance(paddle.framework._current_expected_place(), paddle.CustomPlace) + or paddle.framework._current_expected_place().get_device_type() == "npu" + ): + clip_var_bf16 = paddle.cast(clip_var, paddle.bfloat16) + for p, g in params_grads: + if g is None: + continue + if getattr(p, "need_clip", True) is False: + continue + if g.dtype == paddle.float16: + g.multiply_(clip_var_fp16) + elif g.dtype == paddle.bfloat16: + if paddle.is_compiled_with_xpu(): + raise NotImplementedError("BF16 is not supported on XPU now") + g.multiply_(clip_var_bf16) + else: + g.multiply_(clip_var) + p._reset_grad_inplace_version(True) + + return params_grads + + def __getattr__(self, item): + return getattr(self._clip, item) + + def __call__(self, params_grads): + return self._dygraph_clip(params_grads) + + +class MoEHybridParallelOptimizer(HPBase): + # adapter wrapper for optimizer + def __init__(self, optimizer, hcg, strategy): + # Note: Only sharding stage 1 is considered in HybridParallelOptimizer. + # The sharding stage2 and stage3 optimizers are invoked in other api. + print( + f"moe_sharding_degree:{hcg.get_moe_sharding_parallel_world_size()}, sharding_degree:{hcg.get_sharding_parallel_world_size()}, ep_degree:{hcg.get_expert_parallel_world_size()}" + ) + if hcg.get_moe_sharding_parallel_world_size() > 0: + split_param = strategy.hybrid_configs["sharding_configs"].split_param + assert ( + hcg.get_sharding_parallel_world_size() >= 1 and split_param is True + ), "Hybrid expert parallel only supports ShardingV2 now" + if hcg.get_sharding_parallel_world_size() > 1: + split_param = strategy.hybrid_configs["sharding_configs"].split_param + ShardingOptimizer = DygraphShardingOptimizerV2 if split_param else DygraphShardingOptimizer + optimizer = ShardingOptimizer(optimizer, hcg) + + self._enable_timer = strategy.hybrid_configs["enable_optimizer_timer"] + + if self._enable_timer: + if not timer.is_timer_initialized(): + timer.set_timers() + self._timers = timer.get_timers() + else: + self._timers = None + + self._inner_opt = optimizer + self._strategy = strategy + self._hcg = hcg + + self._use_dp_mode = self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL + + self._need_dp = self._hcg.get_data_parallel_world_size() > 1 + + self._dp_enable = not self._use_dp_mode and self._need_dp + + self._sharding_enable = self._hcg.get_sharding_parallel_world_size() > 1 + + self._sep_enable = self._hcg.get_sep_parallel_world_size() > 1 + + if isinstance(self._inner_opt._grad_clip, ClipGradByGlobalNorm) and not self._use_dp_mode: + logger.warning( + "While using ClipGradByGlobalNorm in TensorParallel, PipelineParallel " + "or Sharding, the grad clip of original optimizer will be changed." + ) + + inner_opt = unwrap_optimizer( + self._inner_opt, + ( + MixPrecisionOptimizer, + DygraphShardingOptimizer, + DygraphShardingOptimizerV2, + ), + ) + + if ( + inner_opt._parameter_list + and not isinstance(inner_opt._parameter_list[0], dict) + and len([p for p in inner_opt._parameter_list if hasattr(p, "main_grad")]) > 0 + ): + inner_opt._grad_clip = MoEHybridParallelClipGrad(inner_opt._grad_clip, hcg, self._timers) + else: + inner_opt._grad_clip = MoEHybridParallelClipGrad(inner_opt._grad_clip, hcg, self._timers) + if inner_opt._parameter_list and isinstance(inner_opt._parameter_list[0], dict): + for item in inner_opt._param_groups: + if "grad_clip" in item.keys(): + item["grad_clip"] = MoEHybridParallelClipGrad(inner_opt._grad_clip, hcg, self._timers) + self.processed_steps = 0 diff --git a/paddleformers/trainer/trainer.py b/paddleformers/trainer/trainer.py index 6cd0e2d66a4..195bcc80117 100644 --- a/paddleformers/trainer/trainer.py +++ b/paddleformers/trainer/trainer.py @@ -166,6 +166,7 @@ get_last_checkpoint, get_scheduler, has_length, + mock_offload_optimizer, set_seed, should_skip_data, speed_metrics, @@ -388,7 +389,6 @@ def __init__( self.optimizer, remap_parameter_name=self.args.load_sharded_model_remap_parameter_name, ) - if self.args.unified_checkpoint: self.unified_checkpoint_handler = UnifiedCheckpointHandler(self.args) @@ -2001,6 +2001,9 @@ def apply_decay_param_fun(x): **optimizer_kwargs, ) + if self.args.tensorwise_offload_optimizer: + mock_offload_optimizer() + return self.optimizer def _apply_to_optimizer(self, action): @@ -2200,6 +2203,30 @@ def _decorate_exclude_layers(self, model: nn.Layer): exclude_layers = [] return exclude_layers + def _wrap_distributed_optimizer(self, optimizer): + """ + In hybrid expert parallel, use customized optimizer and grad clip + """ + if ( + self.args.use_expert_parallel + and self.args.moe_sharding_parallel_degree >= 1 + and self.args.expert_parallel_degree > 1 + ): + from paddleformers.optimizers import MoEHybridParallelOptimizer + + fleet_env = fleet.fleet + fleet_env.user_defined_optimizer = optimizer + hp_optim = MoEHybridParallelOptimizer(optimizer, fleet_env._hcg, fleet_env._user_defined_strategy) + + if fleet_env._user_defined_strategy.hybrid_configs["pp_configs"].dp_comm_overlap: + hp_optim._dp_enable = False + + if fleet_env._user_defined_strategy.hybrid_configs["pp_configs"].sharding_comm_overlap: + hp_optim._sharding_enable = False + return hp_optim + else: + return fleet.distributed_optimizer(optimizer) + def _wrap_model(self, model, training=True): # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again @@ -2320,7 +2347,7 @@ def get_expected_keys(inputs, keys): assert self.optimizer is not None, "Pipeline mode need decorate optimizer, pelease init optimizer." if self.args.amp_master_grad: self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) - self.optimizer = fleet.distributed_optimizer(self.optimizer) + self.optimizer = self._wrap_distributed_optimizer(self.optimizer) if ( hasattr(self.args, "enable_sharding_comm_overlap") @@ -2351,7 +2378,7 @@ def get_expected_keys(inputs, keys): if self.args.amp_master_grad: self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) - self.optimizer = fleet.distributed_optimizer(self.optimizer) + self.optimizer = self._wrap_distributed_optimizer(self.optimizer) else: cpu_offload = ShardingOption.OFFLOAD in self.args.sharding assert self.optimizer is not None, "optimizer is empty!" @@ -2409,7 +2436,7 @@ def get_expected_keys(inputs, keys): assert self.optimizer is not None, "Tensor parallel mode need decorate optimizer, pelease init optimizer." if self.args.amp_master_grad: self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) - self.optimizer = fleet.distributed_optimizer(self.optimizer) + self.optimizer = self._wrap_distributed_optimizer(self.optimizer) # stage1 has v1 and v2 version if in_sharding_parallel_mode and ShardingOption.SHARD_OP in self.args.sharding: diff --git a/paddleformers/trainer/trainer_callback.py b/paddleformers/trainer/trainer_callback.py index 956a2d87c8c..5c5d069626d 100644 --- a/paddleformers/trainer/trainer_callback.py +++ b/paddleformers/trainer/trainer_callback.py @@ -690,7 +690,9 @@ def on_optimizer_begin(self, args, state, control, **kwargs): class MoECorrectionBiasAdjustCallback(TrainerCallback): - """used for moe aux loss free balance""" + """ + used for moe aux loss free balance + """ def __init__(self, lr=0.001, use_mp=False): super().__init__() @@ -750,7 +752,7 @@ def update_bias(layer): class MoeExpertsGradScaleCallback(TrainerCallback): """ - 此 hook 用于修正专家参数的梯度被放大N倍的问题 + This hook is used to correct the issue where the gradients of expert parameters are amplified by a factor of N. """ def __init__(self, args): diff --git a/paddleformers/trainer/trainer_utils.py b/paddleformers/trainer/trainer_utils.py index 0337b3276a3..38c8be9e3fa 100644 --- a/paddleformers/trainer/trainer_utils.py +++ b/paddleformers/trainer/trainer_utils.py @@ -68,6 +68,19 @@ ] +def mock_offload_optimizer(): + """ + mock offload optimizer + """ + try: + from paddleformers.trainer.utils.offload_optimizer import hack_offload_optimizer + + hack_offload_optimizer() + logger.warning("hack_offload_optimizer called.") + except ImportError: + logger.warning("hack_offload_optimizer is not imported") + + def log_trainer_start(): if "MAIN_PROCESS_STARTED" not in os.environ: start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) diff --git a/paddleformers/trainer/utils/offload_optimizer.py b/paddleformers/trainer/utils/offload_optimizer.py new file mode 100644 index 00000000000..65f5b77e2e5 --- /dev/null +++ b/paddleformers/trainer/utils/offload_optimizer.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import _C_ops +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import ( + HybridParallelOptimizer, +) +from paddle.optimizer import Optimizer + +from .sharding_io import to_device + + +def offload(tensor): + if paddle.is_compiled_with_cuda(): + place = paddle.CUDAPinnedPlace() + else: + place = paddle.CPUPlace() + + new_tensor = to_device(tensor, place) + assert new_tensor is tensor, "to_device must be inplace operation" + + +def reload(tensor): + new_tensor = to_device(tensor) + assert new_tensor is tensor, "to_device must be inplace operation" + + +def hack_offload_optimizer(): + # Step 1: mock _add_accumulator + origin_add_accumulator = getattr(Optimizer, "_add_accumulator") + + def new_add_accumulator(self, *args, **kwargs): + x = origin_add_accumulator(self, *args, **kwargs) + offload(x) + return x + + setattr(Optimizer, "_add_accumulator", new_add_accumulator) + + # Step 2: mock _C_ops.adamw_ and _C_ops.adamw + for name in ["adam_", "adamw_"]: + origin_op = getattr(_C_ops, name) + + def new_opt_op(*args): + for arg in args: + if isinstance(arg, paddle.Tensor): + reload(arg) + + ret = origin_op(*args) + + for i, arg in enumerate(args): + if i >= 2 and isinstance(arg, paddle.Tensor): # do not offload parameter and gradient + offload(arg) + return ret + + setattr(_C_ops, name, new_opt_op) + + # Step 3: mock _insert_sync + opt_type = HybridParallelOptimizer + origin_insert_sync = getattr(opt_type, "_insert_sync") + + def new_insert_sync(self, sync_var, *args, **kwargs): + origin_place = sync_var.place + reload(sync_var) + ret = origin_insert_sync(self, sync_var, *args, **kwargs) + new_sync_var = to_device(sync_var, origin_place) + assert new_sync_var is sync_var, "to_device must be inplace operation" + return ret + + setattr(opt_type, "_insert_sync", new_insert_sync) diff --git a/paddleformers/transformers/__init__.py b/paddleformers/transformers/__init__.py index c012fa66938..aca3da5a8e5 100644 --- a/paddleformers/transformers/__init__.py +++ b/paddleformers/transformers/__init__.py @@ -123,12 +123,11 @@ "DeepseekV2DynamicNTKScalingRotaryEmbedding", "DeepseekV2MLP", "yarn_get_mscale", - "DeepseekV2LMHead", "DeepseekV2DecoderLayer", - "DeepseekV2PretrainingCriterion", "yarn_find_correction_range", "get_triangle_upper_mask", "DeepseekV2LinearScalingRotaryEmbedding", + "DeepseekV2ForCausalLMPipe", ], "deepseek_v2.modeling_auto": [ "DeepseekV2LMHeadAuto", @@ -136,7 +135,6 @@ "DeepseekV2ModelAuto", "DeepseekV2PretrainedModelAuto", ], - "deepseek_v2.modeling_pp": ["DeepseekV2ForCausalLMPipe"], "deepseek_v2.mfu_utils": ["DeepSeekProjection"], "deepseek_v2.kernel": [ "act_quant", @@ -160,6 +158,7 @@ "DeepseekV3ForSequenceClassification", "DeepseekV3Model", "DeepseekV3PretrainedModel", + "DeepseekV3ForCausalLMPipe", ], "deepseek_v3.modeling_auto": [ "DeepseekV3LMHeadAuto", @@ -167,7 +166,6 @@ "DeepseekV3ModelAuto", "DeepseekV3PretrainedModelAuto", ], - "deepseek_v3.modeling_pp": ["DeepseekV3ForCausalLMPipe"], "ernie4_5.configuration": ["Ernie4_5Config"], "ernie4_5.modeling": ["Ernie4_5Model", "Ernie4_5ForCausalLM", "Ernie4_5ForCausalLMPipe"], "ernie4_5.tokenizer": ["Ernie4_5Tokenizer"], diff --git a/paddleformers/transformers/deepseek_v2/__init__.py b/paddleformers/transformers/deepseek_v2/__init__.py index a0fac197982..68e4bef3ab5 100644 --- a/paddleformers/transformers/deepseek_v2/__init__.py +++ b/paddleformers/transformers/deepseek_v2/__init__.py @@ -50,12 +50,11 @@ "DeepseekV2DynamicNTKScalingRotaryEmbedding", "DeepseekV2MLP", "yarn_get_mscale", - "DeepseekV2LMHead", "DeepseekV2DecoderLayer", - "DeepseekV2PretrainingCriterion", "yarn_find_correction_range", "get_triangle_upper_mask", "DeepseekV2LinearScalingRotaryEmbedding", + "DeepseekV2ForCausalLMPipe", ], "modeling_auto": [ "DeepseekV2LMHeadAuto", @@ -63,7 +62,6 @@ "DeepseekV2ModelAuto", "DeepseekV2PretrainedModelAuto", ], - "modeling_pp": ["DeepseekV2ForCausalLMPipe"], "mfu_utils": ["DeepSeekProjection"], "kernel": [ "act_quant", diff --git a/paddleformers/transformers/deepseek_v2/configuration.py b/paddleformers/transformers/deepseek_v2/configuration.py index 1feba3cbec7..77b53621732 100644 --- a/paddleformers/transformers/deepseek_v2/configuration.py +++ b/paddleformers/transformers/deepseek_v2/configuration.py @@ -178,7 +178,6 @@ def __init__( attention_bias=False, attention_dropout=0.0, speculate_model_type=False, - using_flex_token=False, **kwargs, ): self.vocab_size = vocab_size @@ -226,7 +225,6 @@ def __init__( self.attention_dropout = attention_dropout self.speculate_model_type = speculate_model_type self.use_fp8 = False - self.using_flex_token = using_flex_token super().__init__( pad_token_id=pad_token_id, diff --git a/paddleformers/transformers/deepseek_v2/modeling.py b/paddleformers/transformers/deepseek_v2/modeling.py index 60603cb0d6c..86945ed7edc 100644 --- a/paddleformers/transformers/deepseek_v2/modeling.py +++ b/paddleformers/transformers/deepseek_v2/modeling.py @@ -24,6 +24,7 @@ import contextlib import math import warnings +from copy import deepcopy from functools import partial from typing import List, Optional, Tuple, Union @@ -35,32 +36,25 @@ from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.recompute.recompute import recompute +from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + GatherOp, + ScatterOp, + mark_as_sequence_parallel_parameter, +) from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -try: - from paddle.incubate.nn.functional import fused_rotary_position_embedding -except ImportError: - fused_rotary_position_embedding = None - -try: - from paddle.distributed.fleet.utils.sequence_parallel_utils import ( - GatherOp, - ScatterOp, - mark_as_sequence_parallel_parameter, - ) -except: - pass - -try: - from paddle.nn.functional.flash_attention import flash_attention -except: - flash_attention = None - - +from paddle.nn.functional.flash_attention import flash_attention + +from ...nn.criterion.interface import CriterionLayer +from ...nn.embedding import Embedding as GeneralEmbedding +from ...nn.linear import Linear as GeneralLinear +from ...nn.lm_head import LMHead as GeneralLMHead +from ...nn.mlp import MLP as DeepseekV2MLP +from ...nn.norm import Norm as GeneralNorm +from ...nn.norm import RMSNorm +from ...nn.pp_model import EmbeddingPipe, GeneralModelForCausalLMPipe, parse_args from ...utils.initializer import kaiming_uniform_ from ...utils.log import logger from ...utils.tools import get_env_device -from ..activations import ACT2FN from ..conversion_utils import StateDictNameMapping, init_name_mappings from ..llama import fusion_ops from ..llama.modeling import get_use_casual_mask @@ -71,19 +65,18 @@ ) from ..model_utils import PretrainedModel, dtype_guard, register_base_model from ..moe_gate import PretrainedMoEGate -from ..moe_layer import MoEFlexTokenLayer, MoELayer +from ..moe_layer import MoEFlexTokenLayer from ..utils import device_guard from . import fp8_linear as linear_utils from .configuration import DeepseekV2Config from .fp8_linear import Linear __all__ = [ - "DeepseekV2LMHead", - "DeepseekV2PretrainingCriterion", "DeepseekV2ForCausalLM", "DeepseekV2ForSequenceClassification", "DeepseekV2Model", "DeepseekV2PretrainedModel", + "DeepseekV2ForCausalLMPipe", ] @@ -302,57 +295,6 @@ def _expand_2d_mask(mask, dtype, tgt_length): return expanded_mask -class DeepseekV2RMSNorm(nn.Layer): - def __init__(self, config: DeepseekV2Config, hidden_size=None, eps=1e-6, use_sequence_parallel=True): - """DeepseekV2RMSNorm is equivalent to T5LayerNorm - - Args: - config (DeepseekV2Config): config dict of DeepseekV2 - hidden_size (_type_): history_states size - eps (_type_, optional): eps value. Defaults to 1e-6. - use_sequence_parallel (bool, optional): A switch to disable sequence parallelism for inputs that are not in tensor parallel mode. - By default, this is set to True. - """ - super().__init__() - self.config = config - self.hidden_size = hidden_size if hidden_size is not None else config.hidden_size - self.variance_epsilon = eps - - self.weight = paddle.create_parameter( - shape=[self.hidden_size], - dtype=paddle.get_default_dtype(), - default_initializer=nn.initializer.Constant(1.0), - ) - - if config.sequence_parallel and use_sequence_parallel: - mark_as_sequence_parallel_parameter(self.weight) - - def forward(self, hidden_states): - if self.config.use_fused_rms_norm and get_env_device() == "xpu": - if self.weight.dtype != hidden_states.dtype: - hidden_states = paddle.cast(hidden_states, self.weight.dtype) - try: - import paddle_xpu_nn # noqa: F821 - - return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] - except ImportError: - raise NotImplementedError( - f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature" - ) - - with paddle.amp.auto_cast(False): - hidden_states = hidden_states.astype("float32") - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states - - if self.weight.dtype in [paddle.float16, paddle.bfloat16]: - hidden_states = paddle.cast(hidden_states, self.weight.dtype) - return hidden_states * self.weight - - def extra_repr(self): - return f"hidden_size={self.hidden_size}, dtype={self.weight.dtype}" - - class DeepseekV2RotaryEmbedding(nn.Layer): def __init__(self, dim, max_position_embeddings=2048, base=10000): super().__init__() @@ -592,18 +534,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, fuse_rope=False): b, s, h, d = k.shape k = k.reshape([b, s, h, d // 2, 2]).transpose([0, 1, 2, 4, 3]).reshape([b, s, h, d]) - if get_env_device() == "xpu" and fuse_rope: - q_embed, k_embed, _ = fused_rotary_position_embedding( - q, - k, - None, - sin=sin, - cos=cos, - position_ids=position_ids, - use_neox_rotary_style=False, - ) - return q_embed, k_embed - if position_ids is None: # Note: Only for MixtralForCausalLMPipe model pretraining cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, axis] @@ -619,58 +549,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, fuse_rope=False): return q_embed, k_embed -class DeepseekV2MLP(nn.Layer): - def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None, is_moe=False): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size if hidden_size is None else hidden_size - self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size - - def linear_dtype_gaurd(): - if config.use_fp8: - return dtype_guard("float8_e4m3fn") - else: - return contextlib.nullcontext() - - if config.sequence_parallel: - ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear - RowParallelLinear = linear_utils.RowSequenceParallelLinear - else: - ColumnParallelLinear = linear_utils.ColumnParallelLinear - RowParallelLinear = linear_utils.RowParallelLinear - - with linear_dtype_gaurd(): - if config.tensor_parallel_degree > 1 and not is_moe: - self.gate_proj = ColumnParallelLinear( - self.hidden_size, - self.intermediate_size, - gather_output=False, - has_bias=False, - ) - self.up_proj = ColumnParallelLinear( - self.hidden_size, - self.intermediate_size, - gather_output=False, - has_bias=False, - ) - self.down_proj = RowParallelLinear( - self.intermediate_size, - self.hidden_size, - input_is_parallel=True, - has_bias=False, - ) - else: - self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) - self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) - self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False) - - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - class FakeGate(paddle.autograd.PyLayer): @staticmethod def forward(ctx, hidden_states, weight): @@ -712,8 +590,12 @@ def __init__(self, config, num_experts, expert_hidden_size, **kwargs): default_initializer=nn.initializer.Constant(0.0), ) self.e_score_correction_bias.is_distributed = True - - self.using_flex_token = config.using_flex_token + self.e_score_correction_bias.stop_gradient = True + self.expert_usage = paddle.zeros( + shape=[num_experts], + dtype=paddle.int64, + ) + self.expert_usage.stop_gradient = True def forward(self, hidden_states): """ @@ -734,12 +616,10 @@ def forward(self, hidden_states): scores = self.gate_score_func(logits=logits) scores = scores.cast(paddle.float32) - if self.using_flex_token: - scores, routing_map, exp_counts, l_aux, l_zloss = self.topkgating_nodrop(scores) - return scores, routing_map, l_aux, l_zloss - - capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.topkgating(scores) - return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss + scores, routing_map, exp_counts, l_aux, l_zloss = self.topkgating_nodrop(scores) + with paddle.no_grad(): + self.expert_usage += exp_counts + return scores, routing_map, l_aux, l_zloss class AddAuxiliaryLoss(paddle.autograd.PyLayer): @@ -763,59 +643,6 @@ def backward(ctx, grad_output): return grad_output, grad_loss -class DeepseekV2MoE(MoELayer): - """ - A mixed expert module containing shared experts. - """ - - def __init__(self, config: DeepseekV2Config): - gate = MoEGate( - config=config, - num_experts=config.n_routed_experts, - expert_hidden_size=config.hidden_size, - top_k=config.num_experts_per_tok, - topk_method=config.topk_method, - n_group=config.n_group, - topk_group=config.topk_group, - norm_topk_prob=config.norm_topk_prob, - routed_scaling_factor=config.routed_scaling_factor, - drop_tokens=False, - ) - - # (LiuTing) only support either tp or ep. - moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() - expert_parallel_degree = dist.get_world_size(moe_group) - expert_parallel_degree = 1 if expert_parallel_degree < 0 else expert_parallel_degree - act_tp_shard = config.tensor_parallel_degree > 1 and expert_parallel_degree <= 1 - super().__init__( - config=config, - moe_num_experts=config.n_routed_experts, - expert_class=DeepseekV2MLP, - expert_kwargs={ - "config": config, - "intermediate_size": config.moe_intermediate_size, - "is_moe": not act_tp_shard, - }, - gate=gate, - capacity=2.0, - ) - self.alpha = config.aux_loss_alpha - if config.n_shared_experts is not None: - intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size, is_moe=False) - - def forward(self, hidden_states): - final_hidden_states, l_aux, l_zloss = super().forward(hidden_states) - if self.training and self.alpha > 0.0: - l_aux = l_aux * self.alpha - final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, l_aux) - - if self.config.n_shared_experts is not None: - shared_expert_output = self.shared_experts(hidden_states) - final_hidden_states = final_hidden_states + shared_expert_output - return final_hidden_states - - class DeepseekV2MoEFlexToken(MoEFlexTokenLayer): """ A mixed expert module containing shared experts. @@ -836,25 +663,35 @@ def __init__(self, config: DeepseekV2Config): ) hcg = fleet.get_hybrid_communicate_group() - moe_group = hcg.expert_parallel_group - moe_grad_group = hcg.expert_grad_comm_group + moe_group = hcg.get_expert_parallel_group() + moe_grad_group = hcg.get_moe_sharding_parallel_group() + config = deepcopy(config) + config.tensor_parallel_degree = 1 super().__init__( config=config, moe_num_experts=config.n_routed_experts, expert_class=DeepseekV2MLP, - expert_kwargs={"config": config, "intermediate_size": config.moe_intermediate_size, "is_moe": True}, + expert_kwargs={"config": config, "intermediate_size": config.moe_intermediate_size}, gate=gate, moe_group=moe_group, ) + self.is_mp_moe = False + self.is_ep_moe = True for p in self.experts.parameters(): + setattr(p, "is_moe_param", True) setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) + p.no_sync = not self.is_mp_moe + p.expert = not self.is_mp_moe + logger.info(f"expert no-sync={p.no_sync}-{p.name}") + if self.is_mp_moe or self.is_ep_moe: + p.is_distributed = True self.alpha = config.aux_loss_alpha if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size, is_moe=False) + self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size) def forward(self, hidden_states): final_hidden_states, l_aux, l_zloss = super().forward(hidden_states) @@ -868,21 +705,6 @@ def forward(self, hidden_states): return final_hidden_states -def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor: - """ - This is the equivalent of paddle.repeat_interleave(hidden_states, n_rep, axis=1). - The hidden states go from (batch, seqlen, num_key_value_heads, head_axis) - to (batch, seqlen, num_attention_heads, head_axis) - """ - batch, slen, num_key_value_heads, head_axis = hidden_states.shape - if n_rep == 1: - return hidden_states - - hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1]) - return hidden_states.reshape([batch, slen, num_key_value_heads * n_rep, head_axis]) - - -# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 class DeepseekV2Attention(nn.Layer): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -892,6 +714,12 @@ def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False): self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads + self.num_local_heads = self.num_heads + if config.tensor_parallel_degree > 1: + assert ( + self.num_heads % config.tensor_parallel_degree == 0 + ), f"Attention head num ({self.num_heads}) is not divisible by tensor_parallel_degree ({config.tensor_parallel_degree})." + self.num_local_heads = self.num_heads // config.tensor_parallel_degree self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta @@ -905,10 +733,7 @@ def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False): self.is_causal = True self.fuse_rope = config.use_fused_rope - if config.num_nextn_predict_layers > 0: - self.seq_length = config.seq_length - config.num_nextn_predict_layers - else: - self.seq_length = config.seq_length + self.seq_length = config.seq_length self.sequence_parallel = config.sequence_parallel # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True @@ -929,47 +754,89 @@ def linear_dtype_gaurd(): # for which are the large weight and can achieve performance gain. # fmt: off - if self.config.tensor_parallel_degree > 1: - # for tensor parallel - if config.sequence_parallel: - ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear - RowParallelLinear = linear_utils.RowSequenceParallelLinear - else: - ColumnParallelLinear = linear_utils.ColumnParallelLinear - RowParallelLinear = linear_utils.RowParallelLinear - - if self.q_lora_rank is None: - with linear_dtype_gaurd(): - self.q_proj = ColumnParallelLinear(self.hidden_size, self.num_heads * self.q_head_dim, has_bias=False, gather_output=True) - else: - with linear_dtype_gaurd(): - self.q_a_proj = Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias) - self.q_b_proj = ColumnParallelLinear(config.q_lora_rank, self.num_heads * self.q_head_dim, has_bias=False, gather_output=True) - self.q_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.q_lora_rank, use_sequence_parallel=False) + if self.q_lora_rank is None: with linear_dtype_gaurd(): - self.kv_a_proj_with_mqa = Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias) - self.kv_b_proj = ColumnParallelLinear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), has_bias=False, gather_output=True) - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, self.hidden_size, has_bias=config.attention_bias, input_is_parallel=False) - self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank, use_sequence_parallel=False) + self.q_proj = GeneralLinear.create( + self.hidden_size, + self.num_heads * self.q_head_dim, + has_bias=False, + config=config, + fuse_matmul_bias=config.fuse_linear, + tp_plan="colwise", + gather_output=False, + ) else: - # for without tensor parallel - if self.q_lora_rank is None: - with linear_dtype_gaurd(): - self.q_proj = Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias_attr=False) - else: - with linear_dtype_gaurd(): - self.q_a_proj = Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias) - self.q_b_proj = Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias_attr=False) - self.q_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.q_lora_rank) - with linear_dtype_gaurd(): - self.kv_a_proj_with_mqa = Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias) - self.kv_b_proj = Linear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias_attr=False) - self.o_proj = Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias) - self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank) + self.q_a_proj = GeneralLinear.create( + self.hidden_size, + config.q_lora_rank, + has_bias=config.attention_bias, + config=config, + fuse_matmul_bias=config.fuse_linear, + linear_type="default", + gather_output=False, + ) + self.q_b_proj = GeneralLinear.create( + config.q_lora_rank, + self.num_heads * self.q_head_dim, + has_bias=False, + config=config, + fuse_matmul_bias=config.fuse_linear, + tp_plan="colwise", + gather_output=False, + ) + self.q_a_layernorm = GeneralNorm.create( + config=config, + hidden_size=config.q_lora_rank, + norm_type="rms_norm", + ) + + with linear_dtype_gaurd(): + self.kv_a_proj_with_mqa = GeneralLinear.create( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + has_bias=config.attention_bias, + config=config, + fuse_matmul_bias=config.fuse_linear, + linear_type="default", + gather_output=False, + ) + + self.kv_b_proj = GeneralLinear.create( + config.kv_lora_rank, + self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + has_bias=False, + config=config, + fuse_matmul_bias=config.fuse_linear, + tp_plan="colwise", + gather_output=False, + ) + + self.o_proj = GeneralLinear.create( + self.num_heads * self.v_head_dim, + self.hidden_size, + has_bias=config.attention_bias, + config=config, + fuse_matmul_bias=config.fuse_linear, + tp_plan="rowwise", + gather_output=False, + input_is_parallel=True + ) + + self.kv_a_layernorm = GeneralNorm.create( + config=config, + hidden_size=config.kv_lora_rank, + norm_type="rms_norm", + ) # fmt: on + if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel: + mark_as_sequence_parallel_parameter(self.kv_a_proj_with_mqa.weight) + mark_as_sequence_parallel_parameter(self.q_a_proj.weight) + if config.attention_bias: + mark_as_sequence_parallel_parameter(self.kv_a_proj_with_mqa.bias) + mark_as_sequence_parallel_parameter(self.q_a_proj.bias) self._init_rope() @@ -1047,7 +914,11 @@ def forward( warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) - bsz, q_len, _ = hidden_states.shape + ori_shape = hidden_states.shape + if self.config.sequence_parallel: + seq_len, bsz, _ = hidden_states.shape + else: + bsz, seq_len, _ = hidden_states.shape # DeepSeekV2 q_lora_rank=1536 # DeepSeekV2-lite q_lora_rank=None @@ -1057,8 +928,13 @@ def forward( q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) if self.sequence_parallel: - target_query_shape = [-1, self.seq_length, self.num_heads, self.q_head_dim] - target_key_value_shape = [-1, self.seq_length, self.num_heads, self.qk_nope_head_dim + self.v_head_dim] + target_query_shape = [bsz, self.seq_length, self.num_local_heads, self.q_head_dim] + target_key_value_shape = [ + bsz, + self.seq_length, + self.num_local_heads, + self.qk_nope_head_dim + self.v_head_dim, + ] else: target_query_shape = [0, 0, self.num_heads, self.q_head_dim] target_key_value_shape = [0, 0, self.num_heads, self.qk_nope_head_dim + self.v_head_dim] @@ -1071,8 +947,9 @@ def forward( compressed_kv, k_pe = paddle.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], axis=-1) if self.sequence_parallel: k_pe = GatherOp.apply(k_pe) - k_pe = k_pe.reshape([-1, q_len, 1, self.qk_rope_head_dim]).expand( - [-1, q_len, self.num_heads, self.qk_rope_head_dim] + k_pe = paddle.transpose(k_pe, [1, 0, 2]) + k_pe = k_pe.reshape([-1, self.seq_length, 1, self.qk_rope_head_dim]).expand( + [-1, self.seq_length, self.num_local_heads, self.qk_rope_head_dim] ) # self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim = 128+64 @@ -1140,6 +1017,8 @@ def forward( # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim] # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism. attn_output = self.o_proj(attn_output) + if attn_output.shape != ori_shape: + attn_output = attn_output.reshape(ori_shape) if not output_attentions: attn_weights = None @@ -1171,7 +1050,7 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute self.self_attn = DeepseekV2Attention(config=config, layerwise_recompute=layerwise_recompute) - MoELayerClass = DeepseekV2MoEFlexToken if config.using_flex_token else DeepseekV2MoE + MoELayerClass = DeepseekV2MoEFlexToken self.mlp = ( MoELayerClass(config) @@ -1182,10 +1061,17 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute ) else DeepseekV2MLP(config) ) - self.input_layernorm = DeepseekV2RMSNorm(config) - self.post_attention_layernorm = DeepseekV2RMSNorm(config) - def forward( + self.input_layernorm = GeneralNorm.create( + config=config, + norm_type="rms_norm", + ) + self.post_attention_layernorm = GeneralNorm.create( + config=config, + norm_type="rms_norm", + ) + + def subbatch_recompute_forward( self, hidden_states: paddle.Tensor, position_ids: Optional[paddle.Tensor] = None, @@ -1194,26 +1080,68 @@ def forward( past_key_value: Optional[Tuple[paddle.Tensor]] = None, use_cache: Optional[bool] = False, attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, - **kwargs, ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: - """ - Args: - hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_axis)` - attention_mask (`paddle.Tensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + offload_kwargs = {} + offload_kwargs["offload_indices"] = [0] + assert self.recompute_granularity != "full_attn" + attn_outputs = recompute( + self.attn, + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + **offload_kwargs, + ) + + hidden_states = attn_outputs[0] + residual = attn_outputs[1] + self_attn_weights = attn_outputs[2] if output_attentions else None + present_key_value = attn_outputs[3] if use_cache else None + + assert len(hidden_states.shape) == 3 + sub_seq_len = self.config.moe_subbatch_token_num + seq_axis = 0 if self.config.sequence_parallel else 1 + seq_len = hidden_states.shape[seq_axis] + assert seq_len % sub_seq_len == 0 + num_chunks = seq_len // sub_seq_len + split_list = [sub_seq_len] * num_chunks + input_list = paddle.split(hidden_states, split_list, axis=seq_axis) + output_list = [] + + for chunk in input_list: + out = recompute( + self.mlp.forward, + chunk, + **offload_kwargs, ) + output_list.append(out) + hidden_states = paddle.concat(output_list, axis=seq_axis) + outputs = recompute( + self.post_process, + hidden_states, + residual, + output_attentions, + use_cache, + self_attn_weights, + present_key_value, + **offload_kwargs, + ) + return outputs + + def attn( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + **kwargs, + ): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -1254,18 +1182,32 @@ def forward( else: hidden_states = outputs + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + attn_outputs = (hidden_states, residual) + if output_attentions: self_attn_weights = outputs[1] + attn_outputs += (self_attn_weights,) if use_cache: present_key_value = outputs[2 if output_attentions else 1] + attn_outputs += (present_key_value,) - hidden_states = residual + hidden_states + return attn_outputs - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) + def post_process( + self, + hidden_states, + residual, + output_attentions=False, + use_cache=False, + self_attn_weights=None, + present_key_value=None, + ): hidden_states = residual + hidden_states outputs = (hidden_states,) @@ -1281,6 +1223,43 @@ def forward( return outputs + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + *args, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + attn_outputs = self.attn( + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + **kwargs, + ) + hidden_states = attn_outputs[0] + residual = attn_outputs[1] + self_attn_weights = attn_outputs[2] if output_attentions else None + present_key_value = attn_outputs[3] if use_cache else None + hidden_states = self.mlp(hidden_states) + outputs = self.post_process( + hidden_states, residual, output_attentions, use_cache, self_attn_weights, present_key_value + ) + return outputs + class DeepseekV2MTPLayer(DeepseekV2DecoderLayer): def __init__( @@ -1291,10 +1270,55 @@ def __init__( ): super(DeepseekV2MTPLayer, self).__init__(config, layer_idx, layerwise_recompute) - self.enorm = DeepseekV2RMSNorm(config) - self.hnorm = DeepseekV2RMSNorm(config) + self.enorm = GeneralNorm.create( + config=config, + norm_type="rms_norm", + ) + self.hnorm = GeneralNorm.create( + config=config, + norm_type="rms_norm", + ) self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size) + if config.sequence_parallel and config.tensor_parallel_degree > 1: + mark_as_sequence_parallel_parameter(self.eh_proj.weight) + mark_as_sequence_parallel_parameter(self.eh_proj.bias) + + def subbatch_recompute_forward( + self, + hidden_states: paddle.Tensor, + nextn_hidden_state: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + hidden_states = self.hnorm(hidden_states) + nextn_hidden_state = self.enorm(nextn_hidden_state) + + hidden_states = self.eh_proj(paddle.concat([nextn_hidden_state, hidden_states], axis=-1)) + + layer_outputs = super(DeepseekV2MTPLayer, self).subbatch_recompute_forward( + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + **kwargs, + ) + + if type(layer_outputs) is tuple: + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + return hidden_states + def forward( self, hidden_states: paddle.Tensor, @@ -1333,13 +1357,24 @@ def forward( class DeepseekV2PretrainedModel(PretrainedModel): config_class = DeepseekV2Config - base_model_prefix = "deepseek_v2" + base_model_prefix = "model" _no_split_modules = ["DeepseekV2DecoderLayer"] + transpose_weight_keys = [ + "kv_a_proj_with_mqa", + "kv_b_proj", + "o_proj", + "q_a_proj", + "q_b_proj", + "gate_proj", + "up_proj", + "down_proj", + "gate", + "eh_proj", + ] def _get_model_flops(self, batch_size=1, seq_length=None, **kwargs): from .mfu_utils import DeepSeekProjection - # self._ mfu_cal_proj = DeepSeekProjection(self.config) if seq_length is None: if hasattr(self.config, "seq_length"): @@ -1449,7 +1484,6 @@ def get_tensor_parallel_split_mappings(num_layers): base_actions["layers.0.self_attn.q_b_proj.weight"] = partial(fn, is_column=True) # if we have enough num_key_value_heads to split, then split it. - # ??? if config.num_key_value_heads % config.tensor_parallel_degree == 0: base_actions["layers.0.self_attn.kv_b_proj.weight"] = partial(fn, is_column=True) if config.use_fp8: @@ -1465,7 +1499,7 @@ def get_tensor_parallel_split_mappings(num_layers): base_actions["layers.0.mlp.down_proj.weight.weight_scale_inv"] = partial(fn, is_column=False) # moe unit routed experts - moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + moe_group = fleet.get_hybrid_communicate_group().get_expert_parallel_group() expert_parallel_degree = dist.get_world_size(moe_group) if expert_parallel_degree <= 1: for e_i in range(config.n_routed_experts): @@ -1589,10 +1623,9 @@ def __init__(self, config: DeepseekV2Config): self.recompute_granularity = config.recompute_granularity self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] - if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: - self.embed_tokens = mpu.VocabParallelEmbedding(config.vocab_size, config.hidden_size) - else: - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_tokens = GeneralEmbedding.create( + config=config, num_embeddings=config.vocab_size, embedding_dim=config.hidden_size + ) self.layers = nn.LayerList( [ @@ -1603,16 +1636,13 @@ def __init__(self, config: DeepseekV2Config): for layer_idx in range(config.num_hidden_layers, config.num_hidden_layers + config.num_nextn_predict_layers): self.layers.append(DeepseekV2MTPLayer(config, layer_idx, layer_idx not in self.no_recompute_layers)) - self.norm = DeepseekV2RMSNorm(config) + self.norm = GeneralNorm.create( + config=config, + norm_type="rms_norm", + ) self.enable_recompute = False - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - @staticmethod def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype): if attention_mask is not None: @@ -1638,14 +1668,8 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values past_key_values_length=past_key_values_length, ) # Convert bool attention_mask to float attention mask, which will be added to attention_scores later - if get_env_device() == "xpu": - x = paddle.to_tensor(0.0, dtype="float32") - y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32") - expanded_attn_mask = paddle.where(expanded_attn_mask, x, y) - else: - expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min).astype( - dtype - ) + + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min).astype(dtype) return expanded_attn_mask @paddle.jit.not_to_static @@ -1711,14 +1735,28 @@ def forward( batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") - if self.config.num_nextn_predict_layers > 0: seq_length -= self.config.num_nextn_predict_layers if attention_mask is not None: attention_mask = attention_mask[ :, :, : -self.config.num_nextn_predict_layers, : -self.config.num_nextn_predict_layers - ] + ].contiguous() + + # attn_mask_startend_row_indices: [b, num_head, seq_len] or [b, num_head, seq_len, C], C is 2 or 4 + if attn_mask_startend_row_indices is not None: + if attn_mask_startend_row_indices.ndim == 3: + attn_mask_startend_row_indices = attn_mask_startend_row_indices[ + :, + :, + : -self.config.num_nextn_predict_layers, + ].contiguous() + elif attn_mask_startend_row_indices.ndim == 4: + attn_mask_startend_row_indices = attn_mask_startend_row_indices[ + :, :, : -self.config.num_nextn_predict_layers, : + ].contiguous() + else: + raise ValueError("attn_mask_startend_row_indices must be 3D or 4D tensor") if self.enable_recompute and self.training: if use_cache: @@ -1770,14 +1808,12 @@ def forward( inputs_embeds_ori = inputs_embeds if self.config.sequence_parallel: - # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] bs, seq_len, hidden_size = inputs_embeds.shape - inputs_embeds = paddle.reshape(inputs_embeds, [bs * seq_len, hidden_size]) - # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) + inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H] inputs_embeds = ScatterOp.apply(inputs_embeds) # embed positions - hidden_states = inputs_embeds + hidden_states = inputs_embeds.contiguous() # decoder layers all_hidden_states = () if output_hidden_states else None @@ -1785,6 +1821,8 @@ def forward( next_decoder_cache = () if use_cache else None mtp_outputs = [] + moelayer_use_subbatch_recompute = self.config.moe_subbatch_token_num > 0 + for idx in range(self.config.num_hidden_layers): decoder_layer = self.layers[idx] @@ -1794,7 +1832,17 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None has_gradient = not hidden_states.stop_gradient - if ( + if moelayer_use_subbatch_recompute: + layer_outputs = decoder_layer.subbatch_recompute_forward( + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + ) + elif ( self.enable_recompute and idx not in self.no_recompute_layers and has_gradient @@ -1842,7 +1890,7 @@ def forward( if self.config.sequence_parallel: hidden_states = GatherOp.apply(hidden_states) - hidden_states = hidden_states.reshape([-1, seq_length, hidden_states.shape[-1]]) + hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) # [S, B, H] --> [B, S, H] inputs_embeds_cur_depth = paddle.cat( [inputs_embeds_ori[:, (nextn + 1) :, :], inputs_embeds_extra[:, : (nextn + 1), :]], axis=1 @@ -1955,62 +2003,6 @@ def add_loss(main_loss, loss): return loss -class DeepseekV2LMHead(nn.Layer): - def __init__(self, config: DeepseekV2Config): - super(DeepseekV2LMHead, self).__init__() - self.config = config - - if config.num_nextn_predict_layers > 0: - self.seq_length = config.seq_length - config.num_nextn_predict_layers - else: - self.seq_length = config.seq_length - - if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: - vocab_size = config.vocab_size // config.tensor_parallel_degree - else: - vocab_size = config.vocab_size - - self.weight = self.create_parameter( - shape=[config.hidden_size, vocab_size], - dtype=paddle.get_default_dtype(), - default_initializer=nn.initializer.XavierNormal(1.0), - ) - # Must set distributed attr for Tensor Parallel ! - self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False - if get_env_device() == "xpu": - try: - from paddle_xpu.layers.nn import ( # noqa: F401 - parallel_matmul as xpu_parallel_matmul, - ) - - self.xpu_parallel_matmul = xpu_parallel_matmul() - except ImportError: - self.xpu_parallel_matmul = None - - def forward(self, hidden_states, tensor_parallel_output=None): - if self.config.sequence_parallel: - hidden_states = GatherOp.apply(hidden_states) - hidden_states = paddle.reshape_(hidden_states, [-1, self.seq_length, self.config.hidden_size]) - - if tensor_parallel_output is None: - tensor_parallel_output = self.config.tensor_parallel_output - - if get_env_device() == "xpu" and self.xpu_parallel_matmul is not None: - logits = self.xpu_parallel_matmul( - hidden_states, - self.weight, - transpose_y=False, - tensor_parallel_output=tensor_parallel_output, - training=self.training, - ) - else: - logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) - return logits - - def extra_repr(self): - return f"hidden_size={self.weight.shape[0]}, vocab_size={self.weight.shape[1]}, dtype={self.weight.dtype}" - - class DeepseekV2ForCausalLM(DeepseekV2PretrainedModel): _tied_weights_keys = ["lm_head.weight"] @@ -2019,8 +2011,8 @@ def __init__(self, config: DeepseekV2Config): self.config = config self.deepseek_v2 = DeepseekV2Model(config) self.vocab_size = config.vocab_size - self.lm_head = DeepseekV2LMHead(config) - self.criterion = DeepseekV2PretrainingCriterion(config) + self.lm_head = GeneralLMHead(config) + self.criterion = CriterionLayer(config) def get_input_embeddings(self): return self.deepseek_v2.embed_tokens @@ -2084,7 +2076,6 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if attn_mask_startend_row_indices is not None and attention_mask is not None: logger.warning( "You have provided both attn_mask_startend_row_indices and attention_mask. " @@ -2134,7 +2125,6 @@ def forward( # if labels is None,means we need full output, instead of tensor_parallel_output # tensor_parallel_output is together with ParallelCrossEntropy tensor_parallel_output = self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1 - logits = self.lm_head(hidden_states, tensor_parallel_output=tensor_parallel_output) mtp_logits = ( [ @@ -2332,3 +2322,278 @@ def forward( hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +class DeepseekV2MTPLayerPipe(DeepseekV2MTPLayer): + def forward(self, args): + hidden_states, attention_mask, position_ids, position_embeddings, nbatch_pack_offset = parse_args(args) + + if attention_mask is None: + attn_mask = None + attn_mask_startend_row_indices = None + elif attention_mask.dtype == paddle.int32: + attn_mask = None + attn_mask_startend_row_indices = attention_mask + else: + attn_mask = attention_mask + attn_mask_startend_row_indices = None + assert len(attn_mask.shape) == 4, f"Attention mask should be 4D tensor, but got {attn_mask.shape}." + + hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1) + hidden_states_main_model = hidden_states_list[0] + inputs_embeds_cur_depth_list = hidden_states_list[1:] + has_gradient = not hidden_states_main_model.stop_gradient + + output_list = [hidden_states_main_model] + hidden_states = hidden_states_main_model + for depth in range(self.config.num_nextn_predict_layers): + inputs_embeds_cur_depth = inputs_embeds_cur_depth_list[depth] + + moelayer_use_subbatch_recompute = self.config.moe_subbatch_token_num > 0 + if moelayer_use_subbatch_recompute: + hidden_states = super().subbatch_recompute_forward( + hidden_states, + inputs_embeds_cur_depth, + position_ids=position_ids, + attention_mask=attn_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + ) + elif self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + if attn_mask is not None or attn_mask_startend_row_indices is not None: + hidden_states = recompute( + super().forward, + hidden_states, + inputs_embeds_cur_depth, + position_ids=position_ids, + attention_mask=attn_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + # for pretrain + hidden_states = recompute( + super().forward, + hidden_states, + inputs_embeds_cur_depth, + position_ids=position_ids, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + hidden_states = super().forward( + hidden_states, + inputs_embeds_cur_depth, + position_ids=position_ids, + attention_mask=attn_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + ) + output_list.append(hidden_states) + + hidden_states = paddle.concat(output_list, axis=-1) + + ret = (hidden_states,) + if attention_mask is not None: + ret += (attention_mask.clone(),) + if position_ids is not None: + ret += (position_ids.clone(),) + + return ret + + +class DeepseekV2EmbeddingPipe(EmbeddingPipe): + def forward(self, args): + num_nextn_predict_layers = self.config.get("num_nextn_predict_layers", 0) + input_ids, attention_mask, position_ids, _, _ = parse_args(args, num_nextn_predict_layers > 0) + inputs_embeds = self.embed_tokens(input_ids).astype(self.embed_tokens.weight.dtype) + + batch_size, max_seq_len = input_ids.shape + max_seq_len -= self.config.num_nextn_predict_layers + if num_nextn_predict_layers > 0: + if attention_mask is None: + attn_mask = None + attn_mask_startend_row_indices = None + elif attention_mask.dtype == paddle.int32: + attn_mask = None + attn_mask_startend_row_indices = attention_mask[:, :, :max_seq_len] + else: + attn_mask = attention_mask[:, :, :max_seq_len, :max_seq_len] + attn_mask_startend_row_indices = None + assert len(attn_mask.shape) == 4, f"Attention mask should be 4D tensor, but got {attn_mask.shape}." + if attn_mask is not None: + assert ( + attn_mask_startend_row_indices is None + ), "attention_mask and attn_mask_startend_row_indices can not be set at same time" + attn_mask = DeepseekV2Model._prepare_decoder_attention_mask( + attn_mask, (batch_size, max_seq_len), 0, inputs_embeds.dtype + ) + attn_mask = attn_mask_startend_row_indices if attn_mask_startend_row_indices is not None else attn_mask + + if num_nextn_predict_layers > 0: + inputs_embeds_extra = inputs_embeds[:, -self.config.num_nextn_predict_layers :, :] # [B, S, D] + inputs_embeds = inputs_embeds[:, : -self.config.num_nextn_predict_layers, :] + inputs_embeds_ori = inputs_embeds + batch_size, seq_length, _ = inputs_embeds.shape + + if self.sequence_parallel: + inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H] + inputs_embeds = ScatterOp.apply(inputs_embeds) + embeds_res = [inputs_embeds] + for depth in range(num_nextn_predict_layers): + inputs_embeds_mtp = paddle.concat( + [ + inputs_embeds_ori[:, (depth + 1) :, :], + inputs_embeds_extra[:, : (depth + 1), :], + ], + axis=1, + ) + if self.sequence_parallel: + inputs_embeds_mtp = paddle.transpose(inputs_embeds_mtp, [1, 0, 2]) # [B, S, H] --> [S, B, H] + inputs_embeds_mtp = ScatterOp.apply(inputs_embeds_mtp) + embeds_res.append(inputs_embeds_mtp) + res = paddle.concat(embeds_res, axis=-1) + ret = (res,) + else: + if self.sequence_parallel: + inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) # [B, S, H] --> [S, B, H] + inputs_embeds = ScatterOp.apply(inputs_embeds) + ret = (inputs_embeds,) + + if attn_mask is not None: + ret += (attn_mask.clone(),) + if position_ids is not None: + ret += (position_ids.clone(),) + return ret + + +class DeepseekV2DecoderLayerPipe(DeepseekV2DecoderLayer): + def forward(self, args): + hidden_states, attention_mask, position_ids, _, _ = parse_args(args) + + if self.config.num_nextn_predict_layers > 0: + hidden_size = hidden_states.shape[-1] + batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1) + inputs_embeds_mtp = hidden_states[..., -batch_size_mtp:].contiguous() + hidden_states = hidden_states[..., :batch_size_mtp].contiguous() + + if attention_mask is None: + attn_mask = None + attn_mask_startend_row_indices = None + elif attention_mask.dtype == paddle.int32: + attn_mask = None + attn_mask_startend_row_indices = attention_mask + else: + attn_mask = attention_mask + attn_mask_startend_row_indices = None + assert len(attn_mask.shape) == 4, f"Attention mask should be 4D tensor, but got {attn_mask.shape}." + + has_gradient = not hidden_states.stop_gradient + + moelayer_use_subbatch_recompute = self.config.moe_subbatch_token_num > 0 + if moelayer_use_subbatch_recompute: + hidden_states = super().subbatch_recompute_forward( + hidden_states, + position_ids=position_ids, + attention_mask=attn_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + ) + elif self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + hidden_states = recompute( + super().forward, + hidden_states, + position_ids=position_ids, + attention_mask=attn_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + hidden_states = super().forward( + hidden_states, + position_ids=position_ids, + attention_mask=attn_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + ) + + if self.config.num_nextn_predict_layers > 0: + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + + if isinstance(hidden_states, paddle.Tensor): + ret = (hidden_states,) + if attention_mask is not None: + ret += (attention_mask.clone(),) + if position_ids is not None: + ret += (position_ids.clone(),) + if len(ret) == 1: + (ret,) = ret + return ret + + +class DeepseekV2LMHeadPipe(GeneralLMHead): + def forward(self, args): + if self.config.num_nextn_predict_layers > 0: + logits = [] + for _hidden_states in args: + logits.append(super().forward(_hidden_states)) + return logits + + hidden_states, _, _, _, _ = parse_args(args) + logits = super().forward(hidden_states) + return logits + + +class DeepseekV2PretrainingCriterionPipe(DeepseekV2PretrainingCriterion): + def forward(self, logits, labels): + + # in GeneralModelForCausalLMPipe last_stage_keys = ["labels", "loss_mask"] + labels = labels[0] + if self.config.num_nextn_predict_layers > 0: + mtp_logits = logits[1:] + logits = logits[0] + loss = super().forward(logits, labels, mtp_logits=mtp_logits) + else: + loss = super().forward(logits, labels) + return loss + + +class DeepseekV2RMSNormLayerPipe(RMSNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.config.sequence_parallel: + self.enable_sequence_parallel() + + def forward(self, args): + hidden_states, _, _, _, _ = parse_args(args) + + if self.config.num_nextn_predict_layers > 0: + hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1) + hidden_states = hidden_states_list[0] + hidden_states_mtp = hidden_states_list[-self.config.num_nextn_predict_layers :] + + output_list = [super().forward(hidden_states)] + for hidden_states in hidden_states_mtp: + output_list.append(super().forward(hidden_states)) + return output_list + else: + hidden_states = super().forward(hidden_states) + return hidden_states + + +class DeepseekV2ForCausalLMPipe(GeneralModelForCausalLMPipe): + config_class = DeepseekV2Config + _embedding_pipe_cls = DeepseekV2EmbeddingPipe + _decoder_layer_cls = DeepseekV2DecoderLayer + _criterion_pipe_cls = DeepseekV2PretrainingCriterionPipe + _lmhead_pipe_cls = DeepseekV2LMHeadPipe + _decoder_layer_pipe_cls = DeepseekV2DecoderLayerPipe + _rms_norm_pipe_cls = DeepseekV2RMSNormLayerPipe + _base_model = DeepseekV2PretrainedModel + + _get_tensor_parallel_mappings = DeepseekV2PretrainedModel._get_tensor_parallel_mappings + _init_weights = DeepseekV2PretrainedModel._init_weights + _keys_to_ignore_on_load_unexpected = DeepseekV2PretrainedModel._keys_to_ignore_on_load_unexpected + _get_model_flops = DeepseekV2PretrainedModel._get_model_flops + _get_hardware_flops = DeepseekV2PretrainedModel._get_hardware_flops + transpose_weight_keys = DeepseekV2PretrainedModel.transpose_weight_keys + + _tied_weights_keys = ["lm_head.weight"] + + _mtp_layer_pipe_cls = DeepseekV2MTPLayerPipe diff --git a/paddleformers/transformers/deepseek_v2/modeling_auto.py b/paddleformers/transformers/deepseek_v2/modeling_auto.py deleted file mode 100644 index 5756f8a0586..00000000000 --- a/paddleformers/transformers/deepseek_v2/modeling_auto.py +++ /dev/null @@ -1,1263 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# Copyright (c) 2023 DeepSeek. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Paddle DeepSeek_V2 model.""" - -from __future__ import annotations - -import warnings -from typing import List, Optional, Tuple, Union - -import paddle -import paddle.distributed as dist -import paddle.nn.functional as F -from paddle import Tensor, nn -from paddle.distributed.fleet.utils import recompute -from paddle.nn import Linear - -try: - from paddle.incubate.nn.functional import fused_rotary_position_embedding -except ImportError: - fused_rotary_position_embedding = None - -try: - from paddle.nn.functional.flash_attention import flash_attention -except: - flash_attention = None - -from ...utils.log import logger -from ...utils.tools import get_env_device -from ..activations import ACT2FN -from ..llama import fusion_ops -from ..llama.modeling import get_use_casual_mask -from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ..model_utils import PretrainedModel, register_base_model -from ..moe_gate_auto import PretrainedMoEGate -from ..moe_layer_auto import MoELayer -from .configuration import DeepseekV2Config -from .modeling import ( - DeepseekV2DynamicNTKScalingRotaryEmbedding, - DeepseekV2LinearScalingRotaryEmbedding, - DeepseekV2PretrainingCriterion, - DeepseekV2RMSNorm, - DeepseekV2RotaryEmbedding, - DeepseekV2YarnRotaryEmbedding, - _expand_2d_mask, - _make_causal_mask, - apply_rotary_pos_emb, - get_triangle_upper_mask, - is_casual_mask, - yarn_get_mscale, -) - -__all__ = [ - "DeepseekV2LMHeadAuto", - "DeepseekV2ForCausalLMAuto", - "DeepseekV2ModelAuto", - "DeepseekV2PretrainedModelAuto", -] - - -def is_pp_enable(): - global_mesh = dist.auto_parallel.get_mesh() - return "pp" in global_mesh.dim_names - - -def scaled_dot_product_attention( - query_states, - config, - key_states, - value_states, - attention_mask, - output_attentions, - attn_mask_startend_row_indices=None, - softmax_scale=1.0, - training=True, - sequence_parallel=False, -): - bsz, q_len, num_heads, head_dim = query_states.shape - _, kv_seq_len, v_num_heads, v_head_dim = value_states.shape - - if config.use_flash_attention and flash_attention: - # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim] - # Torch Flash Attention input [ bz, nhead, seqlen, head_dim] - - # Note: Flash Attention does not support softmax_scale, so we need to scale the query_states - q_head_dim = query_states.shape[-1] - softmax_scale = softmax_scale * (q_head_dim**0.5) - query_states = query_states * softmax_scale - value_padding = paddle.zeros( - [bsz, kv_seq_len, v_num_heads, head_dim - v_head_dim], - dtype=value_states.dtype, - ) - value_states = paddle.cat([value_states, value_padding], axis=-1) - - outputs = fusion_ops.fusion_flash_attention( - query_states, - config, - key_states, - value_states, - attention_mask, - output_attentions, - attn_mask_startend_row_indices=attn_mask_startend_row_indices, - sequence_parallel=False, - ) - - if isinstance(outputs, tuple): - outputs[0] = outputs[0].reshape([bsz, q_len, v_num_heads, head_dim]) - outputs[0] = outputs[0][..., :v_head_dim] - outputs[0] = outputs[0].reshape([bsz, q_len, -1]) - else: - outputs = outputs.reshape([bsz, q_len, v_num_heads, head_dim]) - outputs = outputs[..., :v_head_dim] - outputs = outputs.reshape([bsz, q_len, -1]) - - if sequence_parallel: - attn_output = outputs.reshape([bsz * q_len, v_head_dim * num_heads]) - else: - attn_output = outputs.reshape([bsz, q_len, v_head_dim * num_heads]) - return attn_output - - else: - # [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim] - query_states = paddle.transpose(query_states, [0, 2, 1, 3]) - # merge with the next transpose - key_states = paddle.transpose(key_states, [0, 2, 1, 3]) - value_states = paddle.transpose(value_states, [0, 2, 1, 3]) - - # matmul and divide by sqrt(head_dim) - attn_weights = paddle.matmul(query_states * softmax_scale, key_states.transpose([0, 1, 3, 2])) - - if attn_weights.shape != [bsz, num_heads, q_len, kv_seq_len]: - raise ValueError( - f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.shape}" - ) - - if attention_mask is None: - attention_mask = get_triangle_upper_mask(attn_weights) - attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len]) - if attention_mask.shape != [bsz, 1, q_len, kv_seq_len]: - raise ValueError( - f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}" - ) - - attn_weights = attn_weights + attention_mask - if not paddle.in_dynamic_mode(): - attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) - else: - with paddle.amp.auto_cast(False): - attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) - - attn_weights = F.dropout(attn_weights, p=config.attention_dropout, training=training) - - attn_output = paddle.matmul(attn_weights, value_states) - attn_output = attn_output.transpose([0, 2, 1, 3]) - - if sequence_parallel: - attn_output = attn_output.reshape([bsz * q_len, v_head_dim * num_heads]) - else: - attn_output = attn_output.reshape([bsz, q_len, v_head_dim * num_heads]) - return (attn_output, attn_weights) if output_attentions else attn_output - - -class MoEGate(PretrainedMoEGate): - def __init__(self, config, num_experts, expert_hidden_size, **kwargs): - super().__init__(config, num_experts, expert_hidden_size, **kwargs) - # [hidden_size, n_expert] - - self.scoring_func = config.scoring_func - self.topk_method = config.topk_method - - self.weight = paddle.create_parameter( - shape=[expert_hidden_size, num_experts], - dtype=paddle.get_default_dtype(), - is_bias=False, - default_initializer=nn.initializer.Constant(1.0), - ) - - if config.topk_method == "noaux_tc": - self.e_score_correction_bias = paddle.create_parameter( - shape=[num_experts], - dtype=paddle.get_default_dtype(), - default_initializer=nn.initializer.Constant(0.0), - ) - - def forward(self, hidden_states): - """ - Args: - hidden_states (_type_): [batch_size * seq_len, hidden_size] - """ - _, h_dim = hidden_states.shape - - # compute gating score - logits = F.linear(hidden_states, self.weight, None) - - with paddle.amp.auto_cast(False): - scores = self.gate_score_func(logits=logits) - scores = scores.cast(paddle.get_default_dtype()) - - capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss = self.topkgating(scores) - - return capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss - - -class AddAuxiliaryLoss(paddle.autograd.PyLayer): - """ - The trick function of adding auxiliary (aux) loss, - which includes the gradient of the aux loss during backpropagation. - """ - - @staticmethod - def forward(ctx, x, loss): - # assert paddle.numel(loss) == 1 - ctx.dtype = loss.dtype - ctx.required_aux_loss = not loss.stop_gradient - return x - - @staticmethod - def backward(ctx, grad_output): - grad_loss = None - if ctx.required_aux_loss: - # grad_loss = paddle.ones(1, dtype=ctx.dtype) - grad_loss = paddle.to_tensor(1, dtype=ctx.dtype) - mesh = grad_output.process_mesh - grad_loss = dist.auto_parallel.api.dtensor_from_local(grad_loss, mesh, [dist.Replicate()]) - return grad_output, grad_loss - - -class DeepseekV2MLPAuto(nn.Layer): - def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None, is_moe=False): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size if hidden_size is None else hidden_size - self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size - - self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) - self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) - self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False) - - self.act_fn = ACT2FN[config.hidden_act] - - def redistribute_expert(self, mesh, placements): - """ - Place the experts on different devices. - """ - self.gate_proj.weight = dist.shard_tensor(self.gate_proj.weight, mesh, placements) - if self.gate_proj.bias is not None: - self.gate_proj.bias = dist.shard_tensor(self.gate_proj.bias, mesh, placements) - - self.up_proj.weight = dist.shard_tensor(self.up_proj.weight, mesh, placements) - if self.up_proj.bias is not None: - self.up_proj.bias = dist.shard_tensor(self.up_proj.bias, mesh, placements) - - self.down_proj.weight = dist.shard_tensor(self.down_proj.weight, mesh, placements) - if self.down_proj.bias is not None: - self.down_proj.bias = dist.shard_tensor(self.down_proj.bias, mesh, placements) - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -class DeepseekV2MoEAuto(MoELayer): - """ - A mixed expert module containing shared experts. - """ - - def __init__(self, config: DeepseekV2Config, ipp=None): - gate = MoEGate( - config=config, - num_experts=config.n_routed_experts, - expert_hidden_size=config.hidden_size, - top_k=config.num_experts_per_tok, - topk_method=config.topk_method, - n_group=config.n_group, - topk_group=config.topk_group, - norm_topk_prob=config.norm_topk_prob, - routed_scaling_factor=config.routed_scaling_factor, - drop_tokens=False, - ipp=ipp, - ) - - super().__init__( - config=config, - moe_num_experts=config.n_routed_experts, - expert_class=DeepseekV2MLPAuto, - expert_kwargs={"config": config, "intermediate_size": config.moe_intermediate_size}, - gate=gate, - capacity=2.0, - ipp=ipp, - ) - self.alpha = config.aux_loss_alpha - if config.n_shared_experts is not None: - intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekV2MLPAuto(config=config, intermediate_size=intermediate_size, is_moe=True) - - def forward(self, hidden_states): - final_hidden_states, l_aux, l_zloss = super().forward(hidden_states) - if self.training and self.alpha > 0.0: - final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, l_aux) - - if self.config.n_shared_experts is not None: - shared_expert_output = self.shared_experts(hidden_states) - final_hidden_states = final_hidden_states + shared_expert_output - return final_hidden_states - - -# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 -class DeepseekV2AttentionAuto(nn.Layer): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False): - super().__init__() - self.config = config - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.q_lora_rank = config.q_lora_rank - self.qk_rope_head_dim = config.qk_rope_head_dim - self.kv_lora_rank = config.kv_lora_rank - self.v_head_dim = config.v_head_dim - self.qk_nope_head_dim = config.qk_nope_head_dim - self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim - - self.is_causal = True - - self.seq_length = config.seq_length - - # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True - # Enable_recompute defaults to False and is controlled by Trainer - self.enable_recompute = False - self.layerwise_recompute = layerwise_recompute - self.recompute_granularity = config.recompute_granularity - - # Note (@DrownFish19): For tensor parallel we consider that q_a_proj and kv_a_proj_with_mqa - # are the small weight and cannot achieve performance gain. So we use the original - # linear layers. We use the tensor parallel linear layers for q_proj,q_b_proj and kv_b_proj - # for which are the large weight and can achieve performance gain. - - # fmt: off - # for without tensor parallel - if self.q_lora_rank is None: - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias_attr=False) - else: - self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias) - self.q_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.q_lora_rank) - self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias_attr=False) - - self.kv_a_proj_with_mqa = nn.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias) - self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank) - self.kv_b_proj = nn.Linear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias_attr=False) - - self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias) - # fmt: on - - self._init_rope() - - self.softmax_scale = self.q_head_dim ** (-0.5) - if self.config.rope_scaling is not None: - mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) - scaling_factor = self.config.rope_scaling["factor"] - if mscale_all_dim: - mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) - self.softmax_scale = self.softmax_scale * mscale * mscale - - self.attn_func = scaled_dot_product_attention - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = DeepseekV2RotaryEmbedding( - self.qk_rope_head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding( - self.qk_rope_head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - elif scaling_type == "dynamic": - self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding( - self.qk_rope_head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - elif scaling_type == "yarn": - kwargs = { - key: self.config.rope_scaling[key] - for key in [ - "original_max_position_embeddings", - "beta_fast", - "beta_slow", - "mscale", - "mscale_all_dim", - ] - if key in self.config.rope_scaling - } - self.rotary_emb = DeepseekV2YarnRotaryEmbedding( - self.qk_rope_head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - **kwargs, - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def _shape(self, tensor: paddle.Tensor, seq_len: int, bsz: int): - return tensor.reshape([bsz, seq_len, self.num_heads, self.v_head_dim]).transpose([1, 0, 2, 3]) - - def forward( - self, - hidden_states: paddle.Tensor, - position_ids: Optional[Tuple[paddle.Tensor]] = None, - past_key_value: Optional[Tuple[paddle.Tensor]] = None, - attention_mask: Optional[paddle.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = False, - attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, - **kwargs, - ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.shape - - # DeepSeekV2 q_lora_rank=1536 - # DeepSeekV2-lite q_lora_rank=None - if self.q_lora_rank is None: - q = self.q_proj(hidden_states) - else: - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - q = q.reshape([bsz, q_len, self.num_heads, self.q_head_dim]) - q_nope, q_pe = paddle.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) - - # DeepSeekV2 kv_lora_rank+qk_rope_head_dim=512+64 - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv, k_pe = paddle.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], axis=-1) - k_pe = k_pe.reshape([bsz, q_len, 1, self.qk_rope_head_dim]) - - # self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim = 128+64 - # self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim) = config.qk_nope_head_dim + self.v_head_dim = 128+128 - kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).reshape( - [bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim] - ) - - k_nope, value_states = paddle.split(kv, [self.qk_nope_head_dim, self.v_head_dim], axis=-1) - kv_seq_len = value_states.shape[1] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-3] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - cos = cos[None, :, None, :] - sin = sin[None, :, None, :] - q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - - query_states = paddle.cat([q_nope, q_pe], axis=-1) - key_states = paddle.cat([k_nope, k_pe.expand([bsz, q_len, self.num_heads, k_pe.shape[-1]])], axis=-1) - - # key_states[:, :, :, : self.qk_nope_head_dim] = k_nope - # key_states[:, :, :, self.qk_nope_head_dim :] = k_pe - - # [bs, seq_len, num_head, head_dim] - if past_key_value is not None: - # reuse k, v, self_attention - key_states = paddle.cat([past_key_value[0], key_states], axis=1) - value_states = paddle.cat([past_key_value[1], value_states], axis=1) - past_key_value = (key_states, value_states) if use_cache else None - - has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient) - if ( - self.enable_recompute - and self.layerwise_recompute - and has_gradient - and self.recompute_granularity == "core_attn" - ): - outputs = recompute( - self.attn_func, - query_states, - self.config, - key_states, - value_states, - attention_mask, - output_attentions, - attn_mask_startend_row_indices=attn_mask_startend_row_indices, - softmax_scale=self.softmax_scale, - training=self.training, - use_reentrant=self.config.recompute_use_reentrant, - ) - else: - outputs = self.attn_func( - query_states, - self.config, - key_states, - value_states, - attention_mask, - output_attentions, - attn_mask_startend_row_indices=attn_mask_startend_row_indices, - softmax_scale=self.softmax_scale, - training=self.training, - ) - if output_attentions: - attn_output, attn_weights = outputs - else: - attn_output = outputs - - # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim] - # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism. - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class DeepseekV2DecoderLayerAuto(nn.Layer): - def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute: bool = False, ipp=None): - super().__init__() - self.config = config - - self.enable_recompute = False - self.layerwise_recompute = layerwise_recompute - self.recompute_granularity = config.recompute_granularity - - self.hidden_size = config.hidden_size - - self.self_attn = DeepseekV2AttentionAuto(config=config, layerwise_recompute=layerwise_recompute) - self.ipp = ipp - - self.mlp = ( - DeepseekV2MoEAuto(config, ipp=self.ipp) - if ( - config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0 - ) - else DeepseekV2MLPAuto(config) - ) - self.input_layernorm = DeepseekV2RMSNorm(config) - self.post_attention_layernorm = DeepseekV2RMSNorm(config) - - def forward( - self, - hidden_states: paddle.Tensor, - position_ids: Optional[paddle.Tensor] = None, - attention_mask: Optional[paddle.Tensor] = None, - output_attentions: Optional[bool] = False, - past_key_value: Optional[Tuple[paddle.Tensor]] = None, - use_cache: Optional[bool] = False, - attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, - **kwargs, - ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: - """ - Args: - hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_axis)` - attention_mask (`paddle.Tensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - has_gradient = not hidden_states.stop_gradient - if ( - self.enable_recompute - and self.layerwise_recompute - and has_gradient - and self.recompute_granularity == "full_attn" - ): - hidden_states, self_attn_weights, present_key_value = recompute( - self.self_attn, - hidden_states=hidden_states, - position_ids=position_ids, - attention_mask=attention_mask, - output_attentions=output_attentions, - past_key_value=past_key_value, - use_cache=use_cache, - attn_mask_startend_row_indices=attn_mask_startend_row_indices, - **kwargs, - ) - else: - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states, - position_ids=position_ids, - attention_mask=attention_mask, - output_attentions=output_attentions, - past_key_value=past_key_value, - use_cache=use_cache, - attn_mask_startend_row_indices=attn_mask_startend_row_indices, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - if type(outputs) is tuple and len(outputs) == 1: - outputs = outputs[0] - - return outputs - - -class DeepseekV2MTPLayerAuto(DeepseekV2DecoderLayerAuto): - def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute: bool = False, ipp=None): - super(DeepseekV2MTPLayerAuto, self).__init__(config, layer_idx, layerwise_recompute, ipp) - - self.enorm = DeepseekV2RMSNorm(config) - self.hnorm = DeepseekV2RMSNorm(config) - self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size) - - def forward( - self, - hidden_states: paddle.Tensor, - nextn_hidden_state: paddle.Tensor, - position_ids: Optional[paddle.Tensor] = None, - attention_mask: Optional[paddle.Tensor] = None, - output_attentions: Optional[bool] = False, - past_key_value: Optional[Tuple[paddle.Tensor]] = None, - use_cache: Optional[bool] = False, - attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, - **kwargs, - ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: - - hidden_states = self.hnorm(hidden_states) - nextn_hidden_state = self.enorm(nextn_hidden_state) - - hidden_states = self.eh_proj(paddle.cat([hidden_states, nextn_hidden_state], axis=-1)) - - layer_outputs = super(DeepseekV2MTPLayerAuto, self).forward( - hidden_states, - position_ids, - attention_mask, - output_attentions, - past_key_value, - use_cache, - attn_mask_startend_row_indices, - **kwargs, - ) - - if type(layer_outputs) is tuple: - hidden_states = layer_outputs[0] - else: - hidden_states = layer_outputs - - return hidden_states - - -class DeepseekV2PretrainedModelAuto(PretrainedModel): - config_class = DeepseekV2Config - base_model_prefix = "deepseek_v2" - _no_split_modules = ["DeepseekV2DecoderLayerAuto"] - - -class GlobalOutputNet(nn.Layer): - def __init__(self, config) -> None: - super().__init__() - self.config = config - - @staticmethod - def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype): - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - if len(attention_mask.shape) == 2: - expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1]) - # For decoding phase in generation, seq_length = 1, we don't need to add causal mask - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - past_key_values_length=past_key_values_length, - ) - expanded_attn_mask = expanded_attn_mask & combined_attention_mask - # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len] - elif len(attention_mask.shape) == 3: - expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool") - # if attention_mask is already 4-D, do nothing - else: - expanded_attn_mask = attention_mask - else: - expanded_attn_mask = _make_causal_mask( - input_shape, - past_key_values_length=past_key_values_length, - ) - # Convert bool attention_mask to float attention mask, which will be added to attention_scores later - if get_env_device() == "xpu": - x = paddle.to_tensor(0.0, dtype="float32") - y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32") - expanded_attn_mask = paddle.where(expanded_attn_mask, x, y) - else: - expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min).astype( - dtype - ) - return expanded_attn_mask - - def forward( - self, - position_ids, - attention_mask, - seq_length, - batch_size, - seq_length_with_past, - cache_length, - emb_dtype, - attn_mask_startend_row_indices, - ): - if position_ids is None: - position_ids = paddle.arange(cache_length, seq_length + cache_length, dtype=paddle.int64) - position_ids = position_ids.unsqueeze(0) - - if ( - attn_mask_startend_row_indices is not None - or get_use_casual_mask() - or (self.config.use_flash_attention and self.training) - ): - attention_mask = None - else: - # [bs, seq_len] - attention_mask = ( - paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) - if attention_mask is None - else attention_mask - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), cache_length, emb_dtype - ) # [bs, 1, seq_len, seq_len] - if self.config.use_flash_attention: - attention_mask = None if is_casual_mask(attention_mask) else attention_mask - - return position_ids, attention_mask - - -@register_base_model -class DeepseekV2ModelAuto(DeepseekV2PretrainedModelAuto): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayerAuto`] - - Args: - config: DeepseekV2Config - """ - - def __init__(self, config: DeepseekV2Config): - super().__init__(config) - - self.config = config - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - # Recompute defaults to False and is controlled by Trainer - self.enable_recompute = False - self.recompute_granularity = config.recompute_granularity - self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.global_layer = GlobalOutputNet(config=config) - - def divide_list_indices(n, k): - n = n + self.config.pp_extra_layer_num - base_size = n // k - extra = n % k - - indices = [] - current_index = -1 - - for i in range(k): - current_index += base_size - if i < extra: - current_index += 1 - indices.append(current_index) - return indices - - if is_pp_enable(): - mesh = dist.auto_parallel.get_mesh() - self.pp_indices = divide_list_indices( - config.num_hidden_layers + config.num_nextn_predict_layers, mesh.get_dim_size("pp") - ) - - def get_pp_stage_id(layer_idx): - if not is_pp_enable(): - return None - else: - for idx, end_idx in enumerate(self.pp_indices): - if layer_idx <= end_idx: - return idx - - decoder_layers = [] - for layer_idx in range(config.num_hidden_layers + config.num_nextn_predict_layers): - pp_stage_id = get_pp_stage_id(layer_idx) - logger.info(f"layer_idx:{layer_idx}, pp_stage_id:{pp_stage_id}") - if layer_idx < config.num_hidden_layers: - decoder_layers.append( - DeepseekV2DecoderLayerAuto( - config, layer_idx, layer_idx not in self.no_recompute_layers, pp_stage_id - ) - ) - else: - decoder_layers.append( - DeepseekV2MTPLayerAuto(config, layer_idx, layer_idx not in self.no_recompute_layers, pp_stage_id) - ) - - self.layers = nn.LayerList(decoder_layers) - - self.norm = DeepseekV2RMSNorm(config) - - self.enable_recompute = False - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def forward( - self, - input_ids: paddle.Tensor = None, - position_ids: Optional[paddle.Tensor] = None, - attention_mask: Optional[paddle.Tensor] = None, - inputs_embeds: Optional[paddle.Tensor] = None, - use_cache: Optional[bool] = None, - past_key_values: Optional[List[paddle.Tensor]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - attn_mask_startend_row_indices: Optional[Tensor] = None, - **kwargs, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape[:2] - elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - seq_length -= self.config.num_nextn_predict_layers - - if self.enable_recompute and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." - ) - use_cache = False - - if past_key_values is None: - past_key_values = tuple([None] * len(self.layers)) - # NOTE: to make cache can be clear in-time - past_key_values = list(past_key_values) - - seq_length_with_past = seq_length - cache_length = 0 - if past_key_values[0] is not None: - cache_length = past_key_values[0][0].shape[1] - seq_length_with_past += cache_length - - if inputs_embeds is None: - # [bs, seq_len, dim] - inputs_embeds = self.embed_tokens(input_ids) - - position_ids, attention_mask = self.global_layer( - position_ids, - attention_mask, - seq_length, - batch_size, - seq_length_with_past, - cache_length, - inputs_embeds.dtype, - attn_mask_startend_row_indices, - ) - - if self.config.num_nextn_predict_layers > 0: - inputs_embeds_extra = inputs_embeds[:, -self.config.num_nextn_predict_layers :, :] # [B, S, D] - inputs_embeds = inputs_embeds[:, : -self.config.num_nextn_predict_layers, :] - inputs_embeds_ori = inputs_embeds - - # embed positions - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - mtp_outputs = [] - - for idx in range(self.config.num_hidden_layers): - decoder_layer = self.layers[idx] - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - has_gradient = not hidden_states.stop_gradient - if ( - self.enable_recompute - and idx not in self.no_recompute_layers - and has_gradient - and self.recompute_granularity == "full" - ): - layer_outputs = self.recompute_training_full( - decoder_layer, - hidden_states, - position_ids, - attention_mask, - output_attentions, - past_key_value, - use_cache, - attn_mask_startend_row_indices, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - position_ids, - attention_mask, - output_attentions, - past_key_value, - use_cache, - attn_mask_startend_row_indices, - ) - - # NOTE: clear outdate cache after it has been used for memory saving - past_key_value = past_key_values[idx] = None - if type(layer_outputs) is tuple: - hidden_states = layer_outputs[0] - else: - hidden_states = layer_outputs - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if self.config.num_nextn_predict_layers > 0: - mtp_outputs.append(hidden_states) - - for nextn in range(self.config.num_nextn_predict_layers): - decoder_layer = self.layers[nextn + self.config.num_hidden_layers] - - # 构建输入向量 - inputs_embeds_cur_depth = paddle.cat( - [inputs_embeds_ori[:, (nextn + 1) :, :], inputs_embeds_extra[:, : (nextn + 1), :]], axis=1 - ) - - if inputs_embeds_cur_depth.process_mesh != hidden_states.process_mesh: - inputs_embeds_cur_depth = paddle.distributed.reshard( - inputs_embeds_cur_depth, - hidden_states.process_mesh, - inputs_embeds_cur_depth.placements, - ) - # 通过该层的decoder_layer进行预测 - past_key_value = None - layer_outputs = decoder_layer( - hidden_states, - inputs_embeds_cur_depth, - position_ids, - attention_mask, - output_attentions, - past_key_value, - use_cache, - attn_mask_startend_row_indices, - ) - - if isinstance(layer_outputs, (tuple, list)): - hidden_states = layer_outputs[0] - else: - hidden_states = layer_outputs - - mtp_outputs.append(hidden_states) - mtp_outputs = [self.norm(hidden_states) for hidden_states in mtp_outputs] - hidden_states, mtp_outputs = mtp_outputs[0], mtp_outputs[1:] - else: - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - - if not return_dict: - return tuple( - v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, mtp_outputs] if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class DeepseekV2LMHeadAuto(nn.Layer): - def __init__(self, config: DeepseekV2Config): - super(DeepseekV2LMHeadAuto, self).__init__() - - self.config = config - - self.weight = self.create_parameter( - shape=[config.hidden_size, config.vocab_size], - dtype=paddle.get_default_dtype(), - default_initializer=nn.initializer.XavierNormal(1.0), - ) - - def forward(self, hidden_states, tensor_parallel_output=None): - if tensor_parallel_output is None: - tensor_parallel_output = self.config.tensor_parallel_output - logits = paddle.matmul(hidden_states, self.weight) - return logits - - -class DeepseekV2ForCausalLMAuto(DeepseekV2PretrainedModelAuto): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config: DeepseekV2Config): - super().__init__(config) - self.config = config - self.deepseek_v2 = DeepseekV2ModelAuto(config) - self.vocab_size = config.vocab_size - self.lm_head = DeepseekV2LMHeadAuto(config) - self.criterion = DeepseekV2PretrainingCriterion(config) - - def get_input_embeddings(self): - return self.deepseek_v2.embed_tokens - - def set_input_embeddings(self, value): - self.deepseek_v2.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.deepseek_v2 = decoder - - def get_decoder(self): - return self.deepseek_v2 - - def forward( - self, - input_ids: paddle.Tensor = None, - position_ids: Optional[paddle.Tensor] = None, - attention_mask: Optional[paddle.Tensor] = None, - inputs_embeds: Optional[paddle.Tensor] = None, - labels: Optional[paddle.Tensor] = None, - use_cache: Optional[bool] = None, - past_key_values: Optional[List[paddle.Tensor]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - attn_mask_startend_row_indices=None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLMAuto - - >>> model = DeepseekV2ForCausalLMAuto.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - input_ids.stop_gradient = True - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if attn_mask_startend_row_indices is not None and attention_mask is not None: - logger.warning( - "You have provided both attn_mask_startend_row_indices and attention_mask. " - "The attn_mask_startend_row_indices will be used." - ) - attention_mask = None - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.deepseek_v2( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - past_key_values=past_key_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - attn_mask_startend_row_indices=attn_mask_startend_row_indices, - ) - - hidden_states = outputs[0] - mtp_outputs = outputs[-1] - - # if labels is None,means we need full output, instead of tensor_parallel_output - # tensor_parallel_output is together with ParallelCrossEntropy - tensor_parallel_output = self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1 - - logits = self.lm_head(hidden_states, tensor_parallel_output=tensor_parallel_output) - - mtp_logits = [self.lm_head(_hidden_states) for _hidden_states in mtp_outputs] if len(mtp_outputs) > 0 else [] - - return self.criterion(logits, labels, mtp_logits=mtp_logits) - - def prepare_inputs_for_generation( - self, input_ids, use_cache=False, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - batch_size, seq_length = input_ids.shape - position_ids = kwargs.get("position_ids", paddle.arange(seq_length).expand((batch_size, seq_length))) - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(axis=-1) - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - } - ) - return model_inputs - - def _get_model_inputs_spec(self, dtype: str): - return { - "input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"), - "attention_mask": paddle.static.InputSpec(shape=[None, None], dtype="int64"), - "position_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"), - } - - @staticmethod - def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): - # update cache - if isinstance(outputs, tuple) and len(outputs) > 1 and not isinstance(outputs[1], paddle.Tensor): - model_kwargs["past_key_values"] = outputs[1] - - if isinstance(outputs, CausalLMOutputWithPast) and "past_key_values" in outputs: - model_kwargs["past_key_values"] = outputs.past_key_values - - # update position_ids - if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None: - position_ids = model_kwargs["position_ids"] - model_kwargs["position_ids"] = paddle.cat([position_ids, position_ids[..., -1:] + 1], axis=-1) - - if not is_encoder_decoder and "attention_mask" in model_kwargs: - # TODO: support attention mask for other models - attention_mask = model_kwargs["attention_mask"] - if len(attention_mask.shape) == 2: - model_kwargs["attention_mask"] = paddle.cat( - [attention_mask, paddle.ones([attention_mask.shape[0], 1], dtype=attention_mask.dtype)], - axis=-1, - ) - elif len(attention_mask.shape) == 4: - model_kwargs["attention_mask"] = paddle.cat( - [attention_mask, paddle.ones([*attention_mask.shape[:3], 1], dtype=attention_mask.dtype)], - axis=-1, - )[:, :, -1:, :] - - return model_kwargs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) - return reordered_past - - def auto_dist_config(self, prefix=""): - if prefix != "": - assert prefix.endswith(".") - config = { - "mp_config": { - "parallelize_plan": { - f"{prefix}deepseek_v2.embed_tokens": dist.ColWiseParallel(gather_output=True), - f"{prefix}deepseek_v2.layers.*.self_attn.q_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v2.layers.*.self_attn.kv_b_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v2.layers.*.self_attn.o_proj": dist.RowWiseParallel(), - f"{prefix}deepseek_v2.layers.*.mlp.gate_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v2.layers.*.mlp.up_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v2.layers.*.mlp.down_proj": dist.RowWiseParallel(), - f"{prefix}lm_head.weight": dist.ColWiseParallel(), - } - }, - } - return config diff --git a/paddleformers/transformers/deepseek_v2/modeling_pp.py b/paddleformers/transformers/deepseek_v2/modeling_pp.py deleted file mode 100644 index 4fc4ac92041..00000000000 --- a/paddleformers/transformers/deepseek_v2/modeling_pp.py +++ /dev/null @@ -1,501 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import OrderedDict, Tuple, Union - -import paddle -import paddle.distributed.fleet as fleet -import paddle.nn as nn -from paddle.distributed.fleet.meta_parallel import ( - LayerDesc, - PipelineLayer, - SharedLayerDesc, -) -from paddle.distributed.fleet.recompute.recompute import recompute -from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp - -from ...utils.tools import get_env_device -from ..model_utils import PipelinePretrainedModel -from .modeling import ( - DeepseekV2Config, - DeepseekV2DecoderLayer, - DeepseekV2LMHead, - DeepseekV2Model, - DeepseekV2MTPLayer, - DeepseekV2PretrainedModel, - DeepseekV2PretrainingCriterion, - DeepseekV2RMSNorm, -) - -__all__ = [ - "DeepseekV2ForCausalLMPipe", -] - - -def parse_args(args): - if isinstance(args, tuple): - if len(args) == 4: - hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = args - - elif len(args) == 3: - hidden_states, attention_mask, attn_mask_startend_row_indices = args - position_ids = None - elif len(args) == 2: - hidden_states, attention_mask = args - attn_mask_startend_row_indices, position_ids = None, None - else: - hidden_states = args - attention_mask, attn_mask_startend_row_indices, position_ids = None, None, None - - if position_ids is not None: - position_ids.stop_gradient = True - - if attention_mask is not None: - attention_mask.stop_gradient = True - - if attn_mask_startend_row_indices is not None: - attn_mask_startend_row_indices.stop_gradient = True - - return hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids - - -def return_args(hidden_states, attention_mask=None, attn_mask_startend_row_indices=None, position_ids=None): - ret = (hidden_states,) - - if attention_mask is not None: - ret += (attention_mask.clone(),) - if attn_mask_startend_row_indices is not None: - ret += (attn_mask_startend_row_indices.clone(),) - if position_ids is not None: - ret += (position_ids.clone(),) - if len(ret) == 1: - ret = ret[0] - - return ret - - -def get_attr(layer, name): - if getattr(layer, name, None) is not None: - return getattr(layer, name, None) - else: - return get_attr(layer._layer, name) - - -class DeepseekV2EmbeddingPipe(nn.Layer): - def __init__(self, config: DeepseekV2Config): - super(DeepseekV2EmbeddingPipe, self).__init__() - self.config = config - self.sequence_parallel = config.sequence_parallel - self.hidden_size = config.hidden_size - if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: - self.embed_tokens = fleet.meta_parallel.VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()), - ) - else: - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) - - @property - def embedding_weight(self): - return get_attr(self.embed_tokens, "weight") - - def forward(self, args): - """_summary_ - - Args: - input (_type_): _description_ - - Returns: - _type_: _description_ - """ - input_ids, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) - inputs_embeds = self.embed_tokens(input_ids) - - batch_size, seq_length = input_ids.shape - if self.config.num_nextn_predict_layers > 0: - seq_length -= self.config.num_nextn_predict_layers - - if attention_mask is not None: - attention_mask = attention_mask[ - :, :, : -self.config.num_nextn_predict_layers, : -self.config.num_nextn_predict_layers - ] - - if attention_mask is not None: - assert ( - attn_mask_startend_row_indices is None - ), "attention_mask and attn_mask_startend_row_indices can not be set at same time" - - attention_mask = DeepseekV2Model._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), 0, inputs_embeds.dtype - ) - attention_mask.stop_gradient = True - if get_env_device() == "npu": - attention_mask = attention_mask.astype("bool") - elif get_env_device() == "npu": - attention_mask = paddle.tril(paddle.ones((seq_length, seq_length), dtype="bool")) - attention_mask.stop_gradient = True - - if self.config.num_nextn_predict_layers > 0: - inputs_embeds_extra = inputs_embeds[:, -self.config.num_nextn_predict_layers :, :] # [B, S, D] - inputs_embeds = inputs_embeds[:, : -self.config.num_nextn_predict_layers, :] - inputs_embeds_ori = inputs_embeds - batch_size, seq_length, _ = inputs_embeds.shape - - if self.sequence_parallel: - # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] - inputs_embeds = paddle.reshape(inputs_embeds, [-1, inputs_embeds.shape[-1]]) - # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) - inputs_embeds = ScatterOp.apply(inputs_embeds) - embeds_res = [inputs_embeds] - for depth in range(self.config.num_nextn_predict_layers): - inputs_embeds_mtp = paddle.cat( - [ - inputs_embeds_ori[:, (depth + 1) :, :], - inputs_embeds_extra[:, : (depth + 1), :], - ], - axis=1, - ) - if self.sequence_parallel: - inputs_embeds_mtp = inputs_embeds_mtp.reshape([-1, inputs_embeds_mtp.shape[-1]]) - inputs_embeds_mtp = ScatterOp.apply(inputs_embeds_mtp) - embeds_res.append(inputs_embeds_mtp) - # if not self.sequence_parallel - # mtp_embeds: [B*num_nextn_predict_layers, seq_len, hidden_size] - # else: - # mtp_embeds: [B*seq_len*num_nextn_predict_layers, hidden_size] - inputs_embeds = paddle.cat(embeds_res, axis=-1) - return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids) - else: - if self.sequence_parallel: - inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[-1]]) - inputs_embeds = ScatterOp.apply(inputs_embeds) - return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids) - - -class DeepseekV2DecoderLayerPipe(DeepseekV2DecoderLayer): - def forward(self, args): - hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) - - if self.config.num_nextn_predict_layers > 0: - batch_size, _, hidden_size = hidden_states.shape - batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1) - inputs_embeds_mtp = hidden_states[..., -batch_size_mtp:] - hidden_states = hidden_states[..., :batch_size_mtp] - - has_gradient = not hidden_states.stop_gradient - - if attention_mask is not None and attention_mask.dtype == paddle.int32: - attention_mask, attn_mask_startend_row_indices, position_ids = ( - None, - attention_mask, - attn_mask_startend_row_indices, - ) - elif attention_mask is not None and attention_mask.dtype == paddle.int64: - attention_mask, attn_mask_startend_row_indices, position_ids = None, None, attention_mask - elif attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.dtype == paddle.int64: - attn_mask_startend_row_indices, position_ids = None, attn_mask_startend_row_indices - - if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: - if attention_mask is not None or attn_mask_startend_row_indices is not None: - hidden_states = recompute( - super().forward, - hidden_states, - position_ids=position_ids, - attention_mask=attention_mask, - attn_mask_startend_row_indices=attn_mask_startend_row_indices, - use_reentrant=False, - ) - else: - # for pretrain - hidden_states = recompute( - super().forward, - hidden_states, - position_ids=position_ids, - attn_mask_startend_row_indices=attn_mask_startend_row_indices, - use_reentrant=self.config.recompute_use_reentrant, - ) - else: - hidden_states = super().forward( - hidden_states, - position_ids=position_ids, - attention_mask=attention_mask, - attn_mask_startend_row_indices=attn_mask_startend_row_indices, - ) - - if self.config.num_nextn_predict_layers > 0: - hidden_states = paddle.cat([hidden_states, inputs_embeds_mtp], axis=-1) - - return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids) - - -class DeepseekV2MTPLayerPipe(DeepseekV2MTPLayer): - def forward(self, args): - hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) - - hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1) - hidden_states_main_model = hidden_states_list[0] - inputs_embeds_cur_depth_list = hidden_states_list[1:] - has_gradient = not hidden_states_main_model.stop_gradient - - if attention_mask is not None and attention_mask.dtype == paddle.int32: - attention_mask, attn_mask_startend_row_indices, position_ids = ( - None, - attention_mask, - attn_mask_startend_row_indices, - ) - elif attention_mask is not None and attention_mask.dtype == paddle.int64: - attention_mask, attn_mask_startend_row_indices, position_ids = None, None, attention_mask - elif attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.dtype == paddle.int64: - attn_mask_startend_row_indices, position_ids = None, attn_mask_startend_row_indices - - output_list = [hidden_states_main_model] - hidden_states = hidden_states_main_model - for depth in range(self.config.num_nextn_predict_layers): - inputs_embeds_cur_depth = inputs_embeds_cur_depth_list[depth] - if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: - if attention_mask is not None or attn_mask_startend_row_indices is not None: - hidden_states = recompute( - super().forward, - hidden_states, - inputs_embeds_cur_depth, - position_ids=position_ids, - attention_mask=attention_mask, - attn_mask_startend_row_indices=attn_mask_startend_row_indices, - use_reentrant=False, - ) - else: - # for pretrain - hidden_states = recompute( - super().forward, - hidden_states, - inputs_embeds_cur_depth, - position_ids=position_ids, - attn_mask_startend_row_indices=attn_mask_startend_row_indices, - use_reentrant=self.config.recompute_use_reentrant, - ) - else: - hidden_states = super().forward( - hidden_states, - inputs_embeds_cur_depth, - position_ids=position_ids, - attention_mask=attention_mask, - attn_mask_startend_row_indices=attn_mask_startend_row_indices, - ) - output_list.append(hidden_states) - - hidden_states = paddle.cat(output_list, axis=-1) - return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids) - - -class DeepseekV2RMSNormPipe(nn.Layer): - def __init__(self, config): - super().__init__() - self.config = config - self.norm = DeepseekV2RMSNorm(config) - - def forward(self, args): - hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) - - if self.config.num_nextn_predict_layers > 0: - hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1) - hidden_states = hidden_states_list[0] - hidden_states_mtp = hidden_states_list[-self.config.num_nextn_predict_layers :] - - output_list = [self.norm(hidden_states)] - for hidden_states in hidden_states_mtp: - output_list.append(self.norm(hidden_states)) - return output_list - else: - return self.norm(hidden_states) - - -class DeepseekV2LMHeadPipe(DeepseekV2LMHead): - def __init__(self, config): - super(DeepseekV2LMHeadPipe, self).__init__(config) - - @property - def embedding_weight(self): - return get_attr(self, "weight") - - def forward(self, args: Union[Tuple, paddle.Tensor]): - if self.config.num_nextn_predict_layers > 0: - logits = [] - for _hidden_states in args: - logits.append(super().forward(_hidden_states)) - return logits - hidden_states = args - logits = super().forward(hidden_states) - return logits - - -class DeepseekV2PretrainingCriterionPipe(DeepseekV2PretrainingCriterion): - def forward(self, logits, labels): - if self.config.num_nextn_predict_layers > 0: - mtp_logits = logits[1:] - logits = logits[0] - loss = super().forward(logits, labels, mtp_logits=mtp_logits) - else: - loss = super().forward(logits, labels) - return loss - - -class DeepseekV2ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): - """DeepseekV2ForPretraining adapted for pipeline parallelism. - - The largest change is flattening the DeepseekV2Model class so we can express it as a - sequence of layers including embedding, transformer layers, and output. - """ - - config_class = DeepseekV2Config - _base_model = DeepseekV2PretrainedModel - _get_tensor_parallel_mappings = DeepseekV2PretrainedModel._get_tensor_parallel_mappings - _init_weights = DeepseekV2PretrainedModel._init_weights - _keys_to_ignore_on_load_unexpected = DeepseekV2PretrainedModel._keys_to_ignore_on_load_unexpected - _get_model_flops = DeepseekV2PretrainedModel._get_model_flops - _get_hardware_flops = DeepseekV2PretrainedModel._get_hardware_flops - - _tied_weights_keys = ["lm_head.weight"] - - # DONOT Add base_model_prefix !!!! - - @classmethod - def _prepare_pipeline_inputs_func(cls, inputs): - first_stage_keys = ["input_ids", "attention_mask", "attn_mask_startend_row_indices", "position_ids"] - last_stage_keys = ["labels"] - - def get_expected_keys(inputs, keys): - ret = tuple([inputs.pop(k) if k in inputs else None for k in keys]) - if len(ret) == 1: - ret = ret[0] - return ret - - if type(inputs) is dict or type(inputs) is OrderedDict: - return [ - get_expected_keys(inputs, first_stage_keys), - get_expected_keys(inputs, last_stage_keys), - ] - - keys = list(inputs[0].keys()) - inputs_batch = {key: [data.pop(key) for data in inputs] for key in keys} - return [ - get_expected_keys(inputs_batch, first_stage_keys), - get_expected_keys(inputs_batch, last_stage_keys), - ] - - def __init__(self, config: DeepseekV2Config): - self.config = config - - # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True - # Enable_recompute defaults to False and is controlled by Trainer - self.enable_recompute = False - self.recompute_granularity = self.config.recompute_granularity - self.pp_recompute_interval = self.config.pp_recompute_interval - self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] - if self.recompute_granularity == "full": - assert len(self.no_recompute_layers) == 0, "for pp with full recompute, no_recompute_layers is not support" - - virtual_pp_degree = getattr(self.config, "virtual_pp_degree", 1) - - def get_hcg(): - return fleet.get_hybrid_communicate_group() - - hcg = get_hcg() - tensor_parallel_degree = max(hcg.get_model_parallel_world_size(), 1) - tensor_parallel_rank = max(hcg.get_model_parallel_rank(), 0) - - # TODO: fix tensor_parallel_degree rewrite in here - config.tensor_parallel_degree = tensor_parallel_degree - config.tensor_parallel_rank = tensor_parallel_rank - - if config.tie_word_embeddings: - self.add_sequential_layer( - SharedLayerDesc( - "DeepseekV2_shared_weight", - DeepseekV2EmbeddingPipe, - shared_weight_attr="embedding_weight", - config=config, - ), - self._base_model.base_model_prefix, - ) - else: - self.add_sequential_layer( - LayerDesc(DeepseekV2EmbeddingPipe, config=config), self._base_model.base_model_prefix - ) - - for i in range(config.num_hidden_layers): - self.add_sequential_layer( - LayerDesc( - DeepseekV2DecoderLayerPipe, - config=config, - layer_idx=i, - layerwise_recompute=i not in self.no_recompute_layers, - ), - f"{self._base_model.base_model_prefix}.layers.{i}", - ) - for i in range(config.num_nextn_predict_layers): - self.add_sequential_layer( - LayerDesc(DeepseekV2MTPLayerPipe, config=config, layer_idx=config.num_hidden_layers + i), - f"{self._base_model.base_model_prefix}.layers.{config.num_hidden_layers + i}", - ) - - self.add_sequential_layer(LayerDesc(DeepseekV2RMSNormPipe, config=config), self._base_model.base_model_prefix) - - if config.tie_word_embeddings: - self.add_sequential_layer( - SharedLayerDesc( - "DeepseekV2_shared_weight", - DeepseekV2LMHeadPipe, - shared_weight_attr="embedding_weight", - config=config, - **{"transpose_y": True}, - ), - "lm_head", - ) - else: - self.add_sequential_layer(LayerDesc(DeepseekV2LMHeadPipe, config=config), "lm_head") - - recompute_interval = 0 - if self.enable_recompute and self.recompute_granularity == "full": - assert self.config.pp_recompute_interval <= config.num_hidden_layers // ( - virtual_pp_degree * get_hcg().topology().get_dim_size("pipe") - ), "pp recompute interval should smaller than num layers of each pp chunk" - recompute_interval = self.config.pp_recompute_interval - - seg_method = "layer:DeepseekV2DecoderLayer|DeepseekV2MTPLayerPipe" - if config.num_hidden_layers % get_hcg().topology().get_dim_size("pipe") != 0: - seg_method = "uniform" - - PipelineLayer.__init__( - self, - layers=self.get_sequential_layers(), - loss_fn=self.get_loss_fn(config), - topology=get_hcg().topology(), - seg_method=seg_method, - recompute_interval=recompute_interval, - recompute_ctx={ - "mp_group": get_hcg().get_model_parallel_group(), - "offload": False, - "partition": False, - }, - num_virtual_pipeline_stages=virtual_pp_degree, - ) - # You should call init here, since there is a diamond inheritance problem - self.apply(self._init_weights) - # DON'T init PipelinePretrainedModel - # PipelinePretrainedModel.__init__(self.super(), config=config) - - def get_loss_fn(self, config): - return DeepseekV2PretrainingCriterionPipe(config) diff --git a/paddleformers/transformers/deepseek_v3/__init__.py b/paddleformers/transformers/deepseek_v3/__init__.py index c2f5412a8ae..7866cf44896 100644 --- a/paddleformers/transformers/deepseek_v3/__init__.py +++ b/paddleformers/transformers/deepseek_v3/__init__.py @@ -24,6 +24,7 @@ "DeepseekV3ForSequenceClassification", "DeepseekV3Model", "DeepseekV3PretrainedModel", + "DeepseekV3ForCausalLMPipe", ], "modeling_auto": [ "DeepseekV3LMHeadAuto", @@ -31,13 +32,11 @@ "DeepseekV3ModelAuto", "DeepseekV3PretrainedModelAuto", ], - "modeling_pp": ["DeepseekV3ForCausalLMPipe"], } if TYPE_CHECKING: from .configuration import * from .modeling import * - from .modeling_auto import * from .modeling_pp import * else: sys.modules[__name__] = _LazyModule( diff --git a/paddleformers/transformers/deepseek_v3/modeling.py b/paddleformers/transformers/deepseek_v3/modeling.py index 51c0d1978fe..a637b8a6715 100644 --- a/paddleformers/transformers/deepseek_v3/modeling.py +++ b/paddleformers/transformers/deepseek_v3/modeling.py @@ -25,12 +25,13 @@ import paddle +from ...nn.criterion.interface import CriterionLayer +from ...nn.lm_head import LMHead as GeneralLMHead from ..deepseek_v2.modeling import ( + DeepseekV2ForCausalLMPipe, DeepseekV2ForSequenceClassification, - DeepseekV2LMHead, DeepseekV2Model, DeepseekV2PretrainedModel, - DeepseekV2PretrainingCriterion, ) from ..model_outputs import CausalLMOutputWithPast from ..model_utils import register_base_model @@ -46,8 +47,20 @@ class DeepseekV3PretrainedModel(DeepseekV2PretrainedModel): config_class = DeepseekV3Config - base_model_prefix = "deepseek_v3" + base_model_prefix = "model" _no_split_modules = ["DeepseekV2DecoderLayer"] + transpose_weight_keys = [ + "kv_a_proj_with_mqa", + "kv_b_proj", + "o_proj", + "q_a_proj", + "q_b_proj", + "gate_proj", + "up_proj", + "down_proj", + "gate", + "eh_proj", + ] @register_base_model @@ -61,16 +74,16 @@ class DeepseekV3ForCausalLM(DeepseekV3PretrainedModel): def __init__(self, config: DeepseekV3Config): super().__init__(config) - self.deepseek_v3 = DeepseekV3Model(config) + self.model = DeepseekV3Model(config) self.vocab_size = config.vocab_size - self.lm_head = DeepseekV2LMHead(config) - self.criterion = DeepseekV2PretrainingCriterion(config) + self.lm_head = GeneralLMHead(config) + self.criterion = CriterionLayer(config) def get_input_embeddings(self): - return self.deepseek_v3.embed_tokens + return self.model.embed_tokens def set_input_embeddings(self, value): - self.deepseek_v3.embed_tokens = value + self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head @@ -79,10 +92,10 @@ def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): - self.deepseek_v3 = decoder + self.model = decoder def get_decoder(self): - return self.deepseek_v3 + return self.model def forward( self, @@ -129,7 +142,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.deepseek_v3( + outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -168,3 +181,22 @@ def forward( class DeepseekV3ForSequenceClassification(DeepseekV2ForSequenceClassification): def __init__(self, config): super().__init__(config) + + +class DeepseekV3ForCausalLMPipe(DeepseekV2ForCausalLMPipe): + """DeepseekV2ForPretraining adapted for pipeline parallelism. + + The largest change is flattening the DeepseekV2Model class so we can express it as a + sequence of layers including embedding, transformer layers, and output. + """ + + config_class = DeepseekV3Config + _base_model = DeepseekV3PretrainedModel + _get_tensor_parallel_mappings = DeepseekV3PretrainedModel._get_tensor_parallel_mappings + _init_weights = DeepseekV3PretrainedModel._init_weights + _keys_to_ignore_on_load_unexpected = DeepseekV3PretrainedModel._keys_to_ignore_on_load_unexpected + _get_model_flops = DeepseekV3PretrainedModel._get_model_flops + _get_hardware_flops = DeepseekV3PretrainedModel._get_hardware_flops + _tied_weights_keys = ["lm_head.weight"] + base_model_prefix = DeepseekV3PretrainedModel.base_model_prefix + transpose_weight_keys = DeepseekV3PretrainedModel.transpose_weight_keys diff --git a/paddleformers/transformers/deepseek_v3/modeling_auto.py b/paddleformers/transformers/deepseek_v3/modeling_auto.py deleted file mode 100644 index 747efbc59d3..00000000000 --- a/paddleformers/transformers/deepseek_v3/modeling_auto.py +++ /dev/null @@ -1,209 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# Copyright (c) 2023 DeepSeek. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Paddle DeepSeek_V3 model.""" - -from __future__ import annotations - -from typing import List, Optional, Tuple, Union - -import paddle -import paddle.distributed as dist - -try: - from paddle.incubate.nn.functional import fused_rotary_position_embedding -except ImportError: - fused_rotary_position_embedding = None - -try: - from paddle.nn.functional.flash_attention import flash_attention -except: - flash_attention = None - -from ...utils.log import logger -from ..deepseek_v2.modeling_auto import ( - DeepseekV2LMHeadAuto, - DeepseekV2ModelAuto, - DeepseekV2PretrainedModelAuto, - DeepseekV2PretrainingCriterion, -) -from ..model_outputs import CausalLMOutputWithPast -from ..model_utils import register_base_model -from .configuration import DeepseekV2Config - -__all__ = [ - "DeepseekV3LMHeadAuto", - "DeepseekV3ForCausalLMAuto", - "DeepseekV3ModelAuto", - "DeepseekV3PretrainedModelAuto", -] - - -class DeepseekV3PretrainedModelAuto(DeepseekV2PretrainedModelAuto): - config_class = DeepseekV2Config - base_model_prefix = "deepseek_v3" - _no_split_modules = ["DeepseekV2DecoderLayerAuto"] - - -@register_base_model -class DeepseekV3ModelAuto(DeepseekV2ModelAuto): - def __init__(self, config: DeepseekV2Config): - super().__init__(config) - - -class DeepseekV3LMHeadAuto(DeepseekV2LMHeadAuto): - def __init__(self, config: DeepseekV2Config): - super().__init__(config) - - -class DeepseekV3ForCausalLMAuto(DeepseekV3PretrainedModelAuto): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config: DeepseekV2Config): - super().__init__(config) - self.config = config - self.deepseek_v3 = DeepseekV3ModelAuto(config) - self.vocab_size = config.vocab_size - self.lm_head = DeepseekV3LMHeadAuto(config) - self.criterion = DeepseekV2PretrainingCriterion(config) - - def get_input_embeddings(self): - return self.deepseek_v3.embed_tokens - - def set_input_embeddings(self, value): - self.deepseek_v3.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.deepseek_v3 = decoder - - def get_decoder(self): - return self.deepseek_v3 - - def forward( - self, - input_ids: paddle.Tensor = None, - position_ids: Optional[paddle.Tensor] = None, - attention_mask: Optional[paddle.Tensor] = None, - inputs_embeds: Optional[paddle.Tensor] = None, - labels: Optional[paddle.Tensor] = None, - use_cache: Optional[bool] = None, - past_key_values: Optional[List[paddle.Tensor]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - attn_mask_startend_row_indices=None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLMAuto - - >>> model = DeepseekV3ForCausalLMAuto.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - input_ids.stop_gradient = True - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if attn_mask_startend_row_indices is not None and attention_mask is not None: - logger.warning( - "You have provided both attn_mask_startend_row_indices and attention_mask. " - "The attn_mask_startend_row_indices will be used." - ) - attention_mask = None - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.deepseek_v3( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - past_key_values=past_key_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - attn_mask_startend_row_indices=attn_mask_startend_row_indices, - ) - - hidden_states = outputs[0] - mtp_outputs = outputs[-1] - - # if labels is None,means we need full output, instead of tensor_parallel_output - # tensor_parallel_output is together with ParallelCrossEntropy - tensor_parallel_output = self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1 - - logits = self.lm_head(hidden_states, tensor_parallel_output=tensor_parallel_output) - mtp_logits = [self.lm_head(_hidden_states) for _hidden_states in mtp_outputs] if len(mtp_outputs) > 0 else [] - - return self.criterion(logits, labels, mtp_logits=mtp_logits) - - def auto_dist_config(self, prefix=""): - if prefix != "": - assert prefix.endswith(".") - config = { - "dp_config": {"sharding_level": 0, "offload": False, "exclude_layer": None}, - "pp_config": { - "split_spec": [f"{prefix}deepseek_v3.layers", f"{prefix}lm_head"], - "global_spec": "deepseek_v3.global_layer", - }, - "mp_config": { - "parallelize_plan": { - f"{prefix}deepseek_v3.embed_tokens": dist.ColWiseParallel(gather_output=True), - f"{prefix}deepseek_v3.layers.*.self_attn.q_b_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v3.layers.*.self_attn.q_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v3.layers.*.self_attn.kv_b_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v3.layers.*.self_attn.o_proj": dist.RowWiseParallel(), - f"{prefix}deepseek_v3.layers.*.mlp.gate_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v3.layers.*.mlp.up_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v3.layers.*.mlp.down_proj": dist.RowWiseParallel(), - f"{prefix}deepseek_v3.layers.*.mlp.shared_experts.gate_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v3.layers.*.mlp.shared_experts.up_proj": dist.ColWiseParallel(), - f"{prefix}deepseek_v3.layers.*.mlp.shared_experts.down_proj": dist.RowWiseParallel(), - f"{prefix}lm_head.weight": dist.ColWiseParallel(), - } - }, - } - return config diff --git a/paddleformers/transformers/deepseek_v3/modeling_pp.py b/paddleformers/transformers/deepseek_v3/modeling_pp.py deleted file mode 100644 index d8e90c2b9fc..00000000000 --- a/paddleformers/transformers/deepseek_v3/modeling_pp.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from ..deepseek_v2.modeling_pp import DeepseekV2ForCausalLMPipe -from .configuration import DeepseekV3Config -from .modeling import DeepseekV3PretrainedModel - -__all__ = [ - "DeepseekV3ForCausalLMPipe", -] - - -class DeepseekV3ForCausalLMPipe(DeepseekV2ForCausalLMPipe): - """DeepseekV2ForPretraining adapted for pipeline parallelism. - - The largest change is flattening the DeepseekV2Model class so we can express it as a - sequence of layers including embedding, transformer layers, and output. - """ - - config_class = DeepseekV3Config - _base_model = DeepseekV3PretrainedModel - _get_tensor_parallel_mappings = DeepseekV3PretrainedModel._get_tensor_parallel_mappings - _init_weights = DeepseekV3PretrainedModel._init_weights - _keys_to_ignore_on_load_unexpected = DeepseekV3PretrainedModel._keys_to_ignore_on_load_unexpected - _get_model_flops = DeepseekV3PretrainedModel._get_model_flops - _get_hardware_flops = DeepseekV3PretrainedModel._get_hardware_flops - _tied_weights_keys = ["lm_head.weight"] - base_model_prefix = DeepseekV3PretrainedModel.base_model_prefix diff --git a/paddleformers/transformers/moe_gate.py b/paddleformers/transformers/moe_gate.py index d666515c44d..ca704d41076 100644 --- a/paddleformers/transformers/moe_gate.py +++ b/paddleformers/transformers/moe_gate.py @@ -313,6 +313,9 @@ def _topk_noaux_tc( ) # [n, e] tmp_scores = scores_for_choice * score_mask # [n, e] topk_weight, topk_idx = paddle.topk(tmp_scores, k=k, axis=-1, sorted=True) + + # The bias term b is used only to adjust affinity scores for Top-K expert selection (routing); it does not affect gating. + # The gate applied during dispatch and to weight the FFN output is computed from the original affinity score s_{i,t} (without the bias). topk_weight = scores.take_along_axis(topk_idx, axis=1) if not self.training else topk_weight return topk_weight, topk_idx @@ -497,7 +500,7 @@ def topkgating( top_gate = top_gate * self.routed_scaling_factor # get topk mask - mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0, dtype="float32"), axis=1) + mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0, dtype=gates.dtype), axis=1) if hasattr(self.config, "seq_aux") and self.config.seq_aux: l_aux = self._cal_seq_aux_loss(gates_ori, self.top_k, top_idx) else: diff --git a/paddleformers/transformers/moe_layer.py b/paddleformers/transformers/moe_layer.py index da5392e57d4..692f5fc4c41 100644 --- a/paddleformers/transformers/moe_layer.py +++ b/paddleformers/transformers/moe_layer.py @@ -175,13 +175,13 @@ def __init__( is_fleet_init = True except AttributeError: is_fleet_init = False - - if ( - is_fleet_init - and dist.fleet.get_hybrid_communicate_group().get_data_parallel_world_size() > 1 - and moe_group == "data" - ): - self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + if is_fleet_init and dist.get_world_size() > 1: + if moe_group == "data": + self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + elif moe_group == "expert": + self.moe_group = dist.fleet.get_hybrid_communicate_group().get_expert_parallel_group() + else: + assert NotImplementedError("moe_group can only be data or expert, but given {}".format(self.moe_group)) self.moe_rank = dist.get_rank(self.moe_group) self.moe_rank = 0 if self.moe_rank < 0 else self.moe_rank self.expert_parallel_degree = dist.get_world_size(self.moe_group) @@ -197,7 +197,6 @@ def __init__( self.expert_parallel_degree = 1 self.moe_num_experts_per_device = self.moe_num_experts self.is_dummy_moe = True - self.all_to_all_dropout = all_to_all_dropout self.enable_recompute = False @@ -382,7 +381,6 @@ def expert_forward(self, dispatched_input, tokens_per_expert): chunks = paddle.split(dispatched_input, num_or_sections=tokens_per_expert, axis=0) for i, chunk in enumerate(chunks): chunk = chunk.contiguous() - # assert chunk.shape[0] != 0, "Cannot dispatch empty input" expert = self.experts[i + self.moe_rank * self.moe_num_experts_per_device] outputs += [expert(chunk)] diff --git a/paddleformers/trl/model_config.py b/paddleformers/trl/model_config.py index e83ae297cf6..a42018f704b 100644 --- a/paddleformers/trl/model_config.py +++ b/paddleformers/trl/model_config.py @@ -153,3 +153,5 @@ class ModelConfig: pp_seg_method: Optional[str] = field( default="layer:DecoderLayer|EmptyLayer", metadata={"help": "PP Segmentation Method"} ) + using_fake_gate: bool = field(default=False, metadata={"help": "Whether to fake gate"}) + aux_loss_alpha: float = field(default=0.0001, metadata={"help": "aux_loss_alpha"})