Skip to content

Commit 7ca795a

Browse files
authored
remove flashmask checker (#2631)
1 parent 5051892 commit 7ca795a

File tree

1 file changed

+0
-42
lines changed

1 file changed

+0
-42
lines changed

examples/run_finetune.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -37,26 +37,8 @@
3737
AutoModelForCausalLM,
3838
AutoModelForCausalLMPipe,
3939
AutoTokenizer,
40-
DeepseekV2ForCausalLM,
41-
DeepseekV2ForCausalLMPipe,
42-
DeepseekV3ForCausalLM,
43-
DeepseekV3ForCausalLMPipe,
44-
Ernie4_5_MoeForCausalLM,
45-
Ernie4_5_MoeForCausalLMPipe,
46-
Ernie4_5ForCausalLM,
47-
Ernie4_5ForCausalLMPipe,
4840
Llama3Tokenizer,
49-
LlamaForCausalLM,
50-
LlamaForCausalLMPipe,
5141
LlamaTokenizer,
52-
Qwen2ForCausalLM,
53-
Qwen2ForCausalLMPipe,
54-
Qwen2MoeForCausalLM,
55-
Qwen2MoeForCausalLMPipe,
56-
Qwen3ForCausalLM,
57-
Qwen3ForCausalLMPipe,
58-
Qwen3MoeForCausalLM,
59-
Qwen3MoeForCausalLMPipe,
6042
)
6143
from paddleformers.transformers.configuration_utils import LlmMetaConfig
6244
from paddleformers.trl import DataConfig, ModelConfig, SFTConfig, SFTTrainer
@@ -66,27 +48,6 @@
6648
# Fine-tune Environment Variables to support sharding stage1 overlap optimization.
6749
os.environ["USE_CASUAL_MASK"] = "False"
6850

69-
flash_mask_support_list = [
70-
DeepseekV2ForCausalLM,
71-
DeepseekV2ForCausalLMPipe,
72-
DeepseekV3ForCausalLM,
73-
DeepseekV3ForCausalLMPipe,
74-
Ernie4_5ForCausalLM,
75-
Ernie4_5ForCausalLMPipe,
76-
Ernie4_5_MoeForCausalLM,
77-
Ernie4_5_MoeForCausalLMPipe,
78-
LlamaForCausalLM,
79-
LlamaForCausalLMPipe,
80-
Qwen2ForCausalLM,
81-
Qwen2ForCausalLMPipe,
82-
Qwen2MoeForCausalLM,
83-
Qwen2MoeForCausalLMPipe,
84-
Qwen3ForCausalLM,
85-
Qwen3ForCausalLMPipe,
86-
Qwen3MoeForCausalLM,
87-
Qwen3MoeForCausalLMPipe,
88-
]
89-
9051

9152
def main():
9253
parser = PdArgumentParser((ModelConfig, DataConfig, SFTConfig))
@@ -192,9 +153,6 @@ def main():
192153
else:
193154
model = model_class.from_config(model_config, dtype=dtype)
194155

195-
if model_args.attn_impl == "flashmask" and not any(isinstance(model, cls) for cls in flash_mask_support_list):
196-
raise NotImplementedError(f"{model.__class__} not support flash mask.")
197-
198156
if training_args.do_train and model_args.neftune:
199157
# Inspired by https://github.com/neelsjain/NEFTune
200158
if hasattr(model, "get_input_embeddings"):

0 commit comments

Comments
 (0)