Skip to content

Commit 58fc49f

Browse files
authored
[LLM] Add pipeline and flashmask for Qwen2Moe and Deepseek (PaddlePaddle#9827)
* add modleing_pp * add modleing_pp for qwen2moe * add flashmask and pp for Qwen2MoE and Deepseek * remove * fix fast_tokenizer save * update for topk_weight of noaux_tc * fix for flashmask * add use_expert_parallel for pretrain * fix tokenizer test
1 parent 86286e0 commit 58fc49f

File tree

19 files changed

+1365
-342
lines changed

19 files changed

+1365
-342
lines changed

β€Žllm/run_finetune.pyβ€Ž

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,18 @@
5252
AutoModelForCausalLM,
5353
AutoModelForCausalLMPipe,
5454
AutoTokenizer,
55+
DeepseekV2ForCausalLM,
56+
DeepseekV2ForCausalLMPipe,
57+
DeepseekV3ForCausalLM,
58+
DeepseekV3ForCausalLMPipe,
5559
Llama3Tokenizer,
5660
LlamaForCausalLM,
5761
LlamaForCausalLMPipe,
5862
LlamaTokenizer,
5963
Qwen2ForCausalLM,
6064
Qwen2ForCausalLMPipe,
65+
Qwen2MoeForCausalLM,
66+
Qwen2MoeForCausalLMPipe,
6167
)
6268
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
6369
from paddlenlp.trl import DataConfig, ModelConfig, SFTConfig, SFTTrainer
@@ -74,7 +80,18 @@
7480
# Fine-tune Environment Variables to support sharding stage1 overlap optimization.
7581
os.environ["USE_CASUAL_MASK"] = "False"
7682

77-
flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe, Qwen2ForCausalLM, Qwen2ForCausalLMPipe]
83+
flash_mask_support_list = [
84+
DeepseekV2ForCausalLM,
85+
DeepseekV2ForCausalLMPipe,
86+
DeepseekV3ForCausalLM,
87+
DeepseekV3ForCausalLMPipe,
88+
LlamaForCausalLM,
89+
LlamaForCausalLMPipe,
90+
Qwen2ForCausalLM,
91+
Qwen2ForCausalLMPipe,
92+
Qwen2MoeForCausalLM,
93+
Qwen2MoeForCausalLMPipe,
94+
]
7895

7996

8097
def paddlenlp_verison_check():
@@ -151,7 +168,11 @@ def main():
151168
quantization_config=quantization_config,
152169
)
153170

154-
if "Qwen2Moe" in str(model_config.architectures) and training_args.data_parallel_degree > 1:
171+
architectures_to_check = {"Qwen2Moe", "DeepseekV2", "DeepseekV3"}
172+
if (
173+
any(architecture in str(model_config.architectures) for architecture in architectures_to_check)
174+
and training_args.data_parallel_degree > 1
175+
):
155176
training_args.use_expert_parallel = True
156177

157178
LlmMetaConfig.set_llm_config(model_config, training_args)
@@ -585,7 +606,12 @@ def create_peft_model(model_args, reft_args, training_args, dtype, model_config,
585606
def trans_dataset_to_ids(train_ds, dev_ds, test_ds, model_args, data_args, trans_func, eval_zero_padding):
586607
if train_ds is not None:
587608
train_ds = train_ds.map(
588-
partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask)
609+
partial(
610+
trans_func,
611+
is_test=False,
612+
zero_padding=data_args.zero_padding,
613+
flash_mask=model_args.flash_mask,
614+
)
589615
)
590616
if dev_ds is not None:
591617
dev_ds = dev_ds.map(

β€Žllm/run_pretrain.pyβ€Ž

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,11 @@ def main():
478478
except:
479479
print("Not register llama pp reshard information.")
480480

481-
if "Qwen2Moe" in str(config.architectures) and training_args.data_parallel_degree > 1:
481+
architectures_to_check = {"Qwen2Moe", "DeepseekV2", "DeepseekV3"}
482+
if (
483+
any(architecture in str(config.architectures) for architecture in architectures_to_check)
484+
and training_args.data_parallel_degree > 1
485+
):
482486
training_args.use_expert_parallel = True
483487

484488
if model_args.continue_training:

β€Žllm/utils/data.pyβ€Ž

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,13 @@ def get_convert_example(model):
5959
"gpt",
6060
"yuan",
6161
"jamba",
62+
"deepseek_v2",
63+
"deepseek_v3",
6264
]:
6365
return convert_example_common
6466
else:
6567
raise ValueError(
66-
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral, gemma, qwen2, qwen2_moe, yuan, jamba",
68+
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral, gemma, qwen2, qwen2_moe, yuan, jamba,deepseek_v2, deepseek_v3",
6769
)
6870

6971

β€Žpaddlenlp/transformers/__init__.pyβ€Ž

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,8 @@
143143
from .deberta_v2.configuration import *
144144
from .deberta_v2.modeling import *
145145
from .deberta_v2.tokenizer import *
146-
from .deepseek_v2.configuration import *
147-
from .deepseek_v2.modeling import *
148-
from .deepseek_v2.tokenizer_fast import *
149-
from .deepseek_v3.configuration import *
150-
from .deepseek_v3.modeling import *
146+
from .deepseek_v2 import *
147+
from .deepseek_v3 import *
151148
from .distilbert.configuration import *
152149
from .distilbert.modeling import *
153150
from .distilbert.tokenizer import *

β€Žpaddlenlp/transformers/deepseek_v2/__init__.pyβ€Ž

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414

1515
from .configuration import *
1616
from .modeling import *
17+
from .modeling_pp import *
1718
from .tokenizer_fast import *

0 commit comments

Comments
Β (0)