Skip to content

Commit 50f279c

Browse files
talumbaucopybara-github
authored andcommitted
Support Gemma 2 loading from Hugging Face
PiperOrigin-RevId: 719670366
1 parent 5ad0128 commit 50f279c

File tree

1 file changed

+31
-6
lines changed
  • ai_edge_torch/generative/examples/gemma

1 file changed

+31
-6
lines changed

ai_edge_torch/generative/examples/gemma/gemma2.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,22 @@
4343
lm_head=None,
4444
)
4545

46+
ALT_TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
47+
ff_up_proj="model.layers.{}.mlp.up_proj",
48+
ff_down_proj="model.layers.{}.mlp.down_proj",
49+
ff_gate_proj="model.layers.{}.mlp.gate_proj",
50+
attn_query_proj="model.layers.{}.self_attn.q_proj",
51+
attn_key_proj="model.layers.{}.self_attn.k_proj",
52+
attn_value_proj="model.layers.{}.self_attn.v_proj",
53+
attn_output_proj="model.layers.{}.self_attn.o_proj",
54+
pre_attn_norm="model.layers.{}.input_layernorm",
55+
post_attn_norm="model.layers.{}.post_attention_layernorm",
56+
pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
57+
post_ff_norm="model.layers.{}.post_feedforward_layernorm",
58+
embedding="model.embed_tokens",
59+
final_norm="model.norm",
60+
)
61+
4662

4763
class Gemma2Block(attention.TransformerBlock):
4864

@@ -281,9 +297,18 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
281297

282298

283299
def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
284-
return model_builder.build_decoder_only_model(
285-
checkpoint_path=checkpoint_path,
286-
config=get_model_config_2b(**kwargs),
287-
tensor_names=TENSOR_NAMES,
288-
model_class=Gemma2,
289-
)
300+
try:
301+
return model_builder.build_decoder_only_model(
302+
checkpoint_path=checkpoint_path,
303+
config=get_model_config_2b(**kwargs),
304+
tensor_names=TENSOR_NAMES,
305+
model_class=Gemma2,
306+
)
307+
except KeyError as ke:
308+
# Also attempt to load with an alternative naming scheme.
309+
return model_builder.build_decoder_only_model(
310+
checkpoint_path=checkpoint_path,
311+
config=get_model_config_2b(**kwargs),
312+
tensor_names=ALT_TENSOR_NAMES,
313+
model_class=Gemma2,
314+
)

0 commit comments

Comments
 (0)