Skip to content

Commit 46aa456

Browse files
talumbaucopybara-github
authored andcommitted
Support Gemma 2 loading from Hugging Face
PiperOrigin-RevId: 721195210
1 parent 7f2b452 commit 46aa456

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

ai_edge_torch/generative/examples/experimental/gemma/convert_gemma2_gpu_to_tflite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
)
3939
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
4040
'output_name_prefix',
41-
'gemma2',
41+
'gemma2_gpu',
4242
'The prefix of the output tflite model name.',
4343
)
4444
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(

ai_edge_torch/generative/examples/experimental/gemma/gemma2_gpu.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,21 @@
5050
lm_head=None,
5151
)
5252

53+
ALT_TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
54+
ff_up_proj="model.layers.{}.mlp.up_proj",
55+
ff_down_proj="model.layers.{}.mlp.down_proj",
56+
ff_gate_proj="model.layers.{}.mlp.gate_proj",
57+
attn_query_proj="model.layers.{}.self_attn.q_proj",
58+
attn_key_proj="model.layers.{}.self_attn.k_proj",
59+
attn_value_proj="model.layers.{}.self_attn.v_proj",
60+
attn_output_proj="model.layers.{}.self_attn.o_proj",
61+
pre_attn_norm="model.layers.{}.input_layernorm",
62+
post_attn_norm="model.layers.{}.post_attention_layernorm",
63+
pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
64+
post_ff_norm="model.layers.{}.post_feedforward_layernorm",
65+
embedding="model.embed_tokens",
66+
final_norm="model.norm",
67+
)
5368

5469
class Gemma2Block(attention.TransformerBlock):
5570

@@ -289,9 +304,18 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
289304

290305

291306
def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
292-
return model_builder.build_decoder_only_model(
293-
checkpoint_path=checkpoint_path,
294-
config=get_model_config_2b(**kwargs),
295-
tensor_names=TENSOR_NAMES,
296-
model_class=Gemma2,
297-
)
307+
try:
308+
return model_builder.build_decoder_only_model(
309+
checkpoint_path=checkpoint_path,
310+
config=get_model_config_2b(**kwargs),
311+
tensor_names=TENSOR_NAMES,
312+
model_class=Gemma2,
313+
)
314+
except KeyError as ke:
315+
# Also attempt to load with an alternative naming scheme.
316+
return model_builder.build_decoder_only_model(
317+
checkpoint_path=checkpoint_path,
318+
config=get_model_config_2b(**kwargs),
319+
tensor_names=ALT_TENSOR_NAMES,
320+
model_class=Gemma2,
321+
)

0 commit comments

Comments
 (0)