|
43 | 43 | lm_head=None, |
44 | 44 | ) |
45 | 45 |
|
| 46 | +ALT_TENSOR_NAMES = loading_utils.ModelLoader.TensorNames( |
| 47 | + ff_up_proj="model.layers.{}.mlp.up_proj", |
| 48 | + ff_down_proj="model.layers.{}.mlp.down_proj", |
| 49 | + ff_gate_proj="model.layers.{}.mlp.gate_proj", |
| 50 | + attn_query_proj="model.layers.{}.self_attn.q_proj", |
| 51 | + attn_key_proj="model.layers.{}.self_attn.k_proj", |
| 52 | + attn_value_proj="model.layers.{}.self_attn.v_proj", |
| 53 | + attn_output_proj="model.layers.{}.self_attn.o_proj", |
| 54 | + pre_attn_norm="model.layers.{}.input_layernorm", |
| 55 | + post_attn_norm="model.layers.{}.post_attention_layernorm", |
| 56 | + pre_ff_norm="model.layers.{}.pre_feedforward_layernorm", |
| 57 | + post_ff_norm="model.layers.{}.post_feedforward_layernorm", |
| 58 | + embedding="model.embed_tokens", |
| 59 | + final_norm="model.norm", |
| 60 | +) |
| 61 | + |
46 | 62 |
|
47 | 63 | class Gemma2Block(attention.TransformerBlock): |
48 | 64 |
|
@@ -281,9 +297,18 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig: |
281 | 297 |
|
282 | 298 |
|
283 | 299 | def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module: |
284 | | - return model_builder.build_decoder_only_model( |
285 | | - checkpoint_path=checkpoint_path, |
286 | | - config=get_model_config_2b(**kwargs), |
287 | | - tensor_names=TENSOR_NAMES, |
288 | | - model_class=Gemma2, |
289 | | - ) |
| 300 | + try: |
| 301 | + return model_builder.build_decoder_only_model( |
| 302 | + checkpoint_path=checkpoint_path, |
| 303 | + config=get_model_config_2b(**kwargs), |
| 304 | + tensor_names=TENSOR_NAMES, |
| 305 | + model_class=Gemma2, |
| 306 | + ) |
| 307 | + except KeyError as ke: |
| 308 | + # Also attempt to load with an alternative naming scheme. |
| 309 | + return model_builder.build_decoder_only_model( |
| 310 | + checkpoint_path=checkpoint_path, |
| 311 | + config=get_model_config_2b(**kwargs), |
| 312 | + tensor_names=ALT_TENSOR_NAMES, |
| 313 | + model_class=Gemma2, |
| 314 | + ) |
0 commit comments