Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,26 @@


def main(_):
if _MODEL_SIZE.value == '1b':
model_size = _MODEL_SIZE.value
# Auto-detect model size if it's the default '1b' or if we want to be proactive.
# We only override if detection is successful.
detected_size = gemma3.detect_model_size(flags.FLAGS.checkpoint_path)
if detected_size and _MODEL_SIZE.present:
if detected_size != _MODEL_SIZE.value:
print(f"Note: User specified model_size={_MODEL_SIZE.value}, "
f"but detected {detected_size}. Using user specification.")
elif detected_size:
model_size = detected_size
print(f"Auto-detected model size: {model_size}")

if model_size == '1b':
model_builder = gemma3.build_model_1b
elif _MODEL_SIZE.value == '270m':
elif model_size == '270m':
model_builder = gemma3.build_model_270m
elif model_size == '4b':
model_builder = gemma3.build_model_4b
else:
raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
raise ValueError(f'Unsupported model size: {model_size}')

converter.build_and_convert_to_tflite_from_flags(model_builder)

Expand Down
108 changes: 108 additions & 0 deletions litert_torch/generative/examples/gemma3/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,29 @@
lm_head=None,
)

TENSOR_NAMES_HF_4B = loading_utils.ModelLoader.TensorNames(
ff_up_proj="language_model.model.layers.{}.mlp.up_proj",
ff_down_proj="language_model.model.layers.{}.mlp.down_proj",
ff_gate_proj="language_model.model.layers.{}.mlp.gate_proj",
attn_query_proj="language_model.model.layers.{}.self_attn.q_proj",
attn_key_proj="language_model.model.layers.{}.self_attn.k_proj",
attn_value_proj="language_model.model.layers.{}.self_attn.v_proj",
attn_output_proj="language_model.model.layers.{}.self_attn.o_proj",
attn_query_norm="language_model.model.layers.{}.self_attn.q_norm",
attn_key_norm="language_model.model.layers.{}.self_attn.k_norm",
pre_attn_norm="language_model.model.layers.{}.input_layernorm",
post_attn_norm="language_model.model.layers.{}.post_attention_layernorm",
pre_ff_norm="language_model.model.layers.{}.pre_feedforward_layernorm",
post_ff_norm="language_model.model.layers.{}.post_feedforward_layernorm",
embedding="language_model.model.embed_tokens",
final_norm="language_model.model.norm",
lm_head=None,
)

TENSOR_NAMES_DICT = {
"safetensors": TENSOR_NAMES_SEP_QKV,
"kaggle": TENSOR_NAMES_FUSED_QKV,
"hf_4b": TENSOR_NAMES_HF_4B,
}


Expand Down Expand Up @@ -445,6 +465,60 @@ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
return config


def get_decoder_config_4b() -> cfg.ModelConfig:
"""Returns the model config for a Gemma3 4B model."""
norm_config = cfg.NormalizationConfig(
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True,
)
ff_config = cfg.FeedForwardConfig(
type=cfg.FeedForwardType.GATED,
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
intermediate_size=10240,
pre_ff_norm_config=norm_config,
post_ff_norm_config=norm_config,
)

def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
attn_config = cfg.AttentionConfig(
num_heads=8,
head_dim=256,
num_query_groups=1,
rotary_base=1_000_000 if (idx + 1) % 6 == 0 else 10_000,
rotary_percentage=1.0,
qkv_transpose_before_split=True,
query_norm_config=norm_config,
key_norm_config=norm_config,
logit_softcap=None,
sliding_window_size=1024,
attn_type=(
cfg.AttentionType.GLOBAL
if (idx + 1) % 6 == 0
else cfg.AttentionType.LOCAL_SLIDING
),
)
return cfg.TransformerBlockConfig(
attn_config=attn_config,
ff_config=ff_config,
pre_attention_norm_config=norm_config,
post_attention_norm_config=norm_config,
)

num_layers = 34
embedding_dim = 2560
config = cfg.ModelConfig(
vocab_size=262_208,
num_layers=num_layers,
max_seq_len=32_768,
embedding_dim=embedding_dim,
embedding_scale=embedding_dim**0.5,
block_configs=[get_block_config(i) for i in range(num_layers)],
final_norm_config=norm_config,
lm_head_use_bias=False,
final_logit_softcap=None,
)
return config


def get_fake_decoder_config_1b() -> cfg.ModelConfig:
"""Returns a fake model config for a Gemma3 1B model."""
config = get_decoder_config_1b()
Expand Down Expand Up @@ -481,6 +555,10 @@ def build_model_1b(
)
except KeyError as ke:
continue
raise RuntimeError(
f"Failed to build model from checkpoint at {checkpoint_path}. "
"None of the known tensor name mappings matched the checkpoint."
)


def build_model_270m(
Expand All @@ -503,3 +581,33 @@ def build_model_270m(
)
except KeyError as _:
continue
raise RuntimeError(
f"Failed to build model from checkpoint at {checkpoint_path}. "
"None of the known tensor name mappings matched the checkpoint."
)


def build_model_4b(
checkpoint_path: str,
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
mask_cache_size: int = 0,
) -> nn.Module:
"""Builds a Gemma3 4B model."""
# TODO(b/403644647): Better error handling for loading checkpoints with
# different tensor names.
for tensor_names in TENSOR_NAMES_DICT.values():
try:
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_decoder_config_4b(),
tensor_names=tensor_names,
model_class=Decoder,
custom_loader=custom_loader,
mask_cache_size=mask_cache_size,
)
except KeyError as _:
continue
raise RuntimeError(
f"Failed to build model from checkpoint at {checkpoint_path}. "
"None of the known tensor name mappings matched the checkpoint."
)
50 changes: 50 additions & 0 deletions litert_torch/generative/examples/gemma3/gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,45 @@
import litert_torch.generative.layers.model_config as cfg
from litert_torch.generative.utilities import export_config as export_cfg
import litert_torch.generative.utilities.loader as loading_utils
import json
import os
import torch
from torch import nn


PROJECTION_TENSOR_NAME = "multi_modal_projector.linear"


def detect_model_size(checkpoint_path: str) -> Optional[str]:
"""Attempts to detect the model size from config.json in the checkpoint path.

Args:
checkpoint_path: Path to the checkpoint directory.

Returns:
'270m', '1b', or None if detection fails or config.json is missing.
"""
config_path = os.path.join(checkpoint_path, "config.json")
if not os.path.exists(config_path):
return None

try:
with open(config_path, "r") as f:
config = json.load(f)

num_layers = config.get("num_hidden_layers")
hidden_size = config.get("hidden_size")

if num_layers == 18 or hidden_size == 640:
return "270m"
if num_layers == 26 or hidden_size == 1152:
return "1b"
except Exception:
return None

return None


@dataclass
class Gemma3MMConfig:
"""Gemma3 model configurations."""
Expand Down Expand Up @@ -197,3 +229,21 @@ def build_model_270m(
# TODO: Load the parameters of decoder from checkpoint.
model.eval()
return model


def build_model_4b(
checkpoint_path: str,
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
mask_cache_size: int = 0,
) -> decoder.Decoder:
"""Builds a Gemma3 4B model."""
if checkpoint_path:
model = decoder.build_model_4b(
checkpoint_path, custom_loader, mask_cache_size
)
else:
config = decoder.get_decoder_config_4b()
model = decoder.Decoder(config, mask_cache_size)
# TODO: Load the parameters of decoder from checkpoint.
model.eval()
return model
Loading