@@ -81,14 +81,14 @@ def _load_pretrained_weights(self):
81
81
82
82
# Mapping (`hf: ours`) of decoder layers
83
83
for i in range (12 ):
84
- mapping [f'transformer.h.{ i } .ln_1.weight' ] = f'blocks.{ i } .pre_norm .weight'
85
- mapping [f'transformer.h.{ i } .ln_1.bias' ] = f'blocks.{ i } .pre_norm .bias'
84
+ mapping [f'transformer.h.{ i } .ln_1.weight' ] = f'blocks.{ i } .attn_norm .weight'
85
+ mapping [f'transformer.h.{ i } .ln_1.bias' ] = f'blocks.{ i } .attn_norm .bias'
86
86
mapping [f'transformer.h.{ i } .attn.c_attn.weight' ] = f'blocks.{ i } .attn.qkv_projection.weight'
87
87
mapping [f'transformer.h.{ i } .attn.c_attn.bias' ] = f'blocks.{ i } .attn.qkv_projection.bias'
88
88
mapping [f'transformer.h.{ i } .attn.c_proj.weight' ] = f'blocks.{ i } .attn.output_projection.weight'
89
89
mapping [f'transformer.h.{ i } .attn.c_proj.bias' ] = f'blocks.{ i } .attn.output_projection.bias'
90
- mapping [f'transformer.h.{ i } .ln_2.weight' ] = f'blocks.{ i } .post_norm .weight'
91
- mapping [f'transformer.h.{ i } .ln_2.bias' ] = f'blocks.{ i } .post_norm .bias'
90
+ mapping [f'transformer.h.{ i } .ln_2.weight' ] = f'blocks.{ i } .ffn_norm .weight'
91
+ mapping [f'transformer.h.{ i } .ln_2.bias' ] = f'blocks.{ i } .ffn_norm .bias'
92
92
mapping [f'transformer.h.{ i } .mlp.c_fc.weight' ] = f'blocks.{ i } .ffn.linear_in.weight'
93
93
mapping [f'transformer.h.{ i } .mlp.c_fc.bias' ] = f'blocks.{ i } .ffn.linear_in.bias'
94
94
mapping [f'transformer.h.{ i } .mlp.c_proj.weight' ] = f'blocks.{ i } .ffn.linear_out.weight'
@@ -110,7 +110,11 @@ def _load_pretrained_weights(self):
110
110
new_state_dict [layer ] = torch .transpose (new_state_dict [layer ], 0 , 1 )
111
111
112
112
# Load out model. We use `strict = False` because the state does not have LoRA weights
113
- self .model .load_state_dict (new_state_dict , strict = False )
113
+ missing_keys , unexpected_keys = self .model .load_state_dict (new_state_dict , strict = False )
114
+
115
+ # make sure that only lora weights are not loaded
116
+ assert all ('lora' in key for key in missing_keys )
117
+ assert not unexpected_keys
114
118
115
119
def initialize (self ):
116
120
"""
0 commit comments