Skip to content

Commit c310874

Browse files
chenweng-quicCheng-Hsin Weng
andauthored
Qualcomm AI Engine Direct - GA Static Smollm2 (#13406)
### Summary Summary <img width="1607" height="1117" alt="image" src="https://github.com/user-attachments/assets/acefe148-cfca-42e7-9ea1-07e2df7bd72b" /> ### Test plan ``` python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -H <host> -s <device_id> -m SM8650 --ptq 16a8w --decoder_model smollm2_135m --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "What is the capital of France." python backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_smollm2 --device <device_id> --host <host> --model <soc_model> --build_folder build-android --executorch_root . --artifact all_artifact ``` --------- Co-authored-by: Cheng-Hsin Weng <[email protected]>
1 parent 8ccf38f commit c310874

File tree

8 files changed

+106
-7
lines changed

8 files changed

+106
-7
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4649,6 +4649,64 @@ def test_static_qwen3(self):
46494649
)
46504650
self.assertGreaterEqual(msg["inference_speed"], 70) # Lanai
46514651

4652+
def test_smollm2(self):
4653+
if not self.required_envs():
4654+
self.skipTest("missing required envs")
4655+
4656+
prompt = "My favourite condiment is "
4657+
cmds = [
4658+
"python",
4659+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
4660+
"--artifact",
4661+
self.artifact_dir,
4662+
"--build_folder",
4663+
self.build_folder,
4664+
"--model",
4665+
self.model,
4666+
"--ip",
4667+
self.ip,
4668+
"--port",
4669+
str(self.port),
4670+
"--prompt",
4671+
f"{prompt}",
4672+
"--ptq",
4673+
"16a8w",
4674+
"--decoder_model",
4675+
"smollm2_135m",
4676+
"--model_mode",
4677+
"kv",
4678+
"--temperature",
4679+
"0",
4680+
"--prefill_ar_len",
4681+
"128",
4682+
"--max_seq_len",
4683+
"1024",
4684+
"--eval_perplexity",
4685+
"--task",
4686+
"wikitext",
4687+
]
4688+
if self.compile_only:
4689+
cmds.extend(["--compile_only"])
4690+
elif self.device:
4691+
cmds.extend(["--device", self.device])
4692+
if self.host:
4693+
cmds.extend(["--host", self.host])
4694+
elif self.enable_x86_64:
4695+
cmds.extend(["--enable_x86_64"])
4696+
if self.pre_gen_pte:
4697+
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
4698+
4699+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
4700+
with Listener((self.ip, self.port)) as listener:
4701+
conn = listener.accept()
4702+
p.communicate()
4703+
msg = json.loads(conn.recv())
4704+
if "Error" in msg:
4705+
self.fail(msg["Error"])
4706+
else:
4707+
self.assertLessEqual(msg["wiki_ppl"], 25)
4708+
self.assertGreaterEqual(msg["inference_speed"], 200)
4709+
46524710

46534711
class TestExampleOssScript(TestQNN):
46544712
def test_albert(self):

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ This file provides you the instructions to run LLM Decoder model with different
88
4. QWEN2.5 0.5B
99
5. QWEN3 0.6B / 1.7B
1010
6. Phi4-mini-instruct
11+
7. SMOLLM2 135M
1112

1213
We offer the following modes to execute the model:
1314

@@ -74,6 +75,12 @@ Default example using hybrid mode
7475
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --ptq 16a8w --enable_masked_softmax --r3 --decoder_model qwen2_5 --prompt "I would like to learn python, could you teach me with a simple example?"
7576
```
7677

78+
#### SMOLLM2
79+
Default example using hybrid mode.
80+
```bash
81+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -H mlgtw-linux -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a8w --tokenizer_bin tokenizer.bin --decoder_model smollm2 --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?"
82+
```
83+
7784
### KV Cache update mechanism
7885
We have two distinct mechanisms for updating the key-value (KV) cache, which can be selected at runtime. Shift Pointer and Smart Mask.
7986

examples/qualcomm/oss_scripts/llama/__init__.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
convert_weights as convert_qwen2_5_weights,
1717
)
1818
from executorch.examples.models.qwen3 import convert_weights as convert_qwen3_weights
19-
19+
from executorch.examples.models.smollm2 import (
20+
convert_weights as convert_smollm2_weights,
21+
)
2022
from executorch.examples.qualcomm.oss_scripts.llama.decoder_constants import (
2123
DECODER_MODEL_VERSION,
2224
)
@@ -52,6 +54,7 @@ class Qwen2_5(HFModel):
5254
)
5355
runner_version: str = field(default=DECODER_MODEL_VERSION["qwen2_5"])
5456
convert_weights = convert_qwen2_5_weights
57+
transform_weight = False
5558

5659

5760
@register_hf_model("qwen3_0_6b")
@@ -63,6 +66,7 @@ class Qwen3_0_6B(HFModel):
6366
)
6467
runner_version: str = field(default=DECODER_MODEL_VERSION["qwen2_5"])
6568
convert_weights = convert_qwen3_weights
69+
transform_weight = False
6670

6771

6872
@register_hf_model("qwen3_1_7b")
@@ -74,6 +78,7 @@ class Qwen3_1_7B(HFModel):
7478
)
7579
runner_version: str = field(default=DECODER_MODEL_VERSION["qwen2_5"])
7680
convert_weights = convert_qwen3_weights
81+
transform_weight = False
7782

7883

7984
@register_hf_model("phi_4_mini")
@@ -85,3 +90,16 @@ class Phi4Mini(HFModel):
8590
)
8691
runner_version: str = field(default=DECODER_MODEL_VERSION["phi_4_mini"])
8792
convert_weights = convert_phi_4_mini_weights
93+
transform_weight = False
94+
95+
96+
@register_hf_model("smollm2_135m")
97+
@dataclass(init=False, frozen=True)
98+
class Smollm2_135M(HFModel):
99+
repo_id: str = "HuggingFaceTB/SmolLM2-135M-Instruct"
100+
params_path: str = os.path.join(
101+
BASE_DIR, "../../../models/smollm2/135M_config.json"
102+
)
103+
runner_version: str = field(default=DECODER_MODEL_VERSION["smollm2_135m"])
104+
convert_weights = convert_smollm2_weights
105+
transform_weight = True

examples/qualcomm/oss_scripts/llama/decoder_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@
1818
"qwen3_0_6b": "qwen2_5", # TODO: temp workaround, use special token for qwen3 in runner
1919
"qwen3_1_7b": "qwen2_5",
2020
"phi_4_mini": "phi_4_mini",
21+
"smollm2_135m": "smollm2_135m",
2122
}

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ def compile(args, pte_filename, tokenizer):
434434
state_dict = torch.load(
435435
checkpoint, weights_only=True, map_location="cpu", mmap=True
436436
)
437+
transform_weight = SUPPORTED_HF_MODELS[args.decoder_model].transform_weight
437438
else:
438439
state_dict = torch.load(
439440
args.checkpoint, weights_only=True, map_location="cpu", mmap=True
@@ -444,7 +445,9 @@ def compile(args, pte_filename, tokenizer):
444445

445446
if args.decoder_model == "stories260k":
446447
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
448+
transform_weight = True
447449

450+
if transform_weight:
448451
# Change to HuggingFace weight to improve the performance of RoPE in HTP backend.
449452
def permute(w, heads):
450453
dim_0 = w.size(0)
@@ -1172,11 +1175,6 @@ def export_llama(args) -> None:
11721175
tokenizer, TiktokenTokenizer
11731176
), f"Wrong tokenizer provided for llama3_2."
11741177
runtime_tokenizer_path = args.tokenizer_model
1175-
elif args.decoder_model in {"qwen2_5", "qwen3_0_6b", "qwen3_1_7b"}:
1176-
model_id = SUPPORTED_HF_MODELS[args.decoder_model].repo_id
1177-
tokenizer = AutoTokenizer.from_pretrained(model_id)
1178-
runtime_tokenizer_path = tokenizer.save_pretrained(args.artifact)[-1]
1179-
tokenizer = get_tokenizer(runtime_tokenizer_path)
11801178
elif args.decoder_model == "phi_4_mini":
11811179
model_id = SUPPORTED_HF_MODELS[args.decoder_model].repo_id
11821180
tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -1190,6 +1188,11 @@ def export_llama(args) -> None:
11901188
file.seek(0)
11911189
json.dump(data, file, indent=4)
11921190
file.truncate()
1191+
elif args.decoder_model in SUPPORTED_HF_MODELS:
1192+
model_id = SUPPORTED_HF_MODELS[args.decoder_model].repo_id
1193+
tokenizer = AutoTokenizer.from_pretrained(model_id)
1194+
runtime_tokenizer_path = tokenizer.save_pretrained(args.artifact)[-1]
1195+
tokenizer = get_tokenizer(runtime_tokenizer_path)
11931196
else:
11941197
raise RuntimeError(f"Unknown decoder_model: {args.decoder_model}.")
11951198

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
* @file
1111
*
1212
* This tool can run Llama2 110M, Llama3.2 1B / 3B, Qwen2.5 0.5B, Qwen3 0.6B
13-
* / 1.7B, phi4-mini-instruct with Qualcomm AI Engine Direct.
13+
* / 1.7B, phi4-mini-instruct, Smollm2 135M with Qualcomm AI Engine Direct.
1414
*
1515
*/
1616

@@ -113,6 +113,15 @@ std::string get_formatted_prompt(
113113
formatted_prompt.append("<|user|>");
114114
formatted_prompt.append(prompt);
115115
formatted_prompt.append("<|end|><|assistant|>");
116+
case example::DecoderModelVersion::kSmollm2_135m:
117+
if (!system_prompt.empty()) {
118+
formatted_prompt.append("<|im_start|>system\n");
119+
formatted_prompt.append(system_prompt);
120+
formatted_prompt.append("<|im_end|>\n\n");
121+
}
122+
formatted_prompt.append("<|im_start|>user\n");
123+
formatted_prompt.append(prompt);
124+
formatted_prompt.append("<|im_end|>\n\n");
116125
break;
117126
case example::DecoderModelVersion::kLlama3:
118127
if (!system_prompt.empty()) {

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ Runner<T>::Runner(
132132
decoder_model_version_ = DecoderModelVersion::kQwen2_5;
133133
} else if (decoder_model_version == "phi_4_mini") {
134134
decoder_model_version_ = DecoderModelVersion::kPhi4;
135+
} else if (decoder_model_version == "smollm2_135m") {
136+
decoder_model_version_ = DecoderModelVersion::kSmollm2_135m;
135137
} else {
136138
ET_CHECK_MSG(false, "Unsupported Decoder Model");
137139
}

examples/qualcomm/oss_scripts/llama/runner/runner.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ enum DecoderModelVersion {
3232
kLlama3,
3333
kQwen2_5,
3434
kPhi4,
35+
kSmollm2_135m
3536
};
3637

3738
enum KvBitWidth {

0 commit comments

Comments
 (0)