Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 77 additions & 6 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4638,6 +4638,77 @@ def test_qnn_backend_generate_optrace(self):


class TestExampleLLMScript(TestQNN):
def test_static_gemma3_1b(self):
if not self.required_envs():
self.skipTest("missing required envs")

prompt = "My favourite condiment is "
cmds = [
"python",
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
"--artifact",
self.artifact_dir,
"--build_folder",
self.build_folder,
"--model",
self.model,
"--ip",
self.ip,
"--port",
str(self.port),
"--prompt",
f"{prompt}",
"--ptq",
"16a4w_block",
"--temperature",
"0",
"--decoder_model",
"gemma3-1b",
"--model_mode",
"kv",
"--max_seq_len",
"1024",
"--eval_perplexity",
"--tasks",
"wikitext",
"--limit",
"1",
"--enable_masked_softmax",
]
if self.compile_only:
cmds.extend(["--compile_only"])
elif self.device:
cmds.extend(["--device", self.device])
if self.host:
cmds.extend(["--host", self.host])
elif self.enable_x86_64:
cmds.extend(["--enable_x86_64"])
if self.pre_gen_pte:
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])

p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
with Listener((self.ip, self.port)) as listener:
conn = listener.accept()
p.communicate()
msg = json.loads(conn.recv())
if "Error" in msg:
self.fail(msg["Error"])
else:
if not self.compile_only:
self.assertLessEqual(msg["wiki_ppl"], 23)
if not self.enable_x86_64:
pte_size = msg["pte_size"]
self.assertLessEqual(pte_size, 1_200_000_000) # 1.2GB
inference_speed_ref = {"SM8650": 70, "SM8750": 100}
if (
not self.compile_only
and not self.enable_x86_64
and self.model in inference_speed_ref
):
self.assertGreaterEqual(
msg["inference_speed"], inference_speed_ref[self.model]
)

def test_llama3_2_1b(self):
if not self.required_envs():
self.skipTest("missing required envs")
Expand Down Expand Up @@ -4708,7 +4779,7 @@ def test_llama3_2_1b(self):
# Inference speed on x86 is slow, so we only check when running on Android
if not self.enable_x86_64:
pte_size = msg["pte_size"]
self.assertLessEqual(pte_size, 1300000000)
self.assertLessEqual(pte_size, 1_300_000_000) # 1.3GB
if not self.compile_only and not self.enable_x86_64:
self.assertGreaterEqual(msg["inference_speed"], 66) # Lanai

Expand Down Expand Up @@ -4784,7 +4855,7 @@ def test_llama_stories_260k(self):
# x86 does not allow weight sharing, so we don't check pte size
if not self.enable_x86_64:
pte_size = msg["pte_size"]
self.assertLessEqual(pte_size, 2020000)
self.assertLessEqual(pte_size, 2_020_000) # 2MB
if not self.compile_only and not self.enable_x86_64:
self.assertGreaterEqual(msg["inference_speed"], 1600) # Lanai

Expand Down Expand Up @@ -4859,7 +4930,7 @@ def test_llama_stories_110m(self):
# x86 does not allow weight sharing, so we don't check pte size
if not self.enable_x86_64:
pte_size = msg["pte_size"]
self.assertLessEqual(pte_size, 130000000)
self.assertLessEqual(pte_size, 130_000_000) # 130MB
if not self.compile_only and not self.enable_x86_64:
self.assertGreaterEqual(msg["inference_speed"], 220) # Lanai

Expand Down Expand Up @@ -4922,7 +4993,7 @@ def test_static_phi4(self):
else:
inference_speed_ref = {"SM8650": 14, "SM8750": 19}
self.assertLessEqual(msg["wiki_ppl"], 12)
self.assertLessEqual(msg["pte_size"], 4000000000) # 4gb
self.assertLessEqual(msg["pte_size"], 4_000_000_000) # 4GB
if self.model in inference_speed_ref:
self.assertGreaterEqual(
msg["inference_speed"], inference_speed_ref[self.model]
Expand Down Expand Up @@ -4981,7 +5052,7 @@ def test_static_qwen2_5(self):
else:
inference_speed_ref = {"SM8650": 115, "SM8750": 155}
self.assertLessEqual(msg["wiki_ppl"], 15)
self.assertLessEqual(msg["pte_size"], 600000000) # 600mb
self.assertLessEqual(msg["pte_size"], 600_000_000) # 600MB
if self.model in inference_speed_ref:
self.assertGreaterEqual(
msg["inference_speed"], inference_speed_ref[self.model]
Expand Down Expand Up @@ -5040,7 +5111,7 @@ def test_static_qwen3(self):
else:
inference_speed_ref = {"SM8650": 38, "SM8750": 56}
self.assertLessEqual(msg["wiki_ppl"], 18)
self.assertLessEqual(msg["pte_size"], 950_000_000) # 950mb
self.assertLessEqual(msg["pte_size"], 950_000_000) # 950MB
if self.model in inference_speed_ref:
self.assertGreaterEqual(
msg["inference_speed"], inference_speed_ref[self.model]
Expand Down
16 changes: 16 additions & 0 deletions examples/models/gemma3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.examples.models.gemma3.convert_weights import convert_weights
from executorch.examples.models.llama.model import Llama2Model


class Gemma3Model(Llama2Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)


__all__ = [
"Gemma3Model",
"convert_weights",
]
23 changes: 23 additions & 0 deletions examples/models/gemma3/config/1b_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"dim": 1152,
"ffn_dim_multiplier": 1,
"hidden_dim": 6912,
"n_heads": 4,
"head_dim": 256,
"n_kv_heads": 1,
"n_layers": 26,
"act_fn": "gelu_approx",
"norm_type": "gemma3",
"norm_eps": 1e-06,
"post_attention_norm": true,
"post_ffn_norm": true,
"rope_theta": 1000000.0,
"use_scaled_rope": false,
"apply_embedding": true,
"embedding_scale_factor": 33.941125497,
"vocab_size": 262144,
"use_hf_rope": true,
"attention_qkv_bias": false,
"use_qk_norm": true,
"qk_norm_before_rope": true
}
110 changes: 110 additions & 0 deletions examples/models/gemma3/convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import argparse

import json
import os
from typing import Dict

import torch
from safetensors.torch import load_file

from torchtune.models.convert_weights import get_mapped_key


# Weight mappings from Gemma 3's checkpoint to ExecuTorch's transformer parameters.
_GEMMA3_TO_EXECUTORCH = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.norm.weight": "norm.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm_fn.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm_fn.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.post_attention_norm.weight",
"model.layers.{}.pre_feedforward_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.layers.{}.post_feedforward_layernorm.weight": "layers.{}.post_ffn_norm.weight",
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
}


def gemma3_to_executorch(
state_dict: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""
Convert the state dict so that it matches what ExecuTorch's transformer definition expects.
"""
converted_state_dict = {}
for key, value in state_dict.items():
new_key = get_mapped_key(key, _GEMMA3_TO_EXECUTORCH)
converted_state_dict[new_key] = value
converted_state_dict["output.weight"] = converted_state_dict[
"tok_embeddings.weight"
]
return converted_state_dict


def load_checkpoint_from_safetensors(input_dir: str) -> Dict:
index_path = os.path.join(input_dir, "model.safetensors.index.json")
if os.path.exists(index_path):
# Sharded checkpoint.
with open(index_path, "r") as f:
index = json.load(f)
weight_map = index["weight_map"]
checkpoint_shards = sorted(set(weight_map.values()))

# Load all the shards into memory
shard_to_weights = {}
for shard in checkpoint_shards:
shard_to_weights[shard] = load_file(os.path.join(input_dir, shard))

# Merge tensors into consolidated state dict.
merged_state_dict = {}
for weight_name, shard in weight_map.items():
tensor = shard_to_weights[shard][weight_name]
merged_state_dict[weight_name] = tensor
return merged_state_dict
else:
# Single checkpoint.
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
return state_dict


def load_checkpoint(input_dir: str) -> Dict:
pytorch_path = os.path.join(input_dir, "pytorch_model.bin")
if os.path.exists(pytorch_path):
print("Loading checkpoint from PyTorch .bin file")
return torch.load(pytorch_path, map_location="cpu", weights_only=True)
print("Loading checkpoint from safetensors directory")
return load_checkpoint_from_safetensors(input_dir)


def convert_weights(input_dir: str, output_file: str) -> None:
print("Loading checkpoint...")
sd = load_checkpoint(input_dir)
print("Converting checkpoint...")
sd = gemma3_to_executorch(sd)
print("Saving checkpoint...")
torch.save(sd, output_file)
print("Done.")


def main():
parser = argparse.ArgumentParser(
description="Convert Gemma3 weights to ExecuTorch transformer format."
)
parser.add_argument(
"input_dir",
type=str,
help="Path to directory containing safetensor checkpoint files, or PyTorch checkpoint file.",
)
parser.add_argument("output", type=str, help="Path to the output checkpoint")

args = parser.parse_args()
convert_weights(args.input_dir, args.output)


if __name__ == "__main__":
main()
41 changes: 41 additions & 0 deletions examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,39 @@
import dataclasses
from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import Any, Dict, Optional

import torch.nn.functional as F


class ActFn(Enum):
SILU = "silu"
GELU = "gelu"
GELU_APPROX = "gelu_approx"

@classmethod
def from_string(cls, value: str) -> "ActFn":
"""Convert string to ActFn enum."""
try:
return cls(value)
except ValueError:
valid_values = [e.value for e in cls]
raise ValueError(
f"Invalid activation function: {value}. Valid options: {valid_values}"
)

def get_function(self):
"""Return the corresponding activation function."""
if self == ActFn.SILU:
return F.silu
elif self == ActFn.GELU:
return F.gelu
elif self == ActFn.GELU_APPROX:
return partial(F.gelu, approximate="tanh")
else:
raise ValueError(f"Unsupported activation function: {self}")


@dataclass
class ModelArgs:
Expand All @@ -15,13 +47,17 @@ class ModelArgs:
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
post_attention_norm: bool = False
post_ffn_norm: bool = False
max_batch_size: int = 1
max_seq_len: int = 2048
max_context_len: int = 2048
moe: bool = False # True to enable the MoE (Mixture of Experts)
num_experts: int = 8 # Number of experts
num_activated_experts: int = 2 # Number of experts to activate
attention_type: str = "mha" # Attention type, registered in attention.py
norm_type: str = "rmsnorm" # Normalization type, registered in norm.py
act_fn: ActFn = dataclasses.field(default=ActFn.SILU) # Activation function type
attention_qkv_bias: bool = False
use_kv_cache: bool = False # Use key/value cache
use_sdpa_with_kv_cache_op: bool = (
Expand All @@ -37,6 +73,7 @@ class ModelArgs:
# A dictionary mapping from pruned token-id to original token-id
output_prune_map: Optional[Dict[int, int]] = None
apply_embedding: bool = True # Use embedding inside the transformer
embedding_scale_factor: float = 1.0 # Multiple by which to scale embeddings.
apply_output: bool = True # Use output layer (unembedding) inside the transformer
use_qk_norm: bool = False # apply normalization to q and k in the attention
qk_norm_before_rope: bool = False # when to apply qk norm
Expand Down Expand Up @@ -103,3 +140,7 @@ def find_multiple(n: int, k: int) -> int:

if self.head_dim is None:
self.head_dim = self.dim // self.n_heads

# Convert string act_fn to enum if needed
if isinstance(self.act_fn, str):
self.act_fn = ActFn.from_string(self.act_fn)
1 change: 1 addition & 0 deletions examples/qualcomm/oss_scripts/llama/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ list(
${CMAKE_CURRENT_LIST_DIR}/qnn_llama_runner.cpp
${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp
${CMAKE_CURRENT_LIST_DIR}/runner/runner.h
${CMAKE_CURRENT_LIST_DIR}/runner/cache_utils.h
${CMAKE_CURRENT_LIST_DIR}/runner/decoder_runner.cpp
${CMAKE_CURRENT_LIST_DIR}/runner/decoder_runner.h
${CMAKE_CURRENT_LIST_DIR}/runner/prompt_processor.cpp
Expand Down
Loading
Loading