2020"""Paddle Qwen2 model."""
2121from __future__ import annotations
2222
23+ import json
24+ import os
25+ from dataclasses import asdict , dataclass
2326from typing import Dict , Optional , Tuple , Union
2427
2528import paddle
2831from paddle .distributed .fleet .recompute .recompute import recompute
2932from paddle .distributed .fleet .utils .sequence_parallel_utils import ScatterOp
3033
34+ from paddleformers .transformers .gpt_provider import GPTModelProvider
35+
3136from ...nn .attention .interface import ALL_ATTENTION_FUNCTIONS
3237from ...nn .criterion .interface import CriterionLayer
3338from ...nn .embedding import Embedding as GeneralEmbedding
3439from ...nn .linear import Linear as GeneralLinear
3540from ...nn .lm_head import LMHead as GeneralLMHead
3641from ...nn .mlp import MLP as Qwen2MLP
3742from ...nn .norm import Norm as GeneralNorm
38- from ...nn .pp_model import GeneralModelForCausalLMPipe
43+ from ...nn .pp_model import CriterionLayerPipe , GeneralModelForCausalLMPipe
3944from ...utils .log import logger
4045from ..cache_utils import Cache , DynamicCache
4146from ..contrastive_loss import SimpleContrastiveLoss
5560from .configuration import Qwen2Config
5661
5762
63+ @dataclass
64+ class Qwen2ModelProvider (GPTModelProvider ):
65+ """Base provider for Qwen2 Models."""
66+
67+ model_type = "qwen2"
68+
69+ attention_bias : bool = True
70+
71+ bias_activation_fusion : bool = True
72+ bias_dropout_fusion : bool = True
73+
74+ transform_rules = {
75+ "dtype" : "params_dtype" ,
76+ }
77+
78+ persist_layer_norm : bool = True
79+ share_embeddings_and_output_weights : bool = False
80+
81+ def save_pretrained (self , save_directory : Union [str , os .PathLike ], ** kwargs ):
82+ """
83+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
84+ [`~PretrainedConfig.from_pretrained`] class method.
85+
86+ Args:
87+ save_directory (`str` or `os.PathLike`):
88+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
89+ kwargs:
90+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
91+ """
92+ if os .path .isfile (save_directory ):
93+ raise AssertionError (f"Provided path ({ save_directory } ) should be a directory, not a file" )
94+
95+ os .makedirs (save_directory , exist_ok = True )
96+
97+ output_config_file = os .path .join (save_directory , self .CONFIG_NAME )
98+ config_dict = asdict (self )
99+
100+ # Filter out non-serializable values
101+ def make_serializable (obj ):
102+ if isinstance (obj , dict ):
103+ return {k : make_serializable (v ) for k , v in obj .items () if make_serializable (v ) is not None }
104+ elif isinstance (obj , (list , tuple )):
105+ return [make_serializable (item ) for item in obj if make_serializable (item ) is not None ]
106+ elif isinstance (obj , (str , int , float , bool , type (None ))):
107+ return obj
108+ else :
109+ # Skip non-serializable types like partial, function, etc.
110+ return None
111+
112+ serializable_config = make_serializable (config_dict )
113+
114+ with open (output_config_file , "w" , encoding = "utf-8" ) as writer :
115+ writer .write (json .dumps (serializable_config , indent = 2 , sort_keys = True , ensure_ascii = False ) + "\n " )
116+ logger .info (f"Configuration saved in { output_config_file } " )
117+
118+
58119def rotate_half (x ):
59120 """Rotates half the hidden dims of the input."""
60121 x1 = x [..., : x .shape [- 1 ] // 2 ]
@@ -267,17 +328,27 @@ class Qwen2PretrainedModel(PretrainedModel):
267328 @classmethod
268329 def _gen_aoa_config (cls , config : Qwen2Config ):
269330 model_prefix = "" if cls == cls .base_model_class else "model."
331+ is_fleet = getattr (cls , "is_fleet" , False )
332+
270333 aoa_config = {
271334 "aoa_statements" : [
272335 f"model.layers.$LAYER_ID.self_attn.o_proj.weight^T -> { model_prefix } layers.$LAYER_ID.self_attn.o_proj.weight" ,
273336 f"model.layers.$LAYER_ID.mlp.down_proj.weight^T -> { model_prefix } layers.$LAYER_ID.mlp.down_proj.weight" ,
274- f"model.embed_tokens.weight -> { model_prefix } embed_tokens.weight" ,
275337 f"model.layers.$LAYER_ID.input_layernorm.weight -> { model_prefix } layers.$LAYER_ID.input_layernorm.weight" ,
276338 f"model.layers.$LAYER_ID.post_attention_layernorm.weight -> { model_prefix } layers.$LAYER_ID.post_attention_layernorm.weight" ,
277339 f"model.norm.weight -> { model_prefix } norm.weight" ,
278340 ]
279341 }
280342
343+ if is_fleet :
344+ aoa_config ["aoa_statements" ] += [
345+ f"model.embed_tokens.weight -> { model_prefix } embedding.embed_tokens.weight" ,
346+ ]
347+ else :
348+ aoa_config ["aoa_statements" ] += [
349+ f"model.embed_tokens.weight -> { model_prefix } embed_tokens.weight" ,
350+ ]
351+
281352 # attention qkv
282353 aoa_config ["aoa_statements" ] += [
283354 f"model.layers.$LAYER_ID.self_attn.q_proj.weight^T, model.layers.$LAYER_ID.self_attn.k_proj.weight^T, model.layers.$LAYER_ID.self_attn.v_proj.weight^T -> { model_prefix } layers.$LAYER_ID.self_attn.qkv_proj.weight, fused_qkv, num_heads={ config .num_attention_heads } , num_key_value_groups={ config .num_key_value_heads } " ,
@@ -293,22 +364,38 @@ def _gen_aoa_config(cls, config: Qwen2Config):
293364
294365 # lm_head
295366 if config .tie_word_embeddings :
296- aoa_config ["aoa_statements" ] += ["model.embed_tokens.weight -> lm_head.weight" ]
367+ if is_fleet :
368+ aoa_config ["aoa_statements" ] += [f"model.embed_tokens.weight -> { model_prefix } lm_head.weight" ]
369+ else :
370+ aoa_config ["aoa_statements" ] += ["model.embed_tokens.weight -> lm_head.weight" ]
371+ else :
372+ if is_fleet :
373+ aoa_config ["aoa_statements" ] += [f"lm_head.weight -> { model_prefix } lm_head.weight" ]
297374
298375 return aoa_config
299376
300377 @classmethod
301378 def _gen_inv_aoa_config (cls , config : Qwen2Config ):
302379 model_prefix = "" if cls == cls .base_model_class else "model."
380+ is_fleet = getattr (cls , "is_fleet" , False )
381+
303382 aoa_statements = [
304383 f"{ model_prefix } layers.$LAYER_ID.self_attn.o_proj.weight^T -> model.layers.$LAYER_ID.self_attn.o_proj.weight" ,
305384 f"{ model_prefix } layers.$LAYER_ID.mlp.down_proj.weight^T -> model.layers.$LAYER_ID.mlp.down_proj.weight" ,
306- f"{ model_prefix } embed_tokens.weight -> model.embed_tokens.weight" ,
307385 f"{ model_prefix } layers.$LAYER_ID.input_layernorm.weight -> model.layers.$LAYER_ID.input_layernorm.weight" ,
308386 f"{ model_prefix } layers.$LAYER_ID.post_attention_layernorm.weight -> model.layers.$LAYER_ID.post_attention_layernorm.weight" ,
309387 f"{ model_prefix } norm.weight -> model.norm.weight" ,
310388 ]
311389
390+ if is_fleet :
391+ aoa_statements += [
392+ f"{ model_prefix } embedding.embed_tokens.weight -> model.embed_tokens.weight" ,
393+ ]
394+ else :
395+ aoa_statements += [
396+ f"{ model_prefix } embed_tokens.weight -> model.embed_tokens.weight" ,
397+ ]
398+
312399 aoa_statements += [
313400 f"{ model_prefix } layers.$LAYER_ID.self_attn.qkv_proj.weight -> model.layers.$LAYER_ID.self_attn.q_proj.weight, model.layers.$LAYER_ID.self_attn.k_proj.weight, model.layers.$LAYER_ID.self_attn.v_proj.weight , fused_qkv, num_heads={ config .num_attention_heads } , num_key_value_groups = { config .num_key_value_heads } " ,
314401 ]
@@ -331,7 +418,13 @@ def _gen_inv_aoa_config(cls, config: Qwen2Config):
331418 ]
332419
333420 if config .tie_word_embeddings :
334- aoa_statements += ["lm_head.weight -> _" ]
421+ if is_fleet :
422+ aoa_statements += [f"{ model_prefix } lm_head.weight -> _" ]
423+ else :
424+ aoa_statements += ["lm_head.weight -> _" ]
425+ else :
426+ if is_fleet :
427+ aoa_statements += [f"{ model_prefix } lm_head.weight -> lm_head.weight" ]
335428
336429 aoa_config = {"aoa_statements" : aoa_statements }
337430 return aoa_config
@@ -574,6 +667,30 @@ def forward(
574667
575668
576669class Qwen2ForCausalLM (Qwen2PretrainedModel ):
670+ is_fleet = True
671+
672+ def __new__ (cls , config ):
673+ # Hybrid parallel config convert.
674+ config .tensor_model_parallel_size = max (config .tensor_model_parallel_size , 1 )
675+ config .context_parallel_size = max (config .context_parallel_size , 1 )
676+ config .pipeline_model_parallel_size = max (config .pipeline_model_parallel_size , 1 )
677+ config .virtual_pipeline_model_parallel_size = max (config .virtual_pipeline_model_parallel_size , 1 )
678+ config .expert_model_parallel_size = max (config .expert_model_parallel_size , 1 )
679+
680+ model_provider_class = Qwen2ModelProvider
681+ model_provider = model_provider_class .from_config (config )
682+ loss_fn = None
683+ if getattr (config , "dpo_config" , None ):
684+ loss_fn = CriterionLayerPipe (config , use_infohub = True )
685+ gpt_model = model_provider .provide (loss_fn = loss_fn )
686+ gpt_model ._gen_aoa_config = cls ._gen_aoa_config
687+ gpt_model ._gen_inv_aoa_config = cls ._gen_inv_aoa_config
688+ gpt_model .config_to_save = config
689+ gpt_model .is_fleet = cls .is_fleet
690+ return gpt_model
691+
692+
693+ class Qwen2ForCausalLMDeprecated (Qwen2PretrainedModel ):
577694 enable_to_static_method = True
578695 _tied_weights_keys = ["lm_head.weight" ]
579696
@@ -903,7 +1020,33 @@ def encode(
9031020 return last_hidden_states
9041021
9051022
906- class Qwen2ForCausalLMPipe (GeneralModelForCausalLMPipe ):
1023+ class Qwen2ForCausalLMPipe (Qwen2PretrainedModel , GeneralModelForCausalLMPipe ):
1024+ is_fleet = True
1025+
1026+ def __new__ (cls , config ):
1027+ # Hybrid parallel config convert.
1028+ config .tensor_model_parallel_size = max (config .tensor_model_parallel_size , 1 )
1029+ config .context_parallel_size = max (config .context_parallel_size , 1 )
1030+ config .pipeline_model_parallel_size = max (config .pipeline_model_parallel_size , 1 )
1031+ config .virtual_pipeline_model_parallel_size = max (config .virtual_pipeline_model_parallel_size , 1 )
1032+ config .expert_model_parallel_size = max (config .expert_model_parallel_size , 1 )
1033+
1034+ model_provider_class = Qwen2ModelProvider
1035+ model_provider = model_provider_class .from_config (config )
1036+ loss_fn = None
1037+ if getattr (config , "dpo_config" , None ):
1038+ loss_fn = CriterionLayerPipe (config , use_infohub = True )
1039+ gpt_model = model_provider .provide (loss_fn = loss_fn )
1040+ gpt_model ._gen_aoa_config = cls ._gen_aoa_config
1041+ gpt_model ._gen_inv_aoa_config = cls ._gen_inv_aoa_config
1042+ if not hasattr (config , "architectures" ):
1043+ config .architectures = [cls .__name__ .replace ("Pipe" , "" )]
1044+ gpt_model .config_to_save = config
1045+ gpt_model .is_fleet = cls .is_fleet
1046+ return gpt_model
1047+
1048+
1049+ class Qwen2ForCausalLMPipeDeprecated (GeneralModelForCausalLMPipe ):
9071050 config_class = Qwen2Config
9081051 _decoder_layer_cls = Qwen2DecoderLayer
9091052 _get_tensor_parallel_mappings = Qwen2Model ._get_tensor_parallel_mappings
@@ -924,4 +1067,6 @@ class Qwen2ForCausalLMPipe(GeneralModelForCausalLMPipe):
9241067 "Qwen2ForSequenceClassification" ,
9251068 "Qwen2ForTokenClassification" ,
9261069 "Qwen2SentenceEmbedding" ,
1070+ "Qwen2ForCausalLMDeprecated" ,
1071+ "Qwen2ForCausalLMPipeDeprecated" ,
9271072]
0 commit comments