Skip to content

Commit 96b567c

Browse files
author
rtp-llm
committed
feat - suport tie_word_embeddings option in hf config.json, to fix qwen1.5 0.5b finetune load failed
1 parent 51a1184 commit 96b567c

File tree

14 files changed

+58
-15
lines changed

14 files changed

+58
-15
lines changed

maga_transformer/config/gpt_init_model_parameters.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ class GptInitModelParameters:
100100
"normalize_lm_head_weight",
101101
"ref_model",
102102
"is_quant_mode",
103-
"model_type"
103+
"model_type",
104+
"tie_word_embeddings"
104105
}
105106

106107
def __init__(self,
@@ -134,6 +135,7 @@ def __init__(self,
134135
self.ref_model: Optional[torch.nn.Module] = None
135136

136137
self.model_type = ModelType.NORMAL
138+
self.tie_word_embeddings = False
137139

138140
for k, v in kwargs.items():
139141
setattr(self, k, v)

maga_transformer/models/bloom.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def from_huggingface(config_json: Dict[str, Any]):
107107
config.layernorm_eps = config_json['layer_norm_epsilon']
108108
config.inter_size = hidden_size * 4
109109
config.special_tokens.eos_token_id = config_json['eos_token_id']
110+
config.tie_word_embeddings = config_json.get('tie_word_embeddings', False)
110111
return config
111112

112113
@classmethod

maga_transformer/models/chat_glm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def from_huggingface(cls, config_json: Dict[str, Any]):
8484
config.special_tokens.bos_token_id = config_json.get('bos_token_id', config.special_tokens.bos_token_id)
8585
config.special_tokens.eos_token_id = config_json.get('eos_token_id', config.special_tokens.eos_token_id)
8686
config.src_quantization_bit = config_json.get('quantization_bit', 0)
87+
config.tie_word_embeddings = config_json.get('tie_word_embeddings', False)
8788
return config
8889

8990
# override

maga_transformer/models/chat_glm_v2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def from_huggingface(cls, config_json: Dict[str, Any]):
5252
config.special_tokens.eos_token_id = config_json['eos_token_id']
5353
config.src_quantization_bit = config_json.get('quantization_bit', 0)
5454
config.rotary_embedding_dim = config.size_per_head
55+
config.tie_word_embeddings = config_json.get('tie_word_embeddings', False)
5556
config = cls.get_rotary_embedding_scale(config, config_json)
5657
return config
5758

maga_transformer/models/falcon.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def _create_config(cls, ckpt_path: str):
8181
config.special_tokens.bos_token_id = config_json['bos_token_id']
8282
config.special_tokens.eos_token_id = config_json['eos_token_id']
8383
config.rotary_embedding_dim = config.size_per_head
84+
config.tie_word_embeddings = config_json.get('tie_word_embeddings', False)
8485
return config
8586

8687
register_model('falcon', Falcon, ["FalconForCausalLM"])

maga_transformer/models/gpt_neox.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def from_huggingface(config_json: Dict[str, Any]):
6464
config.has_post_decoder_layernorm = True
6565
config.norm_type = 'layernorm'
6666
config.use_norm_input_residual = True
67+
config.tie_word_embeddings = config_json.get('tie_word_embeddings', False)
6768

6869
return config
6970

maga_transformer/models/llama.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def from_huggingface(config, config_json: Dict[str, Any]):
7575
config.inter_size = config_json['intermediate_size']
7676
config.rotary_embedding_base = int(config_json.get('rope_theta', 10000))
7777
config.rotary_embedding_dim = config.size_per_head
78+
config.tie_word_embeddings = config_json.get('tie_word_embeddings', False)
7879
if config_json.get('rope_scaling', None):
7980
if config_json['rope_scaling']['type'] == 'dynamic':
8081
config.dynamic_embedding_scalar = config_json['rope_scaling']['factor']
@@ -105,6 +106,7 @@ def from_params(config: GptInitModelParameters, params_json: Dict[str, Any]):
105106
params_json['multiple_of'])
106107
config.special_tokens.eos_token_id = 2
107108
config.rotary_embedding_dim = config.size_per_head
109+
config.tie_word_embeddings = params_json.get('tie_word_embeddings', False)
108110
return config
109111

110112
@classmethod

maga_transformer/models/phi.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def _create_config(cls, ckpt_path: str):
5050
activation_type='gelu',
5151
has_positional_encoding=False,
5252
has_post_decoder_layernorm=True,
53-
has_lm_head_bias=True)
53+
has_lm_head_bias=True,
54+
tie_word_embeddings = config_dict.get('tie_word_embeddings', False))
5455
config.head_num_kv = config.head_num
5556
return config
5657

maga_transformer/models/qwen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def _from_hf(config: GptInitModelParameters, ckpt_path: str):
267267
config.rotary_embedding_base = int(config_json.get('rotary_emb_base', 10000))
268268
config.rotary_embedding_dim = config.size_per_head
269269
config.special_tokens.eos_token_id = config_json.get("eos_token_id", config.special_tokens.eos_token_id)
270+
config.tie_word_embeddings = config_json.get('tie_word_embeddings', False)
270271

271272
quant_config = config_json.get("quantization_config", None)
272273
if quant_config is not None:

maga_transformer/models/qwen_v2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def _from_hf(config: GptInitModelParameters, ckpt_path: str):
193193
config.vocab_size = config_json["vocab_size"]
194194
config.rotary_embedding_dim = config.size_per_head
195195
config.layernorm_eps = config_json.get("rms_norm_eps", 1e-06)
196+
config.tie_word_embeddings = config_json.get('tie_word_embeddings', False)
196197

197198
quant_config = config_json.get("quantization_config", None)
198199
if quant_config is not None:

0 commit comments

Comments
 (0)