Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,14 @@ You are a helpful assistant.

public static let llama3PromptTemplate = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>%@<|eot_id|><|start_header_id|>assistant<|end_header_id|>"

public static let phi4PromptTemplate = "<|user|>%@<|end|><|assistant|>"
public static let phi4PromptTemplate = "<|user|>%@<|end|><|assistant|>"

public static let gemma3PromptTemplate = """
<bos><start_of_turn>user


%@<end_of_turn>
<start_of_turn>model

"""
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ struct ContentView: View {
case llava
case qwen3
case phi4
case gemma3

static func fromPath(_ path: String) -> ModelType {
let filename = (path as NSString).lastPathComponent.lowercased()
Expand All @@ -98,7 +99,9 @@ struct ContentView: View {
return .qwen3
} else if filename.hasPrefix("phi4") {
return .phi4
}
} else if filename.hasPrefix("gemma3") {
return .gemma3
}
print("Unknown model type in path: \(path). Model filename should start with one of: llama, llava, qwen3, or phi4")
exit(1)
}
Expand Down Expand Up @@ -346,15 +349,15 @@ struct ContentView: View {
}

switch modelType {
case .llama, .qwen3, .phi4:
case .llama, .qwen3, .phi4, .gemma3:
runnerHolder.llamaRunner = runnerHolder.llamaRunner ?? LLaMARunner(modelPath: modelPath, tokenizerPath: tokenizerPath)
case .llava:
runnerHolder.llavaRunner = runnerHolder.llavaRunner ?? LLaVARunner(modelPath: modelPath, tokenizerPath: tokenizerPath)
}

guard !shouldStopGenerating else { return }
switch modelType {
case .llama, .qwen3, .phi4:
case .llama, .qwen3, .phi4, .gemma3:
if let runner = runnerHolder.llamaRunner, !runner.isLoaded() {
var error: Error?
let startLoadTime = Date()
Expand Down Expand Up @@ -479,6 +482,8 @@ struct ContentView: View {
prompt = String(format: Constants.llama3PromptTemplate, text)
case .phi4:
prompt = String(format: Constants.phi4PromptTemplate, text)
case .gemma3:
prompt = String(format: Constants.gemma3PromptTemplate, text)
}

try runnerHolder.llamaRunner?.generate(prompt, sequenceLength: seq_len) { token in
Expand Down
16 changes: 16 additions & 0 deletions examples/models/gemma3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.examples.models.gemma3.convert_weights import convert_weights
from executorch.examples.models.llama.model import Llama2Model


class Gemma3Model(Llama2Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)


__all__ = [
"Gemma3Model",
"convert_weights",
]
23 changes: 23 additions & 0 deletions examples/models/gemma3/config/1b_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"dim": 1152,
"ffn_dim_multiplier": 1,
"hidden_dim": 6912,
"n_heads": 4,
"head_dim": 256,
"n_kv_heads": 1,
"n_layers": 26,
"act_fn": "gelu_approx",
"norm_type": "gemma3",
"norm_eps": 1e-06,
"post_attention_norm": true,
"post_ffn_norm": true,
"rope_theta": 1000000.0,
"use_scaled_rope": false,
"apply_embedding": true,
"embedding_scale_factor": 33.941125497,
"vocab_size": 262144,
"use_hf_rope": true,
"attention_qkv_bias": false,
"use_qk_norm": true,
"qk_norm_before_rope": true
}
17 changes: 17 additions & 0 deletions examples/models/gemma3/config/gemma3_xnnpack_q8da4w.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
base:
model_class: gemma3_1b
metadata: '{"get_bos_id":[2, 105], "get_eos_ids":[1, 106]}'

model:
use_kv_cache: True
use_sdpa_with_kv_cache: True
dtype_override: fp32
local_global_attention: [512,512,512,512,512,0,512,512,512,512,512,0,512,512,512,512,512,0,512,512,512,512,512,0,512,512]

quantization:
qmode: 8da4w

backend:
xnnpack:
enabled: True
extended_ops: True
110 changes: 110 additions & 0 deletions examples/models/gemma3/convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import argparse

import json
import os
from typing import Dict

import torch
from safetensors.torch import load_file

from torchtune.models.convert_weights import get_mapped_key


# Weight mappings from Gemma 3's checkpoint to ExecuTorch's transformer parameters.
_GEMMA3_TO_EXECUTORCH = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.norm.weight": "norm.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm_fn.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm_fn.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.post_attention_norm.weight",
"model.layers.{}.pre_feedforward_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.layers.{}.post_feedforward_layernorm.weight": "layers.{}.post_ffn_norm.weight",
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
}


def gemma3_to_executorch(
state_dict: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""
Convert the state dict so that it matches what ExecuTorch's transformer definition expects.
"""
converted_state_dict = {}
for key, value in state_dict.items():
new_key = get_mapped_key(key, _GEMMA3_TO_EXECUTORCH)
converted_state_dict[new_key] = value
converted_state_dict["output.weight"] = converted_state_dict[
"tok_embeddings.weight"
]
return converted_state_dict


def load_checkpoint_from_safetensors(input_dir: str) -> Dict:
index_path = os.path.join(input_dir, "model.safetensors.index.json")
if os.path.exists(index_path):
# Sharded checkpoint.
with open(index_path, "r") as f:
index = json.load(f)
weight_map = index["weight_map"]
checkpoint_shards = sorted(set(weight_map.values()))

# Load all the shards into memory
shard_to_weights = {}
for shard in checkpoint_shards:
shard_to_weights[shard] = load_file(os.path.join(input_dir, shard))

# Merge tensors into consolidated state dict.
merged_state_dict = {}
for weight_name, shard in weight_map.items():
tensor = shard_to_weights[shard][weight_name]
merged_state_dict[weight_name] = tensor
return merged_state_dict
else:
# Single checkpoint.
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
return state_dict


def load_checkpoint(input_dir: str) -> Dict:
pytorch_path = os.path.join(input_dir, "pytorch_model.bin")
if os.path.exists(pytorch_path):
print("Loading checkpoint from PyTorch .bin file")
return torch.load(pytorch_path, map_location="cpu", weights_only=True)
print("Loading checkpoint from safetensors directory")
return load_checkpoint_from_safetensors(input_dir)


def convert_weights(input_dir: str, output_file: str) -> None:
print("Loading checkpoint...")
sd = load_checkpoint(input_dir)
print("Converting checkpoint...")
sd = gemma3_to_executorch(sd)
print("Saving checkpoint...")
torch.save(sd, output_file)
print("Done.")


def main():
parser = argparse.ArgumentParser(
description="Convert Gemma3 weights to ExecuTorch transformer format."
)
parser.add_argument(
"input_dir",
type=str,
help="Path to directory containing safetensor checkpoint files, or PyTorch checkpoint file.",
)
parser.add_argument("output", type=str, help="Path to the output checkpoint")

args = parser.parse_args()
convert_weights(args.input_dir, args.output)


if __name__ == "__main__":
main()
18 changes: 11 additions & 7 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn as nn
import torch.nn.functional as F
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.norm import RMSNorm
from executorch.examples.models.llama.norm import Norm
from executorch.examples.models.llama.rope import Rope


Expand Down Expand Up @@ -324,7 +324,14 @@ def update(

@register_attention("mha")
class AttentionMHA(Attention):
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
def __init__(
self,
args: ModelArgs,
layer_id: int,
rope: Rope,
q_norm_fn: Optional[Norm] = None,
k_norm_fn: Optional[Norm] = None,
):
super().__init__()
self.use_kv_cache = args.use_kv_cache
self.n_heads = args.n_heads
Expand All @@ -343,11 +350,8 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
self.qk_norm_before_rope = args.qk_norm_before_rope
self.enable_dynamic_shape = args.enable_dynamic_shape

if self.use_qk_norm:
q_norm_dim = self.head_dim
k_norm_dim = self.head_dim
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)
self.q_norm_fn = q_norm_fn
self.k_norm_fn = k_norm_fn

self.wq = nn.Linear(
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
Expand Down
6 changes: 6 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
"qwen3_0_6b",
"qwen3_1_7b",
"qwen3_4b",
"gemma3_1b",
"phi_4_mini",
"smollm2",
]
Expand All @@ -118,6 +119,7 @@
"qwen3_0_6b": "Qwen/Qwen3-0.6B",
"qwen3_1_7b": "Qwen/Qwen3-1.7B",
"qwen3_4b": "Qwen/Qwen3-4B",
"gemma3_1b": "google/gemma-3-1b-it",
}


Expand Down Expand Up @@ -609,6 +611,10 @@ def export_llama(
from executorch.examples.models.smollm2 import ( # pyre-ignore[21]
convert_weights,
)
elif model_name.startswith("gemma3"):
from executorch.examples.models.gemma3 import ( # pyre-ignore[21]
convert_weights,
)
else:
raise ValueError(
f"Converting weights to meta format for {model_name} is not yet supported"
Expand Down
Loading
Loading