Skip to content

Commit c04d659

Browse files
snnncopybara-github
authored andcommitted
Add support for 4B model in GenAI gemma3 example
PiperOrigin-RevId: 866650745
1 parent 4c84cc3 commit c04d659

File tree

3 files changed

+175
-3
lines changed

3 files changed

+175
-3
lines changed

litert_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,26 @@
3131

3232

3333
def main(_):
34-
if _MODEL_SIZE.value == '1b':
34+
model_size = _MODEL_SIZE.value
35+
# Auto-detect model size if it's the default '1b' or if we want to be proactive.
36+
# We only override if detection is successful.
37+
detected_size = gemma3.detect_model_size(flags.FLAGS.checkpoint_path)
38+
if detected_size and _MODEL_SIZE.present:
39+
if detected_size != _MODEL_SIZE.value:
40+
print(f"Note: User specified model_size={_MODEL_SIZE.value}, "
41+
f"but detected {detected_size}. Using user specification.")
42+
elif detected_size:
43+
model_size = detected_size
44+
print(f"Auto-detected model size: {model_size}")
45+
46+
if model_size == '1b':
3547
model_builder = gemma3.build_model_1b
36-
elif _MODEL_SIZE.value == '270m':
48+
elif model_size == '270m':
3749
model_builder = gemma3.build_model_270m
50+
elif model_size == '4b':
51+
model_builder = gemma3.build_model_4b
3852
else:
39-
raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
53+
raise ValueError(f'Unsupported model size: {model_size}')
4054

4155
converter.build_and_convert_to_tflite_from_flags(model_builder)
4256

litert_torch/generative/examples/gemma3/decoder.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,29 @@
6767
lm_head=None,
6868
)
6969

70+
TENSOR_NAMES_HF_4B = loading_utils.ModelLoader.TensorNames(
71+
ff_up_proj="language_model.model.layers.{}.mlp.up_proj",
72+
ff_down_proj="language_model.model.layers.{}.mlp.down_proj",
73+
ff_gate_proj="language_model.model.layers.{}.mlp.gate_proj",
74+
attn_query_proj="language_model.model.layers.{}.self_attn.q_proj",
75+
attn_key_proj="language_model.model.layers.{}.self_attn.k_proj",
76+
attn_value_proj="language_model.model.layers.{}.self_attn.v_proj",
77+
attn_output_proj="language_model.model.layers.{}.self_attn.o_proj",
78+
attn_query_norm="language_model.model.layers.{}.self_attn.q_norm",
79+
attn_key_norm="language_model.model.layers.{}.self_attn.k_norm",
80+
pre_attn_norm="language_model.model.layers.{}.input_layernorm",
81+
post_attn_norm="language_model.model.layers.{}.post_attention_layernorm",
82+
pre_ff_norm="language_model.model.layers.{}.pre_feedforward_layernorm",
83+
post_ff_norm="language_model.model.layers.{}.post_feedforward_layernorm",
84+
embedding="language_model.model.embed_tokens",
85+
final_norm="language_model.model.norm",
86+
lm_head=None,
87+
)
88+
7089
TENSOR_NAMES_DICT = {
7190
"safetensors": TENSOR_NAMES_SEP_QKV,
7291
"kaggle": TENSOR_NAMES_FUSED_QKV,
92+
"hf_4b": TENSOR_NAMES_HF_4B,
7393
}
7494

7595

@@ -445,6 +465,60 @@ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
445465
return config
446466

447467

468+
def get_decoder_config_4b() -> cfg.ModelConfig:
469+
"""Returns the model config for a Gemma3 4B model."""
470+
norm_config = cfg.NormalizationConfig(
471+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True,
472+
)
473+
ff_config = cfg.FeedForwardConfig(
474+
type=cfg.FeedForwardType.GATED,
475+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
476+
intermediate_size=10240,
477+
pre_ff_norm_config=norm_config,
478+
post_ff_norm_config=norm_config,
479+
)
480+
481+
def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
482+
attn_config = cfg.AttentionConfig(
483+
num_heads=8,
484+
head_dim=256,
485+
num_query_groups=1,
486+
rotary_base=1_000_000 if (idx + 1) % 6 == 0 else 10_000,
487+
rotary_percentage=1.0,
488+
qkv_transpose_before_split=True,
489+
query_norm_config=norm_config,
490+
key_norm_config=norm_config,
491+
logit_softcap=None,
492+
sliding_window_size=1024,
493+
attn_type=(
494+
cfg.AttentionType.GLOBAL
495+
if (idx + 1) % 6 == 0
496+
else cfg.AttentionType.LOCAL_SLIDING
497+
),
498+
)
499+
return cfg.TransformerBlockConfig(
500+
attn_config=attn_config,
501+
ff_config=ff_config,
502+
pre_attention_norm_config=norm_config,
503+
post_attention_norm_config=norm_config,
504+
)
505+
506+
num_layers = 34
507+
embedding_dim = 2560
508+
config = cfg.ModelConfig(
509+
vocab_size=262_208,
510+
num_layers=num_layers,
511+
max_seq_len=32_768,
512+
embedding_dim=embedding_dim,
513+
embedding_scale=embedding_dim**0.5,
514+
block_configs=[get_block_config(i) for i in range(num_layers)],
515+
final_norm_config=norm_config,
516+
lm_head_use_bias=False,
517+
final_logit_softcap=None,
518+
)
519+
return config
520+
521+
448522
def get_fake_decoder_config_1b() -> cfg.ModelConfig:
449523
"""Returns a fake model config for a Gemma3 1B model."""
450524
config = get_decoder_config_1b()
@@ -481,6 +555,10 @@ def build_model_1b(
481555
)
482556
except KeyError as ke:
483557
continue
558+
raise RuntimeError(
559+
f"Failed to build model from checkpoint at {checkpoint_path}. "
560+
"None of the known tensor name mappings matched the checkpoint."
561+
)
484562

485563

486564
def build_model_270m(
@@ -503,3 +581,33 @@ def build_model_270m(
503581
)
504582
except KeyError as _:
505583
continue
584+
raise RuntimeError(
585+
f"Failed to build model from checkpoint at {checkpoint_path}. "
586+
"None of the known tensor name mappings matched the checkpoint."
587+
)
588+
589+
590+
def build_model_4b(
591+
checkpoint_path: str,
592+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
593+
mask_cache_size: int = 0,
594+
) -> nn.Module:
595+
"""Builds a Gemma3 4B model."""
596+
# TODO(b/403644647): Better error handling for loading checkpoints with
597+
# different tensor names.
598+
for tensor_names in TENSOR_NAMES_DICT.values():
599+
try:
600+
return model_builder.build_decoder_only_model(
601+
checkpoint_path=checkpoint_path,
602+
config=get_decoder_config_4b(),
603+
tensor_names=tensor_names,
604+
model_class=Decoder,
605+
custom_loader=custom_loader,
606+
mask_cache_size=mask_cache_size,
607+
)
608+
except KeyError as _:
609+
continue
610+
raise RuntimeError(
611+
f"Failed to build model from checkpoint at {checkpoint_path}. "
612+
"None of the known tensor name mappings matched the checkpoint."
613+
)

litert_torch/generative/examples/gemma3/gemma3.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,45 @@
2525
import litert_torch.generative.layers.model_config as cfg
2626
from litert_torch.generative.utilities import export_config as export_cfg
2727
import litert_torch.generative.utilities.loader as loading_utils
28+
import json
29+
import os
2830
import torch
2931
from torch import nn
3032

3133

3234
PROJECTION_TENSOR_NAME = "multi_modal_projector.linear"
3335

3436

37+
def detect_model_size(checkpoint_path: str) -> Optional[str]:
38+
"""Attempts to detect the model size from config.json in the checkpoint path.
39+
40+
Args:
41+
checkpoint_path: Path to the checkpoint directory.
42+
43+
Returns:
44+
'270m', '1b', or None if detection fails or config.json is missing.
45+
"""
46+
config_path = os.path.join(checkpoint_path, "config.json")
47+
if not os.path.exists(config_path):
48+
return None
49+
50+
try:
51+
with open(config_path, "r") as f:
52+
config = json.load(f)
53+
54+
num_layers = config.get("num_hidden_layers")
55+
hidden_size = config.get("hidden_size")
56+
57+
if num_layers == 18 or hidden_size == 640:
58+
return "270m"
59+
if num_layers == 26 or hidden_size == 1152:
60+
return "1b"
61+
except Exception:
62+
return None
63+
64+
return None
65+
66+
3567
@dataclass
3668
class Gemma3MMConfig:
3769
"""Gemma3 model configurations."""
@@ -197,3 +229,21 @@ def build_model_270m(
197229
# TODO: Load the parameters of decoder from checkpoint.
198230
model.eval()
199231
return model
232+
233+
234+
def build_model_4b(
235+
checkpoint_path: str,
236+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
237+
mask_cache_size: int = 0,
238+
) -> decoder.Decoder:
239+
"""Builds a Gemma3 4B model."""
240+
if checkpoint_path:
241+
model = decoder.build_model_4b(
242+
checkpoint_path, custom_loader, mask_cache_size
243+
)
244+
else:
245+
config = decoder.get_decoder_config_4b()
246+
model = decoder.Decoder(config, mask_cache_size)
247+
# TODO: Load the parameters of decoder from checkpoint.
248+
model.eval()
249+
return model

0 commit comments

Comments
 (0)