Skip to content

Commit 0882c9b

Browse files
Qualcomm AI Engine Direct - GA Static Gemma-2b-instruct (#14459)
### Summary: - e2e script for Gemma-2b-it in static llama version - add model params file & model weight converter ### Test plan ``` bash python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma3-1b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ```
1 parent 4372a14 commit 0882c9b

File tree

11 files changed

+265
-11
lines changed

11 files changed

+265
-11
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4968,6 +4968,65 @@ def test_qnn_backend_seq_mse(self):
49684968

49694969

49704970
class TestExampleLLMScript(TestQNN):
4971+
def test_static_gemma_2b(self):
4972+
if not self.required_envs():
4973+
self.skipTest("missing required envs")
4974+
4975+
prompt = "My favourite condiment is "
4976+
cmds = [
4977+
"python",
4978+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
4979+
"--artifact",
4980+
self.artifact_dir,
4981+
"--build_folder",
4982+
self.build_folder,
4983+
"--model",
4984+
self.model,
4985+
"--ip",
4986+
self.ip,
4987+
"--port",
4988+
str(self.port),
4989+
"--prompt",
4990+
f"{prompt}",
4991+
"--decoder_model",
4992+
"gemma-2b",
4993+
"--model_mode",
4994+
"kv",
4995+
"--max_seq_len",
4996+
"1024",
4997+
"--eval_perplexity",
4998+
"--tasks",
4999+
"wikitext",
5000+
"--limit",
5001+
"1",
5002+
]
5003+
if self.compile_only:
5004+
cmds.extend(["--compile_only"])
5005+
elif self.device:
5006+
cmds.extend(["--device", self.device])
5007+
if self.host:
5008+
cmds.extend(["--host", self.host])
5009+
elif self.enable_x86_64:
5010+
cmds.extend(["--enable_x86_64"])
5011+
if self.pre_gen_pte:
5012+
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
5013+
5014+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
5015+
with Listener((self.ip, self.port)) as listener:
5016+
conn = listener.accept()
5017+
p.communicate()
5018+
msg = json.loads(conn.recv())
5019+
if "Error" in msg:
5020+
self.fail(msg["Error"])
5021+
else:
5022+
inference_speed_ref = {"SM8650": 32, "SM8750": 36}
5023+
self.assertLessEqual(msg["wiki_ppl"], 35)
5024+
self.assertLessEqual(msg["pte_size"], 2_700_000_000) # 2.7GB
5025+
if self.model in inference_speed_ref:
5026+
self.assertGreaterEqual(
5027+
msg["inference_speed"], inference_speed_ref[self.model]
5028+
)
5029+
49715030
def test_static_gemma3_1b(self):
49725031
if not self.required_envs():
49735032
self.skipTest("missing required envs")

examples/models/gemma/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# This source code is licensed under the BSD-style license found in the
2+
# LICENSE file in the root directory of this source tree.
3+
4+
from executorch.examples.models.gemma.convert_weights import convert_weights
5+
from executorch.examples.models.llama.model import Llama2Model
6+
7+
8+
class GemmaModel(Llama2Model):
9+
def __init__(self, **kwargs):
10+
super().__init__(**kwargs)
11+
12+
13+
__all__ = [
14+
"GemmaModel",
15+
"convert_weights",
16+
]
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"dim": 2048,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 16384,
5+
"n_heads": 8,
6+
"head_dim": 256,
7+
"n_kv_heads": 1,
8+
"n_layers": 18,
9+
"act_fn": "gelu",
10+
"norm_type": "gemma3",
11+
"norm_eps": 1e-06,
12+
"rope_theta": 10000.0,
13+
"use_scaled_rope": false,
14+
"apply_embedding": true,
15+
"embedding_scale_factor": 45.254833995939045,
16+
"vocab_size": 256000,
17+
"use_hf_rope": true,
18+
"attention_qkv_bias": false
19+
}
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import argparse
2+
3+
import json
4+
import os
5+
from typing import Dict
6+
7+
import torch
8+
from safetensors.torch import load_file
9+
10+
from torchtune.models.convert_weights import get_mapped_key
11+
12+
13+
# Weight mappings from Gemma's checkpoint to ExecuTorch's transformer parameters.
14+
_GEMMA_TO_EXECUTORCH = {
15+
"model.embed_tokens.weight": "tok_embeddings.weight",
16+
"model.norm.weight": "norm.weight",
17+
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
18+
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
19+
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
20+
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
21+
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
22+
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
23+
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
24+
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
25+
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
26+
}
27+
28+
29+
def gemma_to_executorch(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
30+
"""
31+
Convert the state dict so that it matches what ExecuTorch's transformer definition expects.
32+
"""
33+
converted_state_dict = {}
34+
for key, value in state_dict.items():
35+
new_key = get_mapped_key(key, _GEMMA_TO_EXECUTORCH)
36+
converted_state_dict[new_key] = value
37+
converted_state_dict["output.weight"] = converted_state_dict[
38+
"tok_embeddings.weight"
39+
]
40+
return converted_state_dict
41+
42+
43+
def load_checkpoint_from_safetensors(input_dir: str) -> Dict:
44+
index_path = os.path.join(input_dir, "model.safetensors.index.json")
45+
if os.path.exists(index_path):
46+
# Sharded checkpoint.
47+
with open(index_path, "r") as f:
48+
index = json.load(f)
49+
weight_map = index["weight_map"]
50+
checkpoint_shards = sorted(set(weight_map.values()))
51+
52+
# Load all the shards into memory
53+
shard_to_weights = {}
54+
for shard in checkpoint_shards:
55+
shard_to_weights[shard] = load_file(os.path.join(input_dir, shard))
56+
57+
# Merge tensors into consolidated state dict.
58+
merged_state_dict = {}
59+
for weight_name, shard in weight_map.items():
60+
tensor = shard_to_weights[shard][weight_name]
61+
merged_state_dict[weight_name] = tensor
62+
return merged_state_dict
63+
else:
64+
# Single checkpoint.
65+
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
66+
return state_dict
67+
68+
69+
def load_checkpoint(input_dir: str) -> Dict:
70+
pytorch_path = os.path.join(input_dir, "pytorch_model.bin")
71+
if os.path.exists(pytorch_path):
72+
print("Loading checkpoint from PyTorch .bin file")
73+
return torch.load(pytorch_path, map_location="cpu", weights_only=True)
74+
print("Loading checkpoint from safetensors directory")
75+
return load_checkpoint_from_safetensors(input_dir)
76+
77+
78+
def convert_weights(input_dir: str, output_file: str) -> None:
79+
print("Loading checkpoint...")
80+
sd = load_checkpoint(input_dir)
81+
print("Converting checkpoint...")
82+
sd = gemma_to_executorch(sd)
83+
print("Saving checkpoint...")
84+
torch.save(sd, output_file)
85+
print("Done.")
86+
87+
88+
def main():
89+
parser = argparse.ArgumentParser(
90+
description="Convert Gemma weights to ExecuTorch transformer format."
91+
)
92+
parser.add_argument(
93+
"input_dir",
94+
type=str,
95+
help="Path to directory containing safetensor checkpoint files, or PyTorch checkpoint file.",
96+
)
97+
parser.add_argument("output", type=str, help="Path to the output checkpoint")
98+
99+
args = parser.parse_args()
100+
convert_weights(args.input_dir, args.output)
101+
102+
103+
if __name__ == "__main__":
104+
main()

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@ This file provides you the instructions to run LLM Decoder model with different
55
1. LLAMA2 Stories 110M
66
2. LLAMA3.2 1B
77
3. LLAMA3.2 3B
8-
4. Gemma3 1B
9-
5. Phi4-mini-instruct
10-
6. QWEN2.5 0.5B / 1.5B
11-
7. QWEN3 0.6B / 1.7B
12-
8. SmolLM2 135M
13-
9. SmolLM3 3B
8+
4. Gemma 2B
9+
5. Gemma3 1B
10+
6. Phi4-mini-instruct
11+
7. QWEN2.5 0.5B / 1.5B
12+
8. QWEN3 0.6B / 1.7B
13+
9. SmolLM2 135M
14+
10. SmolLM3 3B
1415

1516

1617
We offer the following modes to execute the model:
@@ -78,6 +79,13 @@ Default example using kv mode.
7879
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-3b_instruct --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
7980
```
8081

82+
#### Gemma 2B
83+
Default example using hybrid mode
84+
```bash
85+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma-2b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
86+
```
87+
88+
8189
#### Gemma3 1B
8290
Default example using hybrid mode
8391
```bash

examples/qualcomm/oss_scripts/llama/__init__.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
2626

27+
from executorch.examples.models.gemma import convert_weights as convert_gemma_weights
2728
from executorch.examples.models.gemma3 import convert_weights as convert_gemma3_weights
2829
from executorch.examples.models.phi_4_mini import (
2930
convert_weights as convert_phi_4_mini_weights,
@@ -300,6 +301,36 @@ class Llama3_2_3B_Instruct(LLMModelConfig):
300301
)
301302

302303

304+
@register_llm_model("gemma-2b")
305+
@dataclass(init=False, frozen=True)
306+
class Gemma_2B(LLMModelConfig):
307+
repo_id: str = "google/gemma-2b-it"
308+
params_path: str = os.path.join(
309+
BASE_DIR, "../../../models/gemma/config/2b_config.json"
310+
)
311+
convert_weights = convert_gemma_weights
312+
transform_weight = False
313+
instruct_model = True
314+
315+
num_sharding = 4
316+
# quant config
317+
ptq = QuantDtype.use_16a4w_block
318+
group_size = 64
319+
masked_softmax = True
320+
seq_mse_candidates = 0
321+
r1 = False
322+
r2 = False
323+
r3 = False
324+
quantization_config_wv_sha_16a8w = get_ptq_per_channel_quant_config(
325+
torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver
326+
)
327+
custom_annotation = (
328+
annotate_kv_8bit,
329+
annotate_output_16a8w,
330+
partial(annotate_wv_sha, quantization_config=quantization_config_wv_sha_16a8w),
331+
)
332+
333+
303334
@register_llm_model("gemma3-1b")
304335
@dataclass(init=False, frozen=True)
305336
class Gemma3(LLMModelConfig):

examples/qualcomm/oss_scripts/llama/decoder_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
DECODER_MODEL_VERSION = {
1515
"stories260k": "llama2",
1616
"stories110m": "llama2",
17+
"gemma-2b": "gemma",
1718
"gemma3-1b": "gemma3",
1819
"phi_4_mini": "phi_4_mini",
1920
"llama3_2-1b_instruct": "llama3",

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,13 @@ def quantize(
327327
chat_template, args.prompt[0], args.system_prompt
328328
)
329329
)
330+
331+
# Gemma may produce unexpected output if the prompt contains an extra <bos> token.
332+
# This can happen after applying a prompt template, which might inject <bos> unintentionally.
333+
# To prevent decoding issues, we explicitly remove <bos> token
334+
if chat_template and args.decoder_model in {"gemma-2b", "gemma3-1b"}:
335+
prompt = prompt.replace("<bos>", "")
336+
330337
graph_module_inference(
331338
use_kv_cache=self.llama_meta["get_use_kv_cache"],
332339
get_example_inputs=self.get_example_inputs,
@@ -534,14 +541,13 @@ def compile(
534541
state_dict = torch.load(
535542
checkpoint, weights_only=True, map_location="cpu", mmap=True
536543
)
537-
if args.decoder_model == "gemma3-1b":
544+
if args.decoder_model in {"gemma-2b", "gemma3-1b"}:
538545
for k, v in state_dict.items():
539546
if "norm" not in k:
540547
continue
541548
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
542549
# See https://github.com/huggingface/transformers/pull/29402
543550
state_dict[k] = v.float() + torch.ones(v.shape, dtype=torch.float32)
544-
545551
else:
546552
state_dict = torch.load(
547553
args.checkpoint, weights_only=True, map_location="cpu", mmap=True
@@ -1286,7 +1292,11 @@ def export_llama(args) -> None:
12861292
)
12871293
tokenizer_artifacts = tokenizer.save_pretrained(args.artifact)
12881294
tokenizer_config = tokenizer_artifacts[0]
1289-
runtime_tokenizer_path = tokenizer_artifacts[-1]
1295+
if args.decoder_model == "gemma-2b":
1296+
# For Gemma, use tokenizer.model as it doesn't provide pre_tokenizer in tokenizer.json.
1297+
runtime_tokenizer_path = tokenizer_artifacts[-3]
1298+
else:
1299+
runtime_tokenizer_path = tokenizer_artifacts[-1]
12901300
tokenizer = get_tokenizer(runtime_tokenizer_path, tokenizer_config)
12911301

12921302
# TODO: Remove this once error is resolved.

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
/**
1010
* @file
1111
*
12-
* This tool can run Llama2 110M, Llama3.2 1B / 3B, Gemma3 1B,
12+
* This tool can run Llama2 110M, Llama3.2 1B / 3B, Gemma 2B, Gemma3 1B,
1313
* phi4-mini-instruct, Qwen2.5 0.5B / 1.5B, Qwen3 0.6B / 1.7B, SmolLM2 135M,
1414
* SmolLM3 3B with Qualcomm AI Engine Direct.
1515
*
@@ -117,6 +117,7 @@ std::string get_formatted_prompt(
117117
formatted_prompt.append(
118118
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
119119
break;
120+
case example::DecoderModelVersion::kGemma:
120121
case example::DecoderModelVersion::kGemma3:
121122
formatted_prompt.append("<start_of_turn>user\n");
122123
formatted_prompt.append(prompt);

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ Runner<T>::Runner(
122122
decoder_model_version_ = DecoderModelVersion::kLlama2;
123123
} else if (decoder_model_version == "llama3") {
124124
decoder_model_version_ = DecoderModelVersion::kLlama3;
125+
} else if (decoder_model_version == "gemma") {
126+
decoder_model_version_ = DecoderModelVersion::kGemma;
125127
} else if (decoder_model_version == "gemma3") {
126128
decoder_model_version_ = DecoderModelVersion::kGemma3;
127129
cache_mode_ = CacheMode::HybridCache;
@@ -199,7 +201,9 @@ Error Runner<T>::load() {
199201
decoder_model_version_ == DecoderModelVersion::kSmollm2_135m ||
200202
decoder_model_version_ == DecoderModelVersion::kSmollm3) {
201203
eos_ids->insert(tokenizer_->encode("<|im_end|>", 0, 0).get()[0]);
202-
} else if (decoder_model_version_ == DecoderModelVersion::kGemma3) {
204+
} else if (
205+
decoder_model_version_ == DecoderModelVersion::kGemma ||
206+
decoder_model_version_ == DecoderModelVersion::kGemma3) {
203207
eos_ids->insert(tokenizer_->encode("<end_of_turn>", 0, 0).get()[0]);
204208
}
205209

0 commit comments

Comments
 (0)