@@ -236,7 +236,7 @@ def copy_weights_phi(
236
236
"lm_head.weight" : "lm_head.weight" ,
237
237
"lm_head.bias" : "lm_head.bias" ,
238
238
}
239
- if config .name .startswith (("Phi -3" , "phi-4" )):
239
+ if config .name .lower (). startswith (("phi -3" , "phi-4" )):
240
240
weight_map .update (
241
241
{
242
242
"transformer.h.{}.attn.qkv.weight" : "model.layers.{}.self_attn.qkv_proj.weight" ,
@@ -249,10 +249,12 @@ def copy_weights_phi(
249
249
gate_up_proj_weights = defaultdict (dict )
250
250
251
251
for from_name , param in lit_weights .items ():
252
+ if from_name == "lm_head.weight" and config .name .startswith ("Phi-4" ):
253
+ continue
252
254
name_template , layer_idx = layer_template (from_name )
253
255
param = load_param (param , from_name , None )
254
256
if from_name .endswith ((".attn.qkv.weight" , ".attn.qkv.bias" )):
255
- if config .name .startswith ("Phi -3" ):
257
+ if config .name .lower (). startswith (( "phi -3", "phi-4" ) ):
256
258
to_names = (weight_map [name_template ].format (layer_idx ),)
257
259
params = (param ,)
258
260
else :
@@ -282,7 +284,7 @@ def copy_weights_phi(
282
284
param = saver .store_early (param )
283
285
state_dict [to_name ] = param
284
286
285
- if config .name .startswith ("Phi -3" ):
287
+ if config .name .lower (). startswith (( "phi -3", "phi-4" ) ):
286
288
for layer_idx in list (gate_up_proj_weights ):
287
289
fc_1_weight = gate_up_proj_weights [layer_idx ]["fc_1" ]
288
290
fc_2_weight = gate_up_proj_weights [layer_idx ]["fc_2" ]
0 commit comments