11# Copyright (c) OpenMMLab. All rights reserved.
2- import json
32import math
4- import os .path as osp
53import re
64
75import torch
@@ -17,7 +15,7 @@ class LlamaReader(BaseReader):
1715 """LlamaReader."""
1816
1917 attn_layer_prefix = 'model.layers'
20- attn_layer_patten = r'model.layers.([0-9]+).'
18+ attn_layer_patten = r'model\ .layers\ .([0-9]+).'
2119 tok_embeddings_key = 'model.embed_tokens.weight'
2220 norm_weight_key = 'model.norm.weight'
2321 output_weight_key = 'lm_head.weight'
@@ -118,79 +116,76 @@ def readers(self):
118116
119117 def model_info (self ):
120118 """Read model info."""
121- params_path = osp .join (self .model_path , 'config.json' )
122- with open (params_path ) as f :
123- model_arg = json .load (f )
124- num_layer = model_arg ['num_hidden_layers' ]
125- norm_eps = model_arg ['rms_norm_eps' ]
126- attn_head_num = model_arg ['num_attention_heads' ]
127- vocab_size = model_arg ['vocab_size' ]
128- inter_size = model_arg ['intermediate_size' ]
129- if 'num_key_value_heads' in model_arg :
130- kv_head_num = model_arg ['num_key_value_heads' ]
131- else :
132- kv_head_num = model_arg ['num_attention_heads' ]
133- hidden_units = model_arg ['hidden_size' ]
134- head_dim = model_arg .get ('head_dim' , hidden_units // attn_head_num )
135- # compute rope param
136- rope_theta = float (model_arg .get ('rope_theta' , 10000.0 ))
137- max_position_embeddings = int (model_arg .get ('max_position_embeddings' , 0 ))
138- rope_param = RopeParam (type = 'default' , base = rope_theta , dim = head_dim )
139- rope_scaling = model_arg .get ('rope_scaling' , None )
140- if isinstance (rope_scaling , dict ):
141- llama2_scaling_type = rope_scaling .get ('type' , '' )
142- llama3_scaling_type = rope_scaling .get ('rope_type' , '' )
143- if llama2_scaling_type and llama3_scaling_type \
144- and llama2_scaling_type != llama3_scaling_type :
145- raise ValueError (f'Ambiguous rope_scaling in config: { model_arg } ' )
146- scaling_type = llama2_scaling_type if llama2_scaling_type \
147- else llama3_scaling_type
148- if rope_scaling .get ('mrope_section' ) is not None :
149- # TODO: treat mrope as an option to the common rope functions
150- scaling_type = 'mrope'
151- scaling_factor = rope_scaling .get ('factor' , 0.0 )
152- if scaling_type == 'default' :
153- pass
154- elif scaling_type == 'dynamic' :
155- rope_param .type = 'dynamic'
156- rope_param .factor = scaling_factor
157- rope_param .max_position_embeddings = max_position_embeddings
158- elif scaling_type == 'linear' :
159- rope_param .type = 'linear'
160- rope_param .factor = scaling_factor
161- elif scaling_type == 'llama3' :
162- low_freq_factor = rope_scaling .get ('low_freq_factor' , 1.0 )
163- high_freq_factor = rope_scaling .get ('high_freq_factor' , 1.0 )
164- original_max_position_embeddings = model_arg ['rope_scaling' ].get (
165- 'original_max_position_embeddings' , 0 )
166- rope_param .type = 'llama3'
167- rope_param .factor = scaling_factor
168- rope_param .low_freq_factor = low_freq_factor
169- rope_param .high_freq_factor = high_freq_factor
170- rope_param .original_max_position_embeddings = original_max_position_embeddings
171- elif scaling_type == 'yarn' :
172- attention_factor = rope_scaling .get ('attention_factor' , None )
173- if attention_factor is None :
174- attention_factor = 0.1 * math .log (scaling_factor ) + 1.0
175- beta_fast = rope_scaling .get ('beta_fast' , 32.0 )
176- beta_slow = rope_scaling .get ('beta_slow' , 1.0 )
177- rope_param .type = 'yarn'
178- if 'original_max_position_embeddings' in rope_scaling :
179- original_max_position_embeddings = rope_scaling ['original_max_position_embeddings' ]
180- scaling_factor = max_position_embeddings / original_max_position_embeddings
181- else :
182- original_max_position_embeddings = max_position_embeddings
183- rope_param .factor = scaling_factor
184- rope_param .max_position_embeddings = original_max_position_embeddings
185- rope_param .attention_factor = attention_factor
186- rope_param .beta_fast = beta_fast
187- rope_param .beta_slow = beta_slow
188- elif scaling_type == 'mrope' :
189- mrope_section = rope_scaling .get ('mrope_section' )
190- rope_param .type = 'mrope'
191- rope_param .mrope_section = mrope_section
119+ model_arg = self .model_config
120+ num_layer = model_arg ['num_hidden_layers' ]
121+ norm_eps = model_arg ['rms_norm_eps' ]
122+ attn_head_num = model_arg ['num_attention_heads' ]
123+ vocab_size = model_arg ['vocab_size' ]
124+ inter_size = model_arg ['intermediate_size' ]
125+ if 'num_key_value_heads' in model_arg :
126+ kv_head_num = model_arg ['num_key_value_heads' ]
127+ else :
128+ kv_head_num = model_arg ['num_attention_heads' ]
129+ hidden_units = model_arg ['hidden_size' ]
130+ head_dim = model_arg .get ('head_dim' , hidden_units // attn_head_num )
131+ # compute rope param
132+ rope_theta = float (model_arg .get ('rope_theta' , 10000.0 ))
133+ max_position_embeddings = int (model_arg .get ('max_position_embeddings' , 0 ))
134+ rope_param = RopeParam (type = 'default' , base = rope_theta , dim = head_dim )
135+ rope_scaling = model_arg .get ('rope_scaling' , None )
136+ if isinstance (rope_scaling , dict ):
137+ llama2_scaling_type = rope_scaling .get ('type' , '' )
138+ llama3_scaling_type = rope_scaling .get ('rope_type' , '' )
139+ if llama2_scaling_type and llama3_scaling_type \
140+ and llama2_scaling_type != llama3_scaling_type :
141+ raise ValueError (f'Ambiguous rope_scaling in config: { model_arg } ' )
142+ scaling_type = llama2_scaling_type if llama2_scaling_type \
143+ else llama3_scaling_type
144+ if rope_scaling .get ('mrope_section' ) is not None :
145+ # TODO: treat mrope as an option to the common rope functions
146+ scaling_type = 'mrope'
147+ scaling_factor = rope_scaling .get ('factor' , 0.0 )
148+ if scaling_type == 'default' :
149+ pass
150+ elif scaling_type == 'dynamic' :
151+ rope_param .type = 'dynamic'
152+ rope_param .factor = scaling_factor
153+ rope_param .max_position_embeddings = max_position_embeddings
154+ elif scaling_type == 'linear' :
155+ rope_param .type = 'linear'
156+ rope_param .factor = scaling_factor
157+ elif scaling_type == 'llama3' :
158+ low_freq_factor = rope_scaling .get ('low_freq_factor' , 1.0 )
159+ high_freq_factor = rope_scaling .get ('high_freq_factor' , 1.0 )
160+ original_max_position_embeddings = model_arg ['rope_scaling' ].get ('original_max_position_embeddings' , 0 )
161+ rope_param .type = 'llama3'
162+ rope_param .factor = scaling_factor
163+ rope_param .low_freq_factor = low_freq_factor
164+ rope_param .high_freq_factor = high_freq_factor
165+ rope_param .original_max_position_embeddings = original_max_position_embeddings
166+ elif scaling_type == 'yarn' :
167+ attention_factor = rope_scaling .get ('attention_factor' , None )
168+ if attention_factor is None :
169+ attention_factor = 0.1 * math .log (scaling_factor ) + 1.0
170+ beta_fast = rope_scaling .get ('beta_fast' , 32.0 )
171+ beta_slow = rope_scaling .get ('beta_slow' , 1.0 )
172+ rope_param .type = 'yarn'
173+ if 'original_max_position_embeddings' in rope_scaling :
174+ original_max_position_embeddings = rope_scaling ['original_max_position_embeddings' ]
175+ scaling_factor = max_position_embeddings / original_max_position_embeddings
192176 else :
193- raise RuntimeError (f'Unsupported rope type: { scaling_type } ' )
177+ original_max_position_embeddings = max_position_embeddings
178+ rope_param .factor = scaling_factor
179+ rope_param .max_position_embeddings = original_max_position_embeddings
180+ rope_param .attention_factor = attention_factor
181+ rope_param .beta_fast = beta_fast
182+ rope_param .beta_slow = beta_slow
183+ elif scaling_type == 'mrope' :
184+ mrope_section = rope_scaling .get ('mrope_section' )
185+ rope_param .type = 'mrope'
186+ rope_param .mrope_section = mrope_section
187+ else :
188+ raise RuntimeError (f'Unsupported rope type: { scaling_type } ' )
194189
195190 return dict (size_per_head = head_dim ,
196191 num_layer = num_layer ,
0 commit comments