|
50 | 50 | lm_head=None, |
51 | 51 | ) |
52 | 52 |
|
| 53 | +ALT_TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( |
| 54 | + ff_up_proj="model.layers.{}.mlp.up_proj", |
| 55 | + ff_down_proj="model.layers.{}.mlp.down_proj", |
| 56 | + ff_gate_proj="model.layers.{}.mlp.gate_proj", |
| 57 | + attn_query_proj="model.layers.{}.self_attn.q_proj", |
| 58 | + attn_key_proj="model.layers.{}.self_attn.k_proj", |
| 59 | + attn_value_proj="model.layers.{}.self_attn.v_proj", |
| 60 | + attn_output_proj="model.layers.{}.self_attn.o_proj", |
| 61 | + pre_attn_norm="model.layers.{}.input_layernorm", |
| 62 | + post_attn_norm="model.layers.{}.post_attention_layernorm", |
| 63 | + pre_ff_norm="model.layers.{}.pre_feedforward_layernorm", |
| 64 | + post_ff_norm="model.layers.{}.post_feedforward_layernorm", |
| 65 | + embedding="model.embed_tokens", |
| 66 | + final_norm="model.norm", |
| 67 | +) |
53 | 68 |
|
54 | 69 | class Gemma2Block(attention.TransformerBlock): |
55 | 70 |
|
@@ -289,9 +304,18 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: |
289 | 304 |
|
290 | 305 |
|
291 | 306 | def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module: |
292 | | - return model_builder.build_decoder_only_model( |
293 | | - checkpoint_path=checkpoint_path, |
294 | | - config=get_model_config_2b(**kwargs), |
295 | | - tensor_names=TENSOR_NAMES, |
296 | | - model_class=Gemma2, |
297 | | - ) |
| 307 | + try: |
| 308 | + return model_builder.build_decoder_only_model( |
| 309 | + checkpoint_path=checkpoint_path, |
| 310 | + config=get_model_config_2b(**kwargs), |
| 311 | + tensor_names=TENSOR_NAMES, |
| 312 | + model_class=Gemma2, |
| 313 | + ) |
| 314 | + except KeyError as ke: |
| 315 | + # Also attempt to load with an alternative naming scheme. |
| 316 | + return model_builder.build_decoder_only_model( |
| 317 | + checkpoint_path=checkpoint_path, |
| 318 | + config=get_model_config_2b(**kwargs), |
| 319 | + tensor_names=ALT_TENSOR_NAMES, |
| 320 | + model_class=Gemma2, |
| 321 | + ) |
0 commit comments