1+ import torch
2+
3+ def save (full_model , path , model_type = 'BIAS' ):
4+ if model_type == 'BIAS' :
5+ keys = [
6+ f'visual_blocks.{ i } .{ key } .{ suffix } '
7+ for i in range (8 )
8+ for key in ['norm1' , 'attn.qkv' , 'attn.proj' , 'norm2' , 'mlp.fc1' , 'mlp.fc2' ]
9+ for suffix in ['weight' , 'bias' ]
10+ ] + [
11+ f'llama.layers.{ i } .{ key } '
12+ for i in range (32 )
13+ for key in ['attention.gate' , 'attention.wq.bias' , 'attention.wo.bias' , 'feed_forward.w1.bias' , 'feed_forward.w2.bias' , 'feed_forward.w3.bias' , 'attention_norm.weight' , 'ffn_norm.weight' ]
14+ ] + [
15+ f'{ base_key } .{ suffix } '
16+ for base_key in ['clip_proj_norm' , 'visual_proj_norm' , 'visual_proj' , 'clip_proj' ]
17+ for suffix in ['weight' , 'bias' ]
18+ ] + ['llama.norm.weight' , 'visual_query.weight' , 'adapter_query.weight' ]
19+
20+
21+ elif model_type == 'LORA' :
22+ keys = [
23+ f'visual_blocks.{ i } .{ key } .{ suffix } '
24+ for i in range (8 )
25+ for key in [f'norm{ j } ' for j in range (1 , 3 )] + ['attn.qkv' , 'attn.proj' , 'mlp.fc1' , 'mlp.fc2' ]
26+ for suffix in ['weight' , 'bias' ]
27+ ] + [
28+ f'llama.layers.{ i } .{ key } '
29+ for i in range (32 )
30+ for key in ['attention.gate' , 'attention.wq.bias' , 'attention.wo.bias' , 'feed_forward.w1.bias' , 'feed_forward.w2.bias' , 'feed_forward.w3.bias' , 'attention_norm.weight' , 'ffn_norm.weight' ]
31+ + [f'attention.lora_wk_l{ j } .weight' for j in range (1 , 3 )]
32+ + [f'attention.lora_wo_l{ j } .weight' for j in range (1 , 3 )]
33+ + [f'feed_forward.lora_w{ k } _l{ j } .weight' for k in range (1 , 4 ) for j in range (1 , 3 )]
34+ + [f'attention.lora_wq_l{ j } .weight' for j in range (1 , 3 )]
35+ + [f'attention.lora_wv_l{ j } .weight' for j in range (1 , 3 )]
36+ + ['attention.new_gate' ]
37+ ] + [
38+ f'{ base_key } .{ suffix } '
39+ for base_key in ['clip_proj_norm' , 'visual_proj_norm' , 'visual_proj' , 'clip_proj' ]
40+ for suffix in ['weight' , 'bias' ]
41+ ] + ['llama.norm.weight' , 'visual_query.weight' , 'adapter_query.weight' ]
42+
43+ ## TODO: Add other model types
44+
45+ full_model_state_dict = full_model .state_dict ()
46+ small_weights = {key : full_model_state_dict [key ] for key in keys }
47+ if model_type == 'BIAS' :
48+ wrapped_small_weights = {'model' : small_weights ,'config' : {'w_bias' : True , 'w_lora' : False , 'lora_rank' : 16 }}
49+ elif model_type == 'LORA' :
50+ wrapped_small_weights = {'model' : small_weights ,'config' : {'w_bias' : True , 'w_lora' : True , 'lora_rank' : 16 }}
51+ # Save the wrapped small weights
52+ torch .save (wrapped_small_weights , path )
0 commit comments