diff --git a/examples/README.md b/examples/README.md index 6db135c5bdc..5af8347d130 100644 --- a/examples/README.md +++ b/examples/README.md @@ -20,18 +20,39 @@ wget https://bj.bcebos.com/paddlenlp/datasets/examples/alpaca_demo.gz tar -xvf alpaca_demo.gz ``` +### 模型下载 +```bash +# PaddleNLP/Qwen2-0.5B-Instruct +aistudio download --model PaddleNLP/Qwen2-0.5B-Instruct --local_dir PaddleNLP/Qwen2-0.5B-Instruct + +# baidu/ERNIE-4.5-0.3B-PT +aistudio download --model PaddlePaddle/ERNIE-4.5-0.3B-PT --local_dir baidu/ERNIE-4.5-0.3B-PT + +# baidu/ERNIE-4.5-0.3B-PT +aistudio download --model PaddlePaddle/ERNIE-4.5-21B-A3B-PT --local_dir baidu/ERNIE-4.5-21B-A3B-PT +``` ### 全参精调:SFT 单卡 ```bash -# 需要12G显存左右 +# 微调Qwen2-0.5B-Instruct 需要12G显存左右 python -u run_finetune.py ./config/qwen/sft_argument_qwen2_0p5b.json + +# 微调ERNIE-4.5-0.3B-PT +python -u run_finetune.py ./config/ernie4_5/sft_argument_ernie4_5_0p3b.json ``` 多卡 ```bash +# SFT Qwen2-0.5B-Instruct python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" run_finetune.py ./config/qwen/sft_argument_qwen2_0p5b.json + +# SFT ERNIE-4.5-0.3B-PT +python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" run_finetune.py ./config/ernie4_5/sft_argument_ernie4_5_0p3b.json + +# SFT ERNIE-4.5-21B-A3B-PT +python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" run_finetune.py ./config/ernie4_5_moe/sft_argument_ernie4_5_21b_a3b.json ``` ### LoRA diff --git a/examples/config/ernie4_5/sft_argument_ernie4_5_0p3b.json b/examples/config/ernie4_5/sft_argument_ernie4_5_0p3b.json index d0d03113d98..f37308a0981 100644 --- a/examples/config/ernie4_5/sft_argument_ernie4_5_0p3b.json +++ b/examples/config/ernie4_5/sft_argument_ernie4_5_0p3b.json @@ -21,8 +21,6 @@ "max_steps": 100, "evaluation_strategy": "epoch", "save_strategy": "epoch", - "src_length": 1024, - "max_length": 2048, "bf16": true, "fp16_opt_level": "O2", "do_train": true, @@ -33,15 +31,17 @@ "metric_for_best_model": "accuracy", "recompute": true, "save_total_limit": 1, - "tensor_parallel_degree": 2, - "pipeline_parallel_degree": 2, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 1, "sharding": "stage2", "zero_padding": true, "flash_mask": true, "unified_checkpoint": true, "use_flash_attention": true, - "sequence_parallel": true, + "use_attn_mask_startend_row_indices": true, + "sequence_parallel": false, "report_to": "none", "convert_from_hf": true, + "save_to_hf": true, "pp_seg_method": "layer:DecoderLayer|EmptyLayer" } diff --git a/examples/config/ernie4_5_moe/sft_argument_ernie4_5_21b_a3b.json b/examples/config/ernie4_5_moe/sft_argument_ernie4_5_21b_a3b.json index caea3249d7c..689ecdbb9f2 100644 --- a/examples/config/ernie4_5_moe/sft_argument_ernie4_5_21b_a3b.json +++ b/examples/config/ernie4_5_moe/sft_argument_ernie4_5_21b_a3b.json @@ -21,8 +21,6 @@ "max_steps": 100, "evaluation_strategy": "epoch", "save_strategy": "epoch", - "src_length": 1024, - "max_length": 2048, "bf16": true, "fp16_opt_level": "O2", "do_train": true, @@ -43,5 +41,6 @@ "sequence_parallel": true, "report_to": "none", "convert_from_hf": true, + "save_to_hf": true, "pp_seg_method": "layer:DecoderLayer|EmptyLayer" } diff --git a/examples/run_finetune.py b/examples/run_finetune.py index 3845a7b8416..fb0b75ed7fb 100644 --- a/examples/run_finetune.py +++ b/examples/run_finetune.py @@ -153,6 +153,13 @@ def main(): logger.info(f"Final model config: {model_config}") logger.info("Creating model") + if model_args.flash_mask and model_args.use_attn_mask_startend_row_indices: + model_config._attn_implementation = "flashmask" + elif model_args.flash_mask and not model_args.use_attn_mask_startend_row_indices: + model_config._attn_implementation = "sdpa" + else: + model_config._attn_implementation = "eager" + model_class = AutoModelForCausalLM if training_args.pipeline_parallel_degree > 1: if data_args.eval_with_do_generation and training_args.do_eval: @@ -174,7 +181,6 @@ def main(): logger.warning("`flash_mask` must use with zero padding and flash attention.") data_args.zero_padding = True model.config.use_flash_attention = True - model.config._attn_implementation = "flashmask" if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list): raise NotImplementedError(f"{model.__class__} not support flash mask.") diff --git a/paddleformers/transformers/ernie4_5/modeling.py b/paddleformers/transformers/ernie4_5/modeling.py index 18aeb83eb13..9b8a73fcfe2 100644 --- a/paddleformers/transformers/ernie4_5/modeling.py +++ b/paddleformers/transformers/ernie4_5/modeling.py @@ -834,7 +834,8 @@ def forward( # Pretrain & Eval must have labels assert labels is not None - return self.criterion(logits, labels, loss_mask) + loss, _ = self.criterion(logits, labels, loss_mask) + return loss, logits class Ernie4_5ForCausalLMPipe(GeneralModelForCausalLMPipe): diff --git a/paddleformers/transformers/ernie4_5_moe/modeling.py b/paddleformers/transformers/ernie4_5_moe/modeling.py index 6c85a01c8a9..6e64221d3f4 100644 --- a/paddleformers/transformers/ernie4_5_moe/modeling.py +++ b/paddleformers/transformers/ernie4_5_moe/modeling.py @@ -1157,7 +1157,8 @@ def forward( # Pretrain & Eval must have labels assert labels is not None - return self.criterion(logits, labels, loss_mask, router_loss=router_loss, mtp_logits=mtp_logits) + loss, _ = self.criterion(logits, labels, loss_mask, router_loss=router_loss, mtp_logits=mtp_logits) + return loss, logits class Ernie4_5_MoeForCausalLMPipe(GeneralModelForCausalLMPipe): diff --git a/paddleformers/transformers/llama/fusion_ops.py b/paddleformers/transformers/llama/fusion_ops.py index 41b78936b62..3254854a9d5 100644 --- a/paddleformers/transformers/llama/fusion_ops.py +++ b/paddleformers/transformers/llama/fusion_ops.py @@ -248,8 +248,6 @@ def fusion_flash_attention( else: if attn_mask_startend_row_indices is not None: assert alibi is None, "flashmask_attention or flash_attention_with_sparse_mask not support alibi" - if len(attn_mask_startend_row_indices.shape) == 2: - attn_mask_startend_row_indices = paddle.unsqueeze(attn_mask_startend_row_indices, axis=1) if hasattr(F, "flashmask_attention"): attn_output = no_recompute( @@ -257,7 +255,7 @@ def fusion_flash_attention( query_states, key_states, value_states, - startend_row_indices=attn_mask_startend_row_indices.unsqueeze(-1), + startend_row_indices=attn_mask_startend_row_indices, causal=True, enable=skip_recompute, ) diff --git a/paddleformers/utils/masking_utils.py b/paddleformers/utils/masking_utils.py index b47c2124a79..c6e2d056ee1 100644 --- a/paddleformers/utils/masking_utils.py +++ b/paddleformers/utils/masking_utils.py @@ -31,9 +31,9 @@ def _gen_from_sparse_attn_mask_indices(attn_mask_start_row_indices, dtype): Returns: paddle.Tensor: The dense attention mask recovered from attn_mask_start_row_indices. """ - batch_size, _, max_seq_len = attn_mask_start_row_indices.shape + batch_size, _, max_seq_len, _ = attn_mask_start_row_indices.shape base = paddle.arange(max_seq_len, dtype="int32").unsqueeze(1).expand([batch_size, -1, max_seq_len]).unsqueeze(1) - mask_indices = attn_mask_start_row_indices.unsqueeze(1) + mask_indices = attn_mask_start_row_indices tril = paddle.tril( paddle.ones([max_seq_len, max_seq_len], dtype="bool").expand([batch_size, 1, max_seq_len, max_seq_len])