|
37 | 37 | AutoModelForCausalLM,
|
38 | 38 | AutoModelForCausalLMPipe,
|
39 | 39 | AutoTokenizer,
|
40 |
| - DeepseekV2ForCausalLM, |
41 |
| - DeepseekV2ForCausalLMPipe, |
42 |
| - DeepseekV3ForCausalLM, |
43 |
| - DeepseekV3ForCausalLMPipe, |
44 |
| - Ernie4_5_MoeForCausalLM, |
45 |
| - Ernie4_5_MoeForCausalLMPipe, |
46 |
| - Ernie4_5ForCausalLM, |
47 |
| - Ernie4_5ForCausalLMPipe, |
48 | 40 | Llama3Tokenizer,
|
49 |
| - LlamaForCausalLM, |
50 |
| - LlamaForCausalLMPipe, |
51 | 41 | LlamaTokenizer,
|
52 |
| - Qwen2ForCausalLM, |
53 |
| - Qwen2ForCausalLMPipe, |
54 |
| - Qwen2MoeForCausalLM, |
55 |
| - Qwen2MoeForCausalLMPipe, |
56 |
| - Qwen3ForCausalLM, |
57 |
| - Qwen3ForCausalLMPipe, |
58 |
| - Qwen3MoeForCausalLM, |
59 |
| - Qwen3MoeForCausalLMPipe, |
60 | 42 | )
|
61 | 43 | from paddleformers.transformers.configuration_utils import LlmMetaConfig
|
62 | 44 | from paddleformers.trl import DataConfig, ModelConfig, SFTConfig, SFTTrainer
|
|
66 | 48 | # Fine-tune Environment Variables to support sharding stage1 overlap optimization.
|
67 | 49 | os.environ["USE_CASUAL_MASK"] = "False"
|
68 | 50 |
|
69 |
| -flash_mask_support_list = [ |
70 |
| - DeepseekV2ForCausalLM, |
71 |
| - DeepseekV2ForCausalLMPipe, |
72 |
| - DeepseekV3ForCausalLM, |
73 |
| - DeepseekV3ForCausalLMPipe, |
74 |
| - Ernie4_5ForCausalLM, |
75 |
| - Ernie4_5ForCausalLMPipe, |
76 |
| - Ernie4_5_MoeForCausalLM, |
77 |
| - Ernie4_5_MoeForCausalLMPipe, |
78 |
| - LlamaForCausalLM, |
79 |
| - LlamaForCausalLMPipe, |
80 |
| - Qwen2ForCausalLM, |
81 |
| - Qwen2ForCausalLMPipe, |
82 |
| - Qwen2MoeForCausalLM, |
83 |
| - Qwen2MoeForCausalLMPipe, |
84 |
| - Qwen3ForCausalLM, |
85 |
| - Qwen3ForCausalLMPipe, |
86 |
| - Qwen3MoeForCausalLM, |
87 |
| - Qwen3MoeForCausalLMPipe, |
88 |
| -] |
89 |
| - |
90 | 51 |
|
91 | 52 | def main():
|
92 | 53 | parser = PdArgumentParser((ModelConfig, DataConfig, SFTConfig))
|
@@ -192,9 +153,6 @@ def main():
|
192 | 153 | else:
|
193 | 154 | model = model_class.from_config(model_config, dtype=dtype)
|
194 | 155 |
|
195 |
| - if model_args.attn_impl == "flashmask" and not any(isinstance(model, cls) for cls in flash_mask_support_list): |
196 |
| - raise NotImplementedError(f"{model.__class__} not support flash mask.") |
197 |
| - |
198 | 156 | if training_args.do_train and model_args.neftune:
|
199 | 157 | # Inspired by https://github.com/neelsjain/NEFTune
|
200 | 158 | if hasattr(model, "get_input_embeddings"):
|
|
0 commit comments