Skip to content

Commit ed15c99

Browse files
authored
add-qwen2/3-fleet (#3965)
1 parent 288101c commit ed15c99

File tree

11 files changed

+366
-62
lines changed

11 files changed

+366
-62
lines changed

paddleformers/transformers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@
185185
"Qwen2ForSequenceClassification",
186186
"Qwen2ForTokenClassification",
187187
"Qwen2SentenceEmbedding",
188+
"Qwen2ForCausalLMDeprecated",
189+
"Qwen2ForCausalLMPipeDeprecated",
188190
],
189191
"qwen2.tokenizer": ["Qwen2Tokenizer"],
190192
"qwen2.tokenizer_fast": ["Qwen2TokenizerFast"],
@@ -242,6 +244,8 @@
242244
"Qwen3ForSequenceClassification",
243245
"Qwen3ForTokenClassification",
244246
"Qwen3SentenceEmbedding",
247+
"Qwen3ForCausalLMDeprecated",
248+
"Qwen3ForCausalLMPipeDeprecated",
245249
],
246250
"qwen3_moe.configuration": ["Qwen3MoeConfig"],
247251
"qwen3_moe.modeling": [

paddleformers/transformers/qwen2/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
"Qwen2ForSequenceClassification",
3131
"Qwen2ForTokenClassification",
3232
"Qwen2SentenceEmbedding",
33+
"Qwen2ForCausalLMDeprecated",
34+
"Qwen2ForCausalLMPipeDeprecated",
3335
],
3436
}
3537

paddleformers/transformers/qwen2/modeling.py

Lines changed: 151 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
"""Paddle Qwen2 model."""
2121
from __future__ import annotations
2222

23+
import json
24+
import os
25+
from dataclasses import asdict, dataclass
2326
from typing import Dict, Optional, Tuple, Union
2427

2528
import paddle
@@ -28,14 +31,16 @@
2831
from paddle.distributed.fleet.recompute.recompute import recompute
2932
from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp
3033

34+
from paddleformers.transformers.gpt_provider import GPTModelProvider
35+
3136
from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS
3237
from ...nn.criterion.interface import CriterionLayer
3338
from ...nn.embedding import Embedding as GeneralEmbedding
3439
from ...nn.linear import Linear as GeneralLinear
3540
from ...nn.lm_head import LMHead as GeneralLMHead
3641
from ...nn.mlp import MLP as Qwen2MLP
3742
from ...nn.norm import Norm as GeneralNorm
38-
from ...nn.pp_model import GeneralModelForCausalLMPipe
43+
from ...nn.pp_model import CriterionLayerPipe, GeneralModelForCausalLMPipe
3944
from ...utils.log import logger
4045
from ..cache_utils import Cache, DynamicCache
4146
from ..contrastive_loss import SimpleContrastiveLoss
@@ -55,6 +60,62 @@
5560
from .configuration import Qwen2Config
5661

5762

63+
@dataclass
64+
class Qwen2ModelProvider(GPTModelProvider):
65+
"""Base provider for Qwen2 Models."""
66+
67+
model_type = "qwen2"
68+
69+
attention_bias: bool = True
70+
71+
bias_activation_fusion: bool = True
72+
bias_dropout_fusion: bool = True
73+
74+
transform_rules = {
75+
"dtype": "params_dtype",
76+
}
77+
78+
persist_layer_norm: bool = True
79+
share_embeddings_and_output_weights: bool = False
80+
81+
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
82+
"""
83+
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
84+
[`~PretrainedConfig.from_pretrained`] class method.
85+
86+
Args:
87+
save_directory (`str` or `os.PathLike`):
88+
Directory where the configuration JSON file will be saved (will be created if it does not exist).
89+
kwargs:
90+
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
91+
"""
92+
if os.path.isfile(save_directory):
93+
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
94+
95+
os.makedirs(save_directory, exist_ok=True)
96+
97+
output_config_file = os.path.join(save_directory, self.CONFIG_NAME)
98+
config_dict = asdict(self)
99+
100+
# Filter out non-serializable values
101+
def make_serializable(obj):
102+
if isinstance(obj, dict):
103+
return {k: make_serializable(v) for k, v in obj.items() if make_serializable(v) is not None}
104+
elif isinstance(obj, (list, tuple)):
105+
return [make_serializable(item) for item in obj if make_serializable(item) is not None]
106+
elif isinstance(obj, (str, int, float, bool, type(None))):
107+
return obj
108+
else:
109+
# Skip non-serializable types like partial, function, etc.
110+
return None
111+
112+
serializable_config = make_serializable(config_dict)
113+
114+
with open(output_config_file, "w", encoding="utf-8") as writer:
115+
writer.write(json.dumps(serializable_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
116+
logger.info(f"Configuration saved in {output_config_file}")
117+
118+
58119
def rotate_half(x):
59120
"""Rotates half the hidden dims of the input."""
60121
x1 = x[..., : x.shape[-1] // 2]
@@ -267,17 +328,27 @@ class Qwen2PretrainedModel(PretrainedModel):
267328
@classmethod
268329
def _gen_aoa_config(cls, config: Qwen2Config):
269330
model_prefix = "" if cls == cls.base_model_class else "model."
331+
is_fleet = getattr(cls, "is_fleet", False)
332+
270333
aoa_config = {
271334
"aoa_statements": [
272335
f"model.layers.$LAYER_ID.self_attn.o_proj.weight^T -> {model_prefix}layers.$LAYER_ID.self_attn.o_proj.weight",
273336
f"model.layers.$LAYER_ID.mlp.down_proj.weight^T -> {model_prefix}layers.$LAYER_ID.mlp.down_proj.weight",
274-
f"model.embed_tokens.weight -> {model_prefix}embed_tokens.weight",
275337
f"model.layers.$LAYER_ID.input_layernorm.weight -> {model_prefix}layers.$LAYER_ID.input_layernorm.weight",
276338
f"model.layers.$LAYER_ID.post_attention_layernorm.weight -> {model_prefix}layers.$LAYER_ID.post_attention_layernorm.weight",
277339
f"model.norm.weight -> {model_prefix}norm.weight",
278340
]
279341
}
280342

343+
if is_fleet:
344+
aoa_config["aoa_statements"] += [
345+
f"model.embed_tokens.weight -> {model_prefix}embedding.embed_tokens.weight",
346+
]
347+
else:
348+
aoa_config["aoa_statements"] += [
349+
f"model.embed_tokens.weight -> {model_prefix}embed_tokens.weight",
350+
]
351+
281352
# attention qkv
282353
aoa_config["aoa_statements"] += [
283354
f"model.layers.$LAYER_ID.self_attn.q_proj.weight^T, model.layers.$LAYER_ID.self_attn.k_proj.weight^T, model.layers.$LAYER_ID.self_attn.v_proj.weight^T -> {model_prefix}layers.$LAYER_ID.self_attn.qkv_proj.weight, fused_qkv, num_heads={config.num_attention_heads}, num_key_value_groups={config.num_key_value_heads}",
@@ -293,22 +364,38 @@ def _gen_aoa_config(cls, config: Qwen2Config):
293364

294365
# lm_head
295366
if config.tie_word_embeddings:
296-
aoa_config["aoa_statements"] += ["model.embed_tokens.weight -> lm_head.weight"]
367+
if is_fleet:
368+
aoa_config["aoa_statements"] += [f"model.embed_tokens.weight -> {model_prefix}lm_head.weight"]
369+
else:
370+
aoa_config["aoa_statements"] += ["model.embed_tokens.weight -> lm_head.weight"]
371+
else:
372+
if is_fleet:
373+
aoa_config["aoa_statements"] += [f"lm_head.weight -> {model_prefix}lm_head.weight"]
297374

298375
return aoa_config
299376

300377
@classmethod
301378
def _gen_inv_aoa_config(cls, config: Qwen2Config):
302379
model_prefix = "" if cls == cls.base_model_class else "model."
380+
is_fleet = getattr(cls, "is_fleet", False)
381+
303382
aoa_statements = [
304383
f"{model_prefix}layers.$LAYER_ID.self_attn.o_proj.weight^T -> model.layers.$LAYER_ID.self_attn.o_proj.weight",
305384
f"{model_prefix}layers.$LAYER_ID.mlp.down_proj.weight^T -> model.layers.$LAYER_ID.mlp.down_proj.weight",
306-
f"{model_prefix}embed_tokens.weight -> model.embed_tokens.weight",
307385
f"{model_prefix}layers.$LAYER_ID.input_layernorm.weight -> model.layers.$LAYER_ID.input_layernorm.weight",
308386
f"{model_prefix}layers.$LAYER_ID.post_attention_layernorm.weight -> model.layers.$LAYER_ID.post_attention_layernorm.weight",
309387
f"{model_prefix}norm.weight -> model.norm.weight",
310388
]
311389

390+
if is_fleet:
391+
aoa_statements += [
392+
f"{model_prefix}embedding.embed_tokens.weight -> model.embed_tokens.weight",
393+
]
394+
else:
395+
aoa_statements += [
396+
f"{model_prefix}embed_tokens.weight -> model.embed_tokens.weight",
397+
]
398+
312399
aoa_statements += [
313400
f"{model_prefix}layers.$LAYER_ID.self_attn.qkv_proj.weight -> model.layers.$LAYER_ID.self_attn.q_proj.weight, model.layers.$LAYER_ID.self_attn.k_proj.weight, model.layers.$LAYER_ID.self_attn.v_proj.weight , fused_qkv, num_heads={config.num_attention_heads}, num_key_value_groups = {config.num_key_value_heads}",
314401
]
@@ -331,7 +418,13 @@ def _gen_inv_aoa_config(cls, config: Qwen2Config):
331418
]
332419

333420
if config.tie_word_embeddings:
334-
aoa_statements += ["lm_head.weight -> _"]
421+
if is_fleet:
422+
aoa_statements += [f"{model_prefix}lm_head.weight -> _"]
423+
else:
424+
aoa_statements += ["lm_head.weight -> _"]
425+
else:
426+
if is_fleet:
427+
aoa_statements += [f"{model_prefix}lm_head.weight -> lm_head.weight"]
335428

336429
aoa_config = {"aoa_statements": aoa_statements}
337430
return aoa_config
@@ -574,6 +667,30 @@ def forward(
574667

575668

576669
class Qwen2ForCausalLM(Qwen2PretrainedModel):
670+
is_fleet = True
671+
672+
def __new__(cls, config):
673+
# Hybrid parallel config convert.
674+
config.tensor_model_parallel_size = max(config.tensor_model_parallel_size, 1)
675+
config.context_parallel_size = max(config.context_parallel_size, 1)
676+
config.pipeline_model_parallel_size = max(config.pipeline_model_parallel_size, 1)
677+
config.virtual_pipeline_model_parallel_size = max(config.virtual_pipeline_model_parallel_size, 1)
678+
config.expert_model_parallel_size = max(config.expert_model_parallel_size, 1)
679+
680+
model_provider_class = Qwen2ModelProvider
681+
model_provider = model_provider_class.from_config(config)
682+
loss_fn = None
683+
if getattr(config, "dpo_config", None):
684+
loss_fn = CriterionLayerPipe(config, use_infohub=True)
685+
gpt_model = model_provider.provide(loss_fn=loss_fn)
686+
gpt_model._gen_aoa_config = cls._gen_aoa_config
687+
gpt_model._gen_inv_aoa_config = cls._gen_inv_aoa_config
688+
gpt_model.config_to_save = config
689+
gpt_model.is_fleet = cls.is_fleet
690+
return gpt_model
691+
692+
693+
class Qwen2ForCausalLMDeprecated(Qwen2PretrainedModel):
577694
enable_to_static_method = True
578695
_tied_weights_keys = ["lm_head.weight"]
579696

@@ -903,7 +1020,33 @@ def encode(
9031020
return last_hidden_states
9041021

9051022

906-
class Qwen2ForCausalLMPipe(GeneralModelForCausalLMPipe):
1023+
class Qwen2ForCausalLMPipe(Qwen2PretrainedModel, GeneralModelForCausalLMPipe):
1024+
is_fleet = True
1025+
1026+
def __new__(cls, config):
1027+
# Hybrid parallel config convert.
1028+
config.tensor_model_parallel_size = max(config.tensor_model_parallel_size, 1)
1029+
config.context_parallel_size = max(config.context_parallel_size, 1)
1030+
config.pipeline_model_parallel_size = max(config.pipeline_model_parallel_size, 1)
1031+
config.virtual_pipeline_model_parallel_size = max(config.virtual_pipeline_model_parallel_size, 1)
1032+
config.expert_model_parallel_size = max(config.expert_model_parallel_size, 1)
1033+
1034+
model_provider_class = Qwen2ModelProvider
1035+
model_provider = model_provider_class.from_config(config)
1036+
loss_fn = None
1037+
if getattr(config, "dpo_config", None):
1038+
loss_fn = CriterionLayerPipe(config, use_infohub=True)
1039+
gpt_model = model_provider.provide(loss_fn=loss_fn)
1040+
gpt_model._gen_aoa_config = cls._gen_aoa_config
1041+
gpt_model._gen_inv_aoa_config = cls._gen_inv_aoa_config
1042+
if not hasattr(config, "architectures"):
1043+
config.architectures = [cls.__name__.replace("Pipe", "")]
1044+
gpt_model.config_to_save = config
1045+
gpt_model.is_fleet = cls.is_fleet
1046+
return gpt_model
1047+
1048+
1049+
class Qwen2ForCausalLMPipeDeprecated(GeneralModelForCausalLMPipe):
9071050
config_class = Qwen2Config
9081051
_decoder_layer_cls = Qwen2DecoderLayer
9091052
_get_tensor_parallel_mappings = Qwen2Model._get_tensor_parallel_mappings
@@ -924,4 +1067,6 @@ class Qwen2ForCausalLMPipe(GeneralModelForCausalLMPipe):
9241067
"Qwen2ForSequenceClassification",
9251068
"Qwen2ForTokenClassification",
9261069
"Qwen2SentenceEmbedding",
1070+
"Qwen2ForCausalLMDeprecated",
1071+
"Qwen2ForCausalLMPipeDeprecated",
9271072
]

paddleformers/transformers/qwen3/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
"Qwen3ForSequenceClassification",
3131
"Qwen3ForTokenClassification",
3232
"Qwen3SentenceEmbedding",
33+
"Qwen3ForCausalLMDeprecated",
34+
"Qwen3ForCausalLMPipeDeprecated",
3335
],
3436
}
3537

0 commit comments

Comments
 (0)