@@ -37,17 +37,64 @@ def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: b
3737 return False
3838
3939
40- def get_tensor_name (name : str ) -> str :
41- if "projection" in name :
42- return name
43- if "mm_projector" in name :
44- name = name .replace ("model.mm_projector" , "mm" )
45- name = re .sub (r'mm\.mlp\.mlp' , 'mm.model.mlp' , name , count = 1 )
46- name = re .sub (r'mm\.peg\.peg' , 'mm.model.peg' , name , count = 1 )
47- return name
40+ def get_tensor_name_from_janus (name : str ) -> str :
41+ name = re .sub (r'^vision_tower\.blocks\.(\d+)\.attn\.qkv\.(weight|bias)$' , r'v.blk.\1.attn_qkv.\2' ,name )
42+ name = re .sub (r'^vision_tower\.blocks\.(\d+)\.norm1\.(.*)$' , r'v.blk.\1.ln1.\2' , name )
43+ name = re .sub (r'^vision_tower\.blocks\.(\d+)\.attn\.proj\.(.*)$' , r'v.blk.\1.attn_out.\2' , name )
44+ name = re .sub (r'^vision_tower\.blocks\.(\d+)\.norm2\.(.*)$' , r'v.blk.\1.ln2.\2' , name )
45+ name = re .sub (r'^vision_tower\.blocks\.(\d+)\.mlp\.fc1\.(.*)$' , r'v.blk.\1.ffn_down.\2' , name )
46+ name = re .sub (r'^vision_tower\.blocks\.(\d+)\.mlp\.fc2\.(.*)$' , r'v.blk.\1.ffn_up.\2' , name )
47+ name = re .sub (r'^vision_tower\.patch_embed\.proj\.(.*)$' , r'v.patch_embd.\1' , name )
48+ name = re .sub (r'^vision_tower\.pos_embed$' , r'v.position_embd.weight' , name )
49+ name = re .sub (r'^vision_tower\.norm\.(weight|bias)$' , r'v.post_ln.\1' , name )
50+
51+ name = name .replace ("vision_tower" , "v" )
52+ name = name .replace ("text_model" , "t" )
53+ name = name .replace ("vision_model" , "v" )
54+ name = name .replace ("encoder.layers" , "blk" )
55+ name = name .replace ("blocks" , "blk" )
56+ name = name .replace ("embeddings." , "" )
57+ name = name .replace ("_proj" , "" )
58+ name = name .replace ("self_attn." , "attn_" )
59+ name = name .replace ("layer_norm" , "ln" )
60+ name = name .replace ("layernorm" , "ln" )
61+ name = name .replace ("mlp.fc1" , "ffn_down" )
62+ name = name .replace ("mlp.fc2" , "ffn_up" )
63+ name = name .replace ("embedding" , "embd" )
64+ name = name .replace ("final" , "post" )
65+ name = name .replace ("layrnorm" , "ln" )
66+
67+ return name
68+
69+
70+ def process_and_save_tensor (tensor : torch .Tensor , new_name : str , ftype : int , fout ) -> None :
71+ """Process a tensor (squeeze, cast dtype, log) and save it to `fout`."""
72+ data = tensor .squeeze ().numpy ()
73+ n_dims = len (data .shape )
74+ ftype_str = {0 : "f32" , 1 : "f16" }
4875
49- return name .replace ("text_model" , "t" ).replace ("vision_model" , "v" ).replace ("encoder.layers" , "blk" ).replace ("embeddings." , "" ).replace ("_proj" , "" ).replace ("self_attn." , "attn_" ).replace ("layer_norm" , "ln" ).replace ("layernorm" , "ln" ).replace ("mlp.fc1" , "ffn_down" ).replace ("mlp.fc2" , "ffn_up" ).replace ("embedding" , "embd" ).replace ("final" , "post" ).replace ("layrnorm" , "ln" )
76+ ftype_cur = 0
77+ if n_dims == 4 :
78+ print (f"tensor { new_name } is always saved in f16" )
79+ data = data .astype (np .float16 )
80+ ftype_cur = 1
81+ elif ftype == 1 :
82+ if new_name .endswith (".weight" ) and n_dims == 2 :
83+ print (" Converting to float16" )
84+ data = data .astype (np .float16 )
85+ ftype_cur = 1
86+ else :
87+ print (" Converting to float32" )
88+ data = data .astype (np .float32 )
89+ ftype_cur = 0
90+ else :
91+ if data .dtype != np .float32 :
92+ print (" Converting to float32" )
93+ data = data .astype (np .float32 )
94+ ftype_cur = 0
5095
96+ print (f"{ new_name } - { ftype_str [ftype_cur ]} - shape = { data .shape } " )
97+ fout .add_tensor (new_name , data )
5198
5299def bytes_to_unicode ():
53100 """
@@ -261,35 +308,17 @@ def bytes_to_unicode():
261308 print (f"skipping parameter: { name } " )
262309 continue
263310
264- name = get_tensor_name (name )
265- data = data .squeeze ().numpy ()
311+ name = get_tensor_name_from_janus (name )
266312
267- n_dims = len (data .shape )
313+ # Handle the qkv projection weights and biases
314+ if "qkv" in name :
315+ q_tensor , k_tensor , v_tensor = torch .chunk (data , 3 , dim = 0 )
268316
269- # ftype == 0 -> float32, ftype == 1 -> float16
270- ftype_cur = 0
271- if n_dims == 4 :
272- print (f"tensor { name } is always saved in f16" )
273- data = data .astype (np .float16 )
274- ftype_cur = 1
275- elif ftype == 1 :
276- if name [- 7 :] == ".weight" and n_dims == 2 :
277- print (" Converting to float16" )
278- data = data .astype (np .float16 )
279- ftype_cur = 1
280- else :
281- print (" Converting to float32" )
282- data = data .astype (np .float32 )
283- ftype_cur = 0
317+ process_and_save_tensor (q_tensor , name .replace ("qkv" , "q" ), ftype , fout )
318+ process_and_save_tensor (k_tensor , name .replace ("qkv" , "k" ), ftype , fout )
319+ process_and_save_tensor (v_tensor , name .replace ("qkv" , "v" ), ftype , fout )
284320 else :
285- if data .dtype != np .float32 :
286- print (" Converting to float32" )
287- data = data .astype (np .float32 )
288- ftype_cur = 0
289-
290- print (f"{ name } - { ftype_str [ftype_cur ]} - shape = { data .shape } " )
291- fout .add_tensor (name , data )
292-
321+ process_and_save_tensor (data , name , ftype , fout )
293322
294323fout .write_header_to_file ()
295324fout .write_kv_data_to_file ()
0 commit comments