diff --git a/examples/alignment/dpo/run_dpo.py b/examples/alignment/dpo/run_dpo.py index 40e66380f5..c34c0e01d5 100644 --- a/examples/alignment/dpo/run_dpo.py +++ b/examples/alignment/dpo/run_dpo.py @@ -144,6 +144,11 @@ def main(): dtype=dtype, ) model_config._attn_implementation = model_args.attn_impl + if training_args.use_expert_parallel and training_args.expert_parallel_degree >= 1: + model_config.n_group = training_args.expert_parallel_degree + model_config.use_fused_head_and_loss_fn = model_args.use_fused_head_and_loss_fn + model_config.use_filtered_label_loss = training_args.use_filtered_label_loss + model_config.loss_subbatch_sequence_length = training_args.loss_subbatch_sequence_length LlmMetaConfig.set_llm_config(model_config, training_args) diff --git a/examples/config/dpo_full.yaml b/examples/config/dpo_full.yaml index 3669198a62..dc9be60647 100644 --- a/examples/config/dpo_full.yaml +++ b/examples/config/dpo_full.yaml @@ -47,3 +47,6 @@ recompute: true bf16: true fp16_opt_level: O2 unified_checkpoint: true +use_fused_head_and_loss_fn: false +use_filtered_label_loss: false +loss_subbatch_sequence_length: -1 \ No newline at end of file diff --git a/examples/config/dpo_lora.yaml b/examples/config/dpo_lora.yaml index 127e6af6b3..1306ab794e 100644 --- a/examples/config/dpo_lora.yaml +++ b/examples/config/dpo_lora.yaml @@ -49,3 +49,6 @@ recompute: true bf16: true fp16_opt_level: O2 unified_checkpoint: true +use_fused_head_and_loss_fn: false +use_filtered_label_loss: false +loss_subbatch_sequence_length: -1 \ No newline at end of file diff --git a/examples/config/sft_full.yaml b/examples/config/sft_full.yaml index c4ad3965a1..a604fa1db7 100644 --- a/examples/config/sft_full.yaml +++ b/examples/config/sft_full.yaml @@ -47,3 +47,6 @@ recompute: true bf16: true fp16_opt_level: O2 unified_checkpoint: true +use_fused_head_and_loss_fn: false +use_filtered_label_loss: false +loss_subbatch_sequence_length: -1 \ No newline at end of file diff --git a/examples/config/sft_lora.yaml b/examples/config/sft_lora.yaml index 8c10c6371d..f7abe2a1f2 100644 --- a/examples/config/sft_lora.yaml +++ b/examples/config/sft_lora.yaml @@ -49,3 +49,6 @@ recompute: true bf16: true fp16_opt_level: O2 unified_checkpoint: true +use_fused_head_and_loss_fn: false +use_filtered_label_loss: false +loss_subbatch_sequence_length: -1 \ No newline at end of file diff --git a/examples/run_finetune.py b/examples/run_finetune.py index f72edb0391..2dcaaa02e6 100644 --- a/examples/run_finetune.py +++ b/examples/run_finetune.py @@ -134,6 +134,11 @@ def main(): model_config.max_sequence_length = training_args.max_seq_len model_config.num_nextn_predict_layers = model_args.num_nextn_predict_layers model_config._attn_implementation = model_args.attn_impl + if training_args.use_expert_parallel and training_args.expert_parallel_degree >= 1: + model_config.n_group = training_args.expert_parallel_degree + model_config.use_fused_head_and_loss_fn = model_args.use_fused_head_and_loss_fn + model_config.use_filtered_label_loss = training_args.use_filtered_label_loss + model_config.loss_subbatch_sequence_length = training_args.loss_subbatch_sequence_length logger.info(f"Final model config: {model_config}") logger.info("Creating model") diff --git a/paddleformers/nn/lm_head.py b/paddleformers/nn/lm_head.py index 8b2f81cbd0..c88afa5391 100644 --- a/paddleformers/nn/lm_head.py +++ b/paddleformers/nn/lm_head.py @@ -93,7 +93,7 @@ def forward(self, hidden_states, tensor_parallel_output=None): hidden_states, self.weight, self.bias, - self.config.tie_word_embeddings, + True, ) return calc_lm_head_logits( diff --git a/paddleformers/transformers/glm4_moe/modeling.py b/paddleformers/transformers/glm4_moe/modeling.py index de40038ac6..fe11370dde 100644 --- a/paddleformers/transformers/glm4_moe/modeling.py +++ b/paddleformers/transformers/glm4_moe/modeling.py @@ -34,7 +34,6 @@ from ...nn.norm import Norm as GeneralNorm from ...nn.pp_model import GeneralModelForCausalLMPipe from ...utils.log import logger -from ..llama.modeling import get_use_casual_mask from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ..model_utils import PretrainedModel, register_base_model from ..moe_gate import PretrainedMoEGate @@ -1013,27 +1012,18 @@ def forward( # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) inputs_embeds = ScatterOp.apply(inputs_embeds) - if attn_mask_startend_row_indices is not None or get_use_casual_mask(): - attention_mask = None - else: - # [bs, seq_len] - attention_mask = ( - paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) - if attention_mask is None - else attention_mask - ) + hidden_states = inputs_embeds + if attention_mask is not None: causal_mask = self._prepare_decoder_attention_mask( - attention_mask=attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=cache_length, - dtype=inputs_embeds.dtype, - ) # [bs, 1, seq_len, seq_len] + attention_mask, hidden_states.shape[:2], cache_length, hidden_states.dtype + ) + else: + causal_mask = None if position_ids is None: position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) - hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers @@ -1205,6 +1195,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, ) hidden_states = outputs[0] # [bs, seq_len, dim] diff --git a/paddleformers/trl/model_config.py b/paddleformers/trl/model_config.py index e83ae297cf..e7427fc1eb 100644 --- a/paddleformers/trl/model_config.py +++ b/paddleformers/trl/model_config.py @@ -153,3 +153,6 @@ class ModelConfig: pp_seg_method: Optional[str] = field( default="layer:DecoderLayer|EmptyLayer", metadata={"help": "PP Segmentation Method"} ) + use_fused_head_and_loss_fn: bool = field( + default=False, metadata={"help": "Whether to use fused head and loss function"} + ) diff --git a/tests/nn/test_lm_head.py b/tests/nn/test_lm_head.py index 781ac5fd6e..eea4cac139 100644 --- a/tests/nn/test_lm_head.py +++ b/tests/nn/test_lm_head.py @@ -60,7 +60,7 @@ def test_forward_fused_head_loss(self): self.assertEqual(output[0].shape, test_input.shape) self.assertEqual(output[1].shape, lm_head.weight.shape) self.assertIs(output[2], lm_head.bias) - self.assertEqual(output[3], config.tie_word_embeddings) + self.assertEqual(output[3], True) if __name__ == "__main__":