Skip to content

Commit 96e82f5

Browse files
Qualcomm AI Engine Direct - GA Static Qwen3 (#13086)
Summary: - support Qwen3-0.6B - support Qwen3-1.7B - refactor HF model registration for static llama Script ``` bash python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s $DEVICE -m SM8750 --prompt "I would like to learn python, could you teach me with a simple example?" --temperature 0 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --ptq 16a4w --decoder_model qwen3 ``` Stat <img width="1709" height="789" alt="ee13db20-f529-413f-95c8-b6ce4bfcb4f4" src="https://github.com/user-attachments/assets/0bea9db8-f10c-4d5b-96fa-31cd775d0a74" /> ### Test plan Note: We only run Qwen3-0.6B for CI ``` bash python backends/qualcomm/tests/test_qnn_delegate.py -k TestExampleLLMScript.test_qwen3 --model SM8650 --build_folder build-android/ --executorch_root . -s $DEVICE --artifact ./qwen3 ``` cc: @haowhsu-quic, @cccclai
1 parent 52ce330 commit 96e82f5

File tree

6 files changed

+203
-35
lines changed

6 files changed

+203
-35
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4588,6 +4588,65 @@ def test_static_qwen2_5(self):
45884588
msg["inference_speed"], inference_speed_ref[self.model]
45894589
)
45904590

4591+
def test_qwen3(self):
4592+
if not self.required_envs():
4593+
self.skipTest("missing required envs")
4594+
4595+
prompt = "My favourite condiment is "
4596+
cmds = [
4597+
"python",
4598+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
4599+
"--artifact",
4600+
self.artifact_dir,
4601+
"--build_folder",
4602+
self.build_folder,
4603+
"--model",
4604+
self.model,
4605+
"--ip",
4606+
self.ip,
4607+
"--port",
4608+
str(self.port),
4609+
"--prompt",
4610+
f"{prompt}",
4611+
"--ptq",
4612+
"16a8w",
4613+
"--decoder_model",
4614+
"qwen3_0.6b",
4615+
"--model_mode",
4616+
"hybrid",
4617+
"--prefill_ar_len",
4618+
"32",
4619+
"--max_seq_len",
4620+
"128",
4621+
]
4622+
if self.compile_only:
4623+
cmds.extend(["--compile_only"])
4624+
elif self.device:
4625+
cmds.extend(["--device", self.device])
4626+
if self.host:
4627+
cmds.extend(["--host", self.host])
4628+
elif self.enable_x86_64:
4629+
cmds.extend(["--enable_x86_64"])
4630+
if self.pre_gen_pte:
4631+
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
4632+
4633+
# Accuracy is bad for now. Just check user's prompt is returned.
4634+
golden_start_with = "My favourite condiment is "
4635+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
4636+
with Listener((self.ip, self.port)) as listener:
4637+
conn = listener.accept()
4638+
p.communicate()
4639+
msg = json.loads(conn.recv())
4640+
if "Error" in msg:
4641+
self.fail(msg["Error"])
4642+
else:
4643+
model_out = msg["result"][0]
4644+
self.assertTrue(
4645+
model_out.startswith(golden_start_with),
4646+
f"Expected Output: {golden_start_with}. Actual Output: {model_out}",
4647+
)
4648+
self.assertGreaterEqual(msg["inference_speed"], 70) # Lanai
4649+
45914650

45924651
class TestExampleOssScript(TestQNN):
45934652
def test_albert(self):
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
from abc import ABC
9+
from dataclasses import dataclass, field
10+
from typing import Callable, Dict, Type
11+
12+
from executorch.examples.models.qwen2_5 import (
13+
convert_weights as convert_qwen2_5_weights,
14+
)
15+
from executorch.examples.models.qwen3 import convert_weights as convert_qwen3_weights
16+
17+
from executorch.examples.qualcomm.oss_scripts.llama.decoder_constants import (
18+
DECODER_MODEL_VERSION,
19+
)
20+
21+
BASE_DIR = os.path.dirname(__file__)
22+
23+
24+
@dataclass(init=False, frozen=True)
25+
class HFModel(ABC):
26+
repo_id: str
27+
params_path: str
28+
runner_version: str
29+
convert_weights: Callable
30+
31+
32+
SUPPORTED_HF_MODELS: Dict[str, Type[HFModel]] = {}
33+
34+
35+
def register_hf_model(name: str):
36+
def decorator(cls: Type[HFModel]):
37+
SUPPORTED_HF_MODELS[name.lower()] = cls()
38+
return cls()
39+
40+
return decorator
41+
42+
43+
@register_hf_model("qwen2_5")
44+
@dataclass(init=False, frozen=True)
45+
class Qwen2_5(HFModel):
46+
repo_id: str = "Qwen/Qwen2.5-0.5B"
47+
params_path: str = os.path.join(
48+
BASE_DIR, "../../../models/qwen2_5/config/0_5b_config.json"
49+
)
50+
runner_version: str = field(default=DECODER_MODEL_VERSION["qwen2_5"])
51+
convert_weights = convert_qwen2_5_weights
52+
53+
54+
@register_hf_model("qwen3_0_6b")
55+
@dataclass(init=False, frozen=True)
56+
class Qwen3_0_6B(HFModel):
57+
repo_id: str = "Qwen/Qwen3-0.6B"
58+
params_path: str = os.path.join(
59+
BASE_DIR, "../../../models/qwen3/config/0_6b_config.json"
60+
)
61+
runner_version: str = field(default=DECODER_MODEL_VERSION["qwen2_5"])
62+
convert_weights = convert_qwen3_weights
63+
64+
65+
@register_hf_model("qwen3_1_7b")
66+
@dataclass(init=False, frozen=True)
67+
class Qwen3_1_7B(HFModel):
68+
repo_id: str = "Qwen/Qwen/Qwen3-1.7B"
69+
params_path: str = os.path.join(
70+
BASE_DIR, "../../../models/qwen3/config/1_7b_config.json"
71+
)
72+
runner_version: str = field(default=DECODER_MODEL_VERSION["qwen2_5"])
73+
convert_weights = convert_qwen3_weights

examples/qualcomm/oss_scripts/llama/decoder_constants.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
HUGGING_FACE_REPO_IDS = {"qwen2_5": "Qwen/Qwen2.5-0.5B"}
8-
97
EVAL_MODE = {
108
"kv": 0,
119
"hybrid": 1,

examples/qualcomm/oss_scripts/llama/decoder_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
DECODER_MODEL_VERSION,
2020
EVAL_MODE,
2121
)
22-
2322
from executorch.examples.qualcomm.utils import make_output_dir, SimpleADB
2423
from executorch.exir._serialize._program import deserialize_pte_binary
2524
from pytorch_tokenizers.hf_tokenizer import HuggingFaceTokenizer

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@
6464
from executorch.examples.models.llama.source_transformation.quantize import (
6565
get_quant_embedding_transform,
6666
)
67+
from executorch.examples.qualcomm.oss_scripts.llama import SUPPORTED_HF_MODELS
6768
from executorch.examples.qualcomm.oss_scripts.llama.decoder_constants import (
6869
DECODER_MODEL_VERSION,
6970
EVAL_MODE,
70-
HUGGING_FACE_REPO_IDS,
7171
)
7272
from executorch.examples.qualcomm.oss_scripts.llama.decoder_utils import (
7373
graph_module_inference,
@@ -227,7 +227,6 @@ def quantize(
227227

228228
self.has_quant_io = True
229229
fx_graph_module = None
230-
231230
with torch.no_grad():
232231
fx_graph_module = torch.export.export(
233232
self.llama_graph_module, self.inputs, strict=True
@@ -351,14 +350,11 @@ def compile(args, pte_filename, tokenizer):
351350

352351
kv_config, prefill_config = None, None
353352
if args.params:
354-
with open(args.params) as f:
355-
kv_config = ModelArgs(**json.load(f))
356-
elif args.decoder_model == "qwen2_5":
357-
from importlib.resources import files
358-
359-
data_dir = files("executorch").joinpath("examples/models/qwen2_5/config")
360-
config_file = data_dir.joinpath("0_5b_config.json")
361-
kv_config = ModelArgs(**json.loads(config_file.read_text()))
353+
params_path = args.params
354+
else:
355+
params_path = SUPPORTED_HF_MODELS[args.decoder_model].params_path
356+
with open(params_path) as f:
357+
kv_config = ModelArgs(**json.load(f))
362358

363359
# TODO: support batch inputs if necessary
364360
kv_config.max_batch_size = 1
@@ -430,13 +426,10 @@ def compile(args, pte_filename, tokenizer):
430426
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")
431427

432428
if args.checkpoint is None: # HF models
433-
model_id = HUGGING_FACE_REPO_IDS[args.decoder_model]
434-
if args.decoder_model == "qwen2_5":
435-
from executorch.examples.models.qwen2_5 import ( # pyre-ignore[21]
436-
convert_weights,
437-
)
438-
439-
checkpoint = download_and_convert_hf_checkpoint(model_id, convert_weights)
429+
checkpoint = download_and_convert_hf_checkpoint(
430+
SUPPORTED_HF_MODELS[args.decoder_model].repo_id,
431+
SUPPORTED_HF_MODELS[args.decoder_model].convert_weights,
432+
)
440433
state_dict = torch.load(
441434
checkpoint, weights_only=True, map_location="cpu", mmap=True
442435
)
@@ -964,8 +957,9 @@ def _build_parser():
964957

965958
parser.add_argument(
966959
"--decoder_model",
967-
choices=["stories260k", "stories110m", "llama3_2", "qwen2_5"],
968-
help="The Llama model to export. Current available options are: [stories260k, stories110m, llama3_2, qwen2_5]",
960+
choices=["stories260k", "stories110m", "llama3_2"]
961+
+ list(SUPPORTED_HF_MODELS.keys()),
962+
help=f"The Llama model to export. Current available options are: [stories260k, stories110m, llama3_2] + {SUPPORTED_HF_MODELS.keys()}",
969963
required=True,
970964
)
971965

@@ -1176,11 +1170,19 @@ def export_llama(args) -> None:
11761170
tokenizer, TiktokenTokenizer
11771171
), f"Wrong tokenizer provided for llama3_2."
11781172
runtime_tokenizer_path = args.tokenizer_model
1179-
elif args.decoder_model == "qwen2_5":
1180-
model_id = HUGGING_FACE_REPO_IDS[args.decoder_model]
1173+
elif args.decoder_model in {"qwen2_5", "qwen3_0_6b", "qwen3_1_7b"}:
1174+
model_id = SUPPORTED_HF_MODELS[args.decoder_model].repo_id
11811175
tokenizer = AutoTokenizer.from_pretrained(model_id)
11821176
runtime_tokenizer_path = tokenizer.save_pretrained(args.artifact)[-1]
11831177
tokenizer = get_tokenizer(runtime_tokenizer_path)
1178+
with open(runtime_tokenizer_path, "r+") as file:
1179+
data = json.load(file)
1180+
# TODO: Encountered the following error during runtime, so switched behavior for now.
1181+
# Error: libc++abi: terminating due to uncaught exception of type std::runtime_error: Unsupported Normalizer type: NFC.
1182+
data.pop("normalizer")
1183+
file.seek(0)
1184+
json.dump(data, file, indent=4)
1185+
file.truncate()
11841186
else:
11851187
raise RuntimeError(f"Unknown decoder_model: {args.decoder_model}.")
11861188

examples/qualcomm/oss_scripts/llama/model/static_llama.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
import torch.nn as nn
1616
import torch.nn.functional as F
1717
from executorch.examples.models.llama.model_args import ModelArgs
18-
from executorch.examples.models.llama.rope import precompute_freqs_cis
18+
from executorch.examples.models.llama.rope import (
19+
hf_precompute_freqs_cis,
20+
precompute_freqs_cis,
21+
)
1922

2023

2124
def apply_rotary_emb_single(
@@ -48,6 +51,14 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
4851
self.max_seq_len = config.max_seq_len
4952
self.output_new_cache_only = output_new_cache_only
5053
self.enable_masked_softmax = getattr(config, "enable_masked_softmax", False)
54+
self.use_qk_norm = config.use_qk_norm
55+
self.qk_norm_before_rope = config.qk_norm_before_rope
56+
57+
if self.use_qk_norm:
58+
q_norm_dim = self.head_dim
59+
k_norm_dim = self.head_dim
60+
self.q_norm_fn = torch.nn.RMSNorm(q_norm_dim, eps=config.norm_eps)
61+
self.k_norm_fn = torch.nn.RMSNorm(k_norm_dim, eps=config.norm_eps)
5162

5263
self.wq = nn.Linear(
5364
self.dim,
@@ -151,7 +162,7 @@ def prepare_sha(self):
151162
)
152163
self.wo_sha.weight.data.copy_(self.wo.weight[:, :, None, None])
153164

154-
def forward_sha(
165+
def forward_sha( # noqa: C901
155166
self,
156167
hidden_states: torch.Tensor,
157168
freqs_cos: torch.Tensor,
@@ -184,15 +195,23 @@ def forward_sha(
184195
.reshape(bsz, seq_len, self.head_dim)
185196
for wv_sha in self.wv_sha
186197
]
198+
187199
for i in range(len(q)):
200+
if self.use_qk_norm and self.qk_norm_before_rope:
201+
q[i] = self.q_norm_fn(q[i])
188202
q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin)
189203
if hasattr(self.config, "enable_r3") and self.config.enable_r3:
190204
q[i] = torch.matmul(q[i], self.r3_weight.T)
205+
if self.use_qk_norm and not self.qk_norm_before_rope:
206+
q[i] = self.q_norm_fn(q[i])
191207
for i in range(len(k)):
192-
k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin)
208+
if self.use_qk_norm and self.qk_norm_before_rope:
209+
k[i] = self.k_norm_fn(k[i])
210+
k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).transpose(1, 2)
193211
if hasattr(self.config, "enable_r3") and self.config.enable_r3:
194212
k[i] = torch.matmul(k[i], self.r3_weight.T)
195-
k[i] = k[i].transpose(1, 2)
213+
if self.use_qk_norm and not self.qk_norm_before_rope:
214+
k[i] = self.k_norm_fn(k[i])
196215

197216
output_y = []
198217
kh, vh = [], []
@@ -249,9 +268,17 @@ def forward(
249268
k = k.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
250269
v = v.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
251270

271+
if self.use_qk_norm and self.qk_norm_before_rope:
272+
q = self.q_norm_fn(q)
273+
k = self.k_norm_fn(k)
274+
252275
q = apply_rotary_emb_single(q, freqs_cos, freqs_sin)
253276
k = apply_rotary_emb_single(k, freqs_cos, freqs_sin).permute(0, 2, 3, 1)
254277

278+
if self.use_qk_norm and not self.qk_norm_before_rope:
279+
q = self.q_norm_fn(q)
280+
k = self.k_norm_fn(k)
281+
255282
output_kh, output_vh, output_y = [], [], []
256283
kh, vh = [], []
257284
# kv cache mode
@@ -403,13 +430,23 @@ def __init__(
403430
self.norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
404431
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
405432
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
406-
freqs_cos, freqs_sin = precompute_freqs_cis(
407-
config.head_dim,
408-
config.max_seq_len,
409-
config.rope_freq_base,
410-
config.use_scaled_rope,
411-
config.rope_scale_factor,
412-
)
433+
if config.use_hf_rope:
434+
freqs_cos, freqs_sin = hf_precompute_freqs_cis(
435+
config.head_dim,
436+
config.max_seq_len,
437+
config.rope_freq_base,
438+
config.partial_rotary_factor,
439+
)
440+
freqs_cos = freqs_cos[:, : freqs_cos.shape[-1] // 2]
441+
freqs_sin = freqs_sin[:, : freqs_sin.shape[-1] // 2]
442+
else:
443+
freqs_cos, freqs_sin = precompute_freqs_cis(
444+
config.head_dim,
445+
config.max_seq_len,
446+
config.rope_freq_base,
447+
config.use_scaled_rope,
448+
config.rope_scale_factor,
449+
)
413450
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
414451
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
415452

0 commit comments

Comments
 (0)