|
27 | 27 | )
|
28 | 28 |
|
29 | 29 | from paddleformers.datasets.dpo import collate_fn, create_dataset
|
| 30 | +from paddleformers.nn.attention import AttentionInterface |
30 | 31 | from paddleformers.peft import LoRAConfig, LoRAModel
|
31 | 32 | from paddleformers.trainer import PdArgumentParser, get_last_checkpoint, set_seed
|
32 | 33 | from paddleformers.transformers import (
|
@@ -57,6 +58,11 @@ def main():
|
57 | 58 |
|
58 | 59 | paddle.set_device(training_args.device)
|
59 | 60 | set_seed(training_args.seed)
|
| 61 | + |
| 62 | + avaible_attn_impl = AttentionInterface._global_mapping.keys() |
| 63 | + if model_args.attn_impl not in avaible_attn_impl: |
| 64 | + raise ValueError(f"Invalid attn_impl: {model_args.attn_impl}, available attn_impl: {avaible_attn_impl}") |
| 65 | + |
60 | 66 | if dpo_config.loss_type == "orpo":
|
61 | 67 | dpo_config.reference_free = True
|
62 | 68 | dpo_config.sft_loss_ratio = 1.0
|
@@ -113,6 +119,8 @@ def main():
|
113 | 119 | dtype=dtype,
|
114 | 120 | download_hub=model_args.download_hub,
|
115 | 121 | )
|
| 122 | + model_config._attn_implementation = model_args.attn_impl |
| 123 | + |
116 | 124 | LlmMetaConfig.set_llm_config(model_config, training_args)
|
117 | 125 |
|
118 | 126 | if not dpo_config.reference_free and not dpo_config.lora:
|
@@ -151,11 +159,8 @@ def main():
|
151 | 159 | ref_model = None
|
152 | 160 | if training_args.pipeline_parallel_degree > 1:
|
153 | 161 | model.config.dpo_config = None
|
154 |
| - if model_args.flash_mask and not model.config.use_flash_attention: |
155 |
| - logger.warning("`flash_mask` must use with zero padding and flash attention.") |
156 |
| - model.config.use_flash_attention = True |
157 | 162 |
|
158 |
| - if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list): |
| 163 | + if model_args.attn_impl == "flashmask" and not any(isinstance(model, cls) for cls in flash_mask_support_list): |
159 | 164 | raise NotImplementedError(f"{model.__class__} not support flash mask.")
|
160 | 165 |
|
161 | 166 | if model_args.tokenizer_name_or_path is not None:
|
@@ -219,6 +224,7 @@ def main():
|
219 | 224 | "greedy_intokens": data_args.greedy_intokens,
|
220 | 225 | "packing": data_args.packing,
|
221 | 226 | "mix_strategy": data_args.mix_strategy,
|
| 227 | + "encode_one_turn": data_args.encode_one_turn, |
222 | 228 | }
|
223 | 229 | if training_args.do_train and training_args.should_load_dataset:
|
224 | 230 | train_dataset = create_dataset(
|
|
0 commit comments