Skip to content

Commit c9e9b68

Browse files
authored
fix pp_seg_method and unfiy training attention with attn_impl (#2572)
1 parent 0031d69 commit c9e9b68

15 files changed

+30
-40
lines changed

examples/alignment/dpo/dpo_argument.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,6 @@ class DPOModelArgument:
161161
lora_alpha: int = field(default=-1, metadata={"help": "lora_alpha"})
162162
rslora_plus: bool = field(default=False, metadata={"help": "Strengthen lora performance"})
163163
use_quick_lora: bool = field(default=True, metadata={"help": "quick lora"})
164+
165+
# Attention
166+
attn_impl: str = field(default="flashmask", metadata={"help": "Attention implementation"})

examples/alignment/dpo/run_dpo.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828

2929
from paddleformers.datasets.dpo import collate_fn, create_dataset
30+
from paddleformers.nn.attention import AttentionInterface
3031
from paddleformers.peft import LoRAConfig, LoRAModel
3132
from paddleformers.trainer import PdArgumentParser, get_last_checkpoint, set_seed
3233
from paddleformers.transformers import (
@@ -57,6 +58,11 @@ def main():
5758

5859
paddle.set_device(training_args.device)
5960
set_seed(training_args.seed)
61+
62+
avaible_attn_impl = AttentionInterface._global_mapping.keys()
63+
if model_args.attn_impl not in avaible_attn_impl:
64+
raise ValueError(f"Invalid attn_impl: {model_args.attn_impl}, available attn_impl: {avaible_attn_impl}")
65+
6066
if dpo_config.loss_type == "orpo":
6167
dpo_config.reference_free = True
6268
dpo_config.sft_loss_ratio = 1.0
@@ -113,6 +119,8 @@ def main():
113119
dtype=dtype,
114120
download_hub=model_args.download_hub,
115121
)
122+
model_config._attn_implementation = model_args.attn_impl
123+
116124
LlmMetaConfig.set_llm_config(model_config, training_args)
117125

118126
if not dpo_config.reference_free and not dpo_config.lora:
@@ -151,11 +159,8 @@ def main():
151159
ref_model = None
152160
if training_args.pipeline_parallel_degree > 1:
153161
model.config.dpo_config = None
154-
if model_args.flash_mask and not model.config.use_flash_attention:
155-
logger.warning("`flash_mask` must use with zero padding and flash attention.")
156-
model.config.use_flash_attention = True
157162

158-
if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list):
163+
if model_args.attn_impl == "flashmask" and not any(isinstance(model, cls) for cls in flash_mask_support_list):
159164
raise NotImplementedError(f"{model.__class__} not support flash mask.")
160165

161166
if model_args.tokenizer_name_or_path is not None:
@@ -219,6 +224,7 @@ def main():
219224
"greedy_intokens": data_args.greedy_intokens,
220225
"packing": data_args.packing,
221226
"mix_strategy": data_args.mix_strategy,
227+
"encode_one_turn": data_args.encode_one_turn,
222228
}
223229
if training_args.do_train and training_args.should_load_dataset:
224230
train_dataset = create_dataset(

examples/config/ernie4_5/sft_argument_ernie4_5_0p3b.json

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,8 @@
3434
"tensor_parallel_degree": 2,
3535
"pipeline_parallel_degree": 2,
3636
"sharding": "stage2",
37-
"zero_padding": true,
38-
"flash_mask": true,
3937
"unified_checkpoint": true,
40-
"use_flash_attention": true,
38+
"attn_impl": "flashmask",
4139
"sequence_parallel": true,
4240
"report_to": "none",
4341
"convert_from_hf": true,

examples/config/ernie4_5_moe/sft_argument_ernie4_5_21b_a3b.json

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,8 @@
3434
"tensor_parallel_degree": 4,
3535
"pipeline_parallel_degree": 2,
3636
"sharding": "stage2",
37-
"zero_padding": true,
38-
"flash_mask": true,
3937
"unified_checkpoint": true,
40-
"use_flash_attention": true,
38+
"attn_impl": "flashmask",
4139
"sequence_parallel": true,
4240
"report_to": "none",
4341
"convert_from_hf": true,

examples/config/gpt_oss/sft_argument_gptoss_20b.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
"tensor_parallel_degree": 4,
3636
"pipeline_parallel_degree": 1,
3737
"sharding": "stage2",
38-
"zero_padding": false,
3938
"unified_checkpoint": true,
4039
"use_flash_attention": false,
4140
"lora": true,

examples/config/qwen/dpo_argument_qwen2_0p5b.json

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,13 @@
3232
"load_best_model_at_end": true,
3333
"tensor_parallel_degree": 1,
3434
"sharding": "stage1",
35-
"use_flash_attention": false,
36-
"flash_mask": false,
35+
"attn_impl": "flashmask",
3736
"recompute": true,
3837
"recompute_granularity": "full",
3938
"benchmark": false,
4039
"unified_checkpoint": true,
4140
"autotuner_benchmark":false,
4241
"beta": 0.1,
4342
"loss_type": "sigmoid",
44-
"greedy_zero_padding": false,
4543
"label_smoothing": 0.0
4644
}

examples/config/qwen/dpo_lora_argument_qwen2_0p5b.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"load_best_model_at_end": true,
3333
"tensor_parallel_degree": 1,
3434
"sharding": "stage1",
35-
"use_flash_attention": true,
35+
"attn_impl": "flashmask",
3636
"recompute": false,
3737
"recompute_granularity": "full",
3838
"beta": 0.1,

examples/config/qwen/lora_argument_qwen2_0p5b.json

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@
3535
"pipeline_parallel_degree": 1,
3636
"sharding": "stage2",
3737
"lora": true,
38-
"zero_padding": true,
39-
"flash_mask": true,
4038
"unified_checkpoint": true,
41-
"use_flash_attention": true,
39+
"attn_impl": "flashmask",
4240
"convert_from_hf": false,
4341
"save_to_hf": false,
4442
"pissa": false,

examples/config/qwen/sft_argument_qwen2_0p5b.json

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,8 @@
3434
"tensor_parallel_degree": 1,
3535
"pipeline_parallel_degree": 1,
3636
"sharding": "stage2",
37-
"zero_padding": true,
38-
"flash_mask": true,
3937
"unified_checkpoint": true,
40-
"use_flash_attention": true,
38+
"attn_impl": "flashmask",
4139
"convert_from_hf": false,
4240
"save_to_hf": false,
4341
"encode_one_turn": true

examples/run_finetune.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from paddleformers.datasets.data_utils import estimate_training
2323
from paddleformers.datasets.finetuning import collate_fn
2424
from paddleformers.datasets.finetuning import create_dataset as create_dataset_sft
25+
from paddleformers.nn.attention import AttentionInterface
2526
from paddleformers.peft import LoRAConfig, LoRAModel
2627
from paddleformers.trainer import (
2728
IntervalStrategy,
@@ -161,7 +162,12 @@ def main():
161162
model_config.fuse_attention_qkv = model_args.fuse_attention_qkv
162163
if model_args.fuse_attention_ffn is not None:
163164
model_config.fuse_attention_ffn = model_args.fuse_attention_ffn
164-
model_config.pp_seg_method = training_args.pp_seg_method
165+
166+
avaible_attn_impl = AttentionInterface._global_mapping.keys()
167+
if model_args.attn_impl not in avaible_attn_impl:
168+
raise ValueError(f"Invalid attn_impl: {model_args.attn_impl}, available attn_impl: {avaible_attn_impl}")
169+
170+
model_config.pp_seg_method = model_args.pp_seg_method
165171
model_config.seq_length = training_args.max_seq_len
166172
model_config.max_sequence_length = training_args.max_seq_len
167173
model_config.num_nextn_predict_layers = model_args.num_nextn_predict_layers
@@ -185,13 +191,7 @@ def main():
185191
else:
186192
model = model_class.from_config(model_config, dtype=dtype)
187193

188-
if model_args.flash_mask and (not data_args.zero_padding or not model.config.use_flash_attention):
189-
logger.warning("`flash_mask` must use with zero padding and flash attention.")
190-
data_args.zero_padding = True
191-
model.config.use_flash_attention = True
192-
model.config._attn_implementation = "flashmask"
193-
194-
if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list):
194+
if model_args.attn_impl == "flashmask" and not any(isinstance(model, cls) for cls in flash_mask_support_list):
195195
raise NotImplementedError(f"{model.__class__} not support flash mask.")
196196

197197
if training_args.do_train and model_args.neftune:

0 commit comments

Comments
 (0)