Skip to content

Commit 3a71f1f

Browse files
Qualcomm AI Engine Direct - GA Static SmolLM3 3B
Summary: - e2e script for GA Static SmolLM3-3B - perf: 16a4w block quant token rate in kv mode: ~= 30 tokens/sec(SM8750) - acc: PPL ~= (fp: 8.345 -> htp:8.976) in wikitext dataset - add model params file & model weight converter
1 parent 6ed10e5 commit 3a71f1f

File tree

15 files changed

+376
-50
lines changed

15 files changed

+376
-50
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
from enum import Enum, unique
67
from typing import Sequence
78

89
import torch
@@ -31,6 +32,17 @@
3132
)
3233

3334

35+
@unique
36+
class StaticLLMQuantConfig(Enum):
37+
"""
38+
Layer namespace configuration for Qualcomm's static LLaMA quantization.
39+
"""
40+
41+
wq_sha = "wq_sha" # Query weight (single head)
42+
wk_sha = "wk_sha" # Key weight (single head)
43+
wv_sha = "wv_sha" # Value weight (single head)
44+
45+
3446
def annotate_eurobert(gm: torch.fx.GraphModule):
3547
"""
3648
QNN does not support int32 -> signed 16bit quant
@@ -166,11 +178,35 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
166178
)
167179

168180

169-
def annotate_wv_sha(gm: torch.fx.GraphModule, quantization_config: QuantizationConfig):
181+
def annotate_qkv_proj_sha(
182+
gm: torch.fx.GraphModule,
183+
quantization_config: QuantizationConfig,
184+
qkv_tags: set[StaticLLMQuantConfig],
185+
):
186+
"""
187+
Annotates QKV projection layers in a GraphModule for quantization,
188+
specifically layers defined in StaticLLMQuantConfig.
189+
190+
Args:
191+
qkv_tags (set[StaticLLMQuantConfig]): A set of enum tags indicating which QKV layers
192+
(e.g., wq, wk, wv) should be annotated for quantization. Only tags defined in
193+
StaticLLMQuantConfig are allowed.
194+
195+
Raises:
196+
ValueError: If any tag in `qkv_tags` is not among the allowed enum members.
197+
"""
198+
199+
# Get all valid tags from the StaticLLMQuantConfig enum
200+
allowed_tags = set(StaticLLMQuantConfig)
201+
invalid_tags = qkv_tags - allowed_tags
202+
if invalid_tags:
203+
raise ValueError(
204+
f"Invalid qkv tags: {invalid_tags}. Allowed tags are: {allowed_tags}"
205+
)
206+
170207
for node in gm.graph.nodes:
171-
if (
172-
node.target == torch.ops.aten.conv2d.default
173-
and "wv_sha" in node.meta["stack_trace"]
208+
if node.target == torch.ops.aten.conv2d.default and any(
209+
tag.value in node.meta["stack_trace"] for tag in qkv_tags
174210
):
175211
input_qspec_map = {}
176212
input_qspec_map[node.args[0]] = quantization_config.input_activation

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 81 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5117,6 +5117,60 @@ def test_static_qwen3(self):
51175117
msg["inference_speed"], inference_speed_ref[self.model]
51185118
)
51195119

5120+
def test_qwen2_5(self):
5121+
if not self.required_envs([]):
5122+
self.skipTest("missing required envs")
5123+
prompt = "My favourite condiment is "
5124+
cmds = [
5125+
"python",
5126+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py",
5127+
"--prompt",
5128+
prompt,
5129+
"--decoder_model",
5130+
"qwen2.5_0.5B",
5131+
"--ptq",
5132+
"16a8w",
5133+
"--enable_spinquant_r3",
5134+
"--max_seq_len",
5135+
"128",
5136+
"--artifact",
5137+
self.artifact_dir,
5138+
"--build_folder",
5139+
self.build_folder,
5140+
"--model",
5141+
self.model,
5142+
"--ip",
5143+
self.ip,
5144+
"--port",
5145+
str(self.port),
5146+
]
5147+
if self.compile_only:
5148+
cmds.extend(["--compile_only"])
5149+
elif self.device:
5150+
cmds.extend(["--device", self.device])
5151+
if self.host:
5152+
cmds.extend(["--host", self.host])
5153+
elif self.enable_x86_64:
5154+
cmds.extend(["--enable_x86_64"])
5155+
if self.pre_gen_pte:
5156+
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
5157+
5158+
golden_start_with = "My favourite condiment is iced tea."
5159+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
5160+
with Listener((self.ip, self.port)) as listener:
5161+
conn = listener.accept()
5162+
p.communicate()
5163+
msg = json.loads(conn.recv())
5164+
if "Error" in msg:
5165+
self.fail(msg["Error"])
5166+
else:
5167+
if not self.compile_only:
5168+
model_out = msg["result"][0]
5169+
self.assertTrue(
5170+
model_out.startswith(golden_start_with),
5171+
f"Expected Output: '{golden_start_with}' Actual Output: '{model_out}'",
5172+
)
5173+
51205174
def test_static_smollm2(self):
51215175
if not self.required_envs():
51225176
self.skipTest("missing required envs")
@@ -5150,6 +5204,8 @@ def test_static_smollm2(self):
51505204
"--eval_perplexity",
51515205
"--task",
51525206
"wikitext",
5207+
"--limit",
5208+
"1",
51535209
]
51545210
if self.compile_only:
51555211
cmds.extend(["--compile_only"])
@@ -5173,22 +5229,14 @@ def test_static_smollm2(self):
51735229
self.assertLessEqual(msg["wiki_ppl"], 25)
51745230
self.assertGreaterEqual(msg["inference_speed"], 200)
51755231

5176-
def test_qwen2_5(self):
5177-
if not self.required_envs([]):
5232+
def test_static_smollm3(self):
5233+
if not self.required_envs():
51785234
self.skipTest("missing required envs")
5235+
51795236
prompt = "My favourite condiment is "
51805237
cmds = [
51815238
"python",
5182-
f"{self.executorch_root}/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py",
5183-
"--prompt",
5184-
prompt,
5185-
"--decoder_model",
5186-
"qwen2.5_0.5B",
5187-
"--ptq",
5188-
"16a8w",
5189-
"--enable_spinquant_r3",
5190-
"--max_seq_len",
5191-
"128",
5239+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
51925240
"--artifact",
51935241
self.artifact_dir,
51945242
"--build_folder",
@@ -5199,6 +5247,21 @@ def test_qwen2_5(self):
51995247
self.ip,
52005248
"--port",
52015249
str(self.port),
5250+
"--prompt",
5251+
f"{prompt}",
5252+
"--decoder_model",
5253+
"smollm3-3b",
5254+
"--model_mode",
5255+
"kv",
5256+
"--temperature",
5257+
"0",
5258+
"--max_seq_len",
5259+
"1024",
5260+
"--eval_perplexity",
5261+
"--task",
5262+
"wikitext",
5263+
"--limit",
5264+
"1",
52025265
]
52035266
if self.compile_only:
52045267
cmds.extend(["--compile_only"])
@@ -5211,7 +5274,6 @@ def test_qwen2_5(self):
52115274
if self.pre_gen_pte:
52125275
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
52135276

5214-
golden_start_with = "My favourite condiment is iced tea."
52155277
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
52165278
with Listener((self.ip, self.port)) as listener:
52175279
conn = listener.accept()
@@ -5220,11 +5282,12 @@ def test_qwen2_5(self):
52205282
if "Error" in msg:
52215283
self.fail(msg["Error"])
52225284
else:
5223-
if not self.compile_only:
5224-
model_out = msg["result"][0]
5225-
self.assertTrue(
5226-
model_out.startswith(golden_start_with),
5227-
f"Expected Output: '{golden_start_with}' Actual Output: '{model_out}'",
5285+
inference_speed_ref = {"SM8650": 23, "SM8750": 28}
5286+
self.assertLessEqual(msg["wiki_ppl"], 10)
5287+
self.assertLessEqual(msg["pte_size"], 2_600_000_000) # 2.6GB
5288+
if self.model in inference_speed_ref:
5289+
self.assertGreaterEqual(
5290+
msg["inference_speed"], inference_speed_ref[self.model]
52285291
)
52295292

52305293

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"dim": 2048,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 11008,
5+
"n_heads": 16,
6+
"n_kv_heads": 4,
7+
"n_layers": 36,
8+
"norm_eps": 1e-06,
9+
"rope_theta": 5000000.0,
10+
"use_scaled_rope": false,
11+
"vocab_size": 128256,
12+
"use_hf_rope": false,
13+
"attention_qkv_bias": false
14+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# This source code is licensed under the BSD-style license found in the
2+
# LICENSE file in the root directory of this source tree.
3+
4+
from executorch.examples.models.llama.model import Llama2Model
5+
from executorch.examples.models.smollm3.convert_weights import convert_weights
6+
7+
8+
class SmolLM3Model(Llama2Model):
9+
def __init__(self, **kwargs):
10+
super().__init__(**kwargs)
11+
12+
13+
__all__ = [
14+
"SmolLM3Model",
15+
"convert_weights",
16+
]
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import argparse
2+
import json
3+
import os
4+
from typing import Dict
5+
6+
import torch
7+
8+
from safetensors.torch import load_file
9+
10+
from torchtune.models.convert_weights import get_mapped_key
11+
12+
13+
_SMOLLM_TO_META = {
14+
"model.embed_tokens.weight": "tok_embeddings.weight",
15+
"model.norm.weight": "norm.weight",
16+
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
17+
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
18+
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
19+
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
20+
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
21+
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
22+
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
23+
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
24+
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
25+
}
26+
27+
28+
def smollm_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
29+
"""
30+
Convert a state dict from torchtune's format to Meta's format. This function
31+
doesn't handle any sharding or splitting of state dicts. It follows the
32+
state_dict IN -> state_dict OUT pattern.
33+
34+
Args:
35+
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.
36+
37+
Returns:
38+
Dict[str, torch.Tensor]: State dict in Meta's format.
39+
"""
40+
converted_state_dict = {}
41+
for key, value in state_dict.items():
42+
new_key = get_mapped_key(key, _SMOLLM_TO_META)
43+
converted_state_dict[new_key] = value
44+
converted_state_dict["output.weight"] = converted_state_dict[
45+
"tok_embeddings.weight"
46+
]
47+
48+
return converted_state_dict
49+
50+
51+
def load_checkpoint_from_safetensors(input_dir: str) -> Dict:
52+
index_path = os.path.join(input_dir, "model.safetensors.index.json")
53+
if os.path.exists(index_path):
54+
# Sharded checkpoint.
55+
with open(index_path, "r") as f:
56+
index = json.load(f)
57+
weight_map = index["weight_map"]
58+
checkpoint_shards = sorted(set(weight_map.values()))
59+
60+
# Load all the shards into memory
61+
shard_to_weights = {}
62+
for shard in checkpoint_shards:
63+
shard_to_weights[shard] = load_file(os.path.join(input_dir, shard))
64+
65+
# Merge tensors into consolidated state dict.
66+
merged_state_dict = {}
67+
for weight_name, shard in weight_map.items():
68+
tensor = shard_to_weights[shard][weight_name]
69+
merged_state_dict[weight_name] = tensor
70+
return merged_state_dict
71+
else:
72+
# Single checkpoint.
73+
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
74+
return state_dict
75+
76+
77+
def load_checkpoint(input_dir: str) -> Dict:
78+
pytorch_path = os.path.join(input_dir, "pytorch_model.bin")
79+
if os.path.exists(pytorch_path):
80+
print("Loading checkpoint from PyTorch .bin file")
81+
return torch.load(pytorch_path, map_location="cpu", weights_only=True)
82+
print("Loading checkpoint from safetensors directory")
83+
return load_checkpoint_from_safetensors(input_dir)
84+
85+
86+
def convert_weights(input_dir: str, output_file: str) -> None:
87+
print("Loading checkpoint...")
88+
sd = load_checkpoint(input_dir)
89+
print("Converting checkpoint...")
90+
sd = smollm_to_meta(sd)
91+
print("Saving checkpoint...")
92+
torch.save(sd, output_file)
93+
print("Done.")
94+
95+
96+
def main():
97+
parser = argparse.ArgumentParser(
98+
description="Convert SmolLM weights to Meta format."
99+
)
100+
parser.add_argument(
101+
"input_dir",
102+
type=str,
103+
help="Path to directory containing checkpoint files",
104+
)
105+
parser.add_argument("output", type=str, help="Path to the output checkpoint")
106+
107+
args = parser.parse_args()
108+
convert_weights(args.input_dir, args.output)
109+
110+
111+
if __name__ == "__main__":
112+
main()

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ This file provides you the instructions to run LLM Decoder model with different
99
5. Phi4-mini-instruct
1010
6. QWEN2.5 0.5B / 1.5B
1111
7. QWEN3 0.6B / 1.7B
12-
8. SMOLLM2 135M
12+
8. SmolLM2 135M
13+
9. SmolLM3 3B
1314

1415

1516
We offer the following modes to execute the model:
@@ -107,10 +108,16 @@ Default example using hybrid mode
107108
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode kv --max_seq_len 1024 --decoder_model qwen3-1_7b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
108109
```
109110

110-
#### SMOLLM2
111+
#### SmolLM2
111112
Default example using hybrid mode.
112113
```bash
113-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -H mlgtw-linux -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm2_135m --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
114+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm2_135m --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
115+
```
116+
117+
#### SmolLM3
118+
Default example using kv mode.
119+
```bash
120+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --decoder_model smollm3-3b --model_mode kv --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1
114121
```
115122

116123

0 commit comments

Comments
 (0)