Skip to content

Commit d0e1d1f

Browse files
Qualcomm AI Engine Direct - GA Static Qwen3
Summary: - support Qwen3-0.6B - support Qwen3-1.7B - refactor HF model registration for static llama
1 parent 2f782bf commit d0e1d1f

File tree

4 files changed

+180
-35
lines changed

4 files changed

+180
-35
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4372,6 +4372,65 @@ def test_qwen2_5(self):
43724372
)
43734373
self.assertGreaterEqual(msg["inference_speed"], 95) # Lanai
43744374

4375+
def test_qwen3(self):
4376+
if not self.required_envs():
4377+
self.skipTest("missing required envs")
4378+
4379+
prompt = "My favourite condiment is "
4380+
cmds = [
4381+
"python",
4382+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
4383+
"--artifact",
4384+
self.artifact_dir,
4385+
"--build_folder",
4386+
self.build_folder,
4387+
"--model",
4388+
self.model,
4389+
"--ip",
4390+
self.ip,
4391+
"--port",
4392+
str(self.port),
4393+
"--prompt",
4394+
f"{prompt}",
4395+
"--ptq",
4396+
"16a8w",
4397+
"--decoder_model",
4398+
"qwen3_0.6b",
4399+
"--model_mode",
4400+
"hybrid",
4401+
"--prefill_ar_len",
4402+
"32",
4403+
"--max_seq_len",
4404+
"128",
4405+
]
4406+
if self.compile_only:
4407+
cmds.extend(["--compile_only"])
4408+
elif self.device:
4409+
cmds.extend(["--device", self.device])
4410+
if self.host:
4411+
cmds.extend(["--host", self.host])
4412+
elif self.enable_x86_64:
4413+
cmds.extend(["--enable_x86_64"])
4414+
if self.pre_gen_pte:
4415+
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
4416+
4417+
# Accuracy is bad for now. Just check user's prompt is returned.
4418+
golden_start_with = "My favourite condiment is "
4419+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
4420+
with Listener((self.ip, self.port)) as listener:
4421+
conn = listener.accept()
4422+
p.communicate()
4423+
msg = json.loads(conn.recv())
4424+
if "Error" in msg:
4425+
self.fail(msg["Error"])
4426+
else:
4427+
model_out = msg["result"][0]
4428+
self.assertTrue(
4429+
model_out.startswith(golden_start_with),
4430+
f"Expected Output: {golden_start_with}. Actual Output: {model_out}",
4431+
)
4432+
self.assertGreaterEqual(msg["inference_speed"], 70) # Lanai
4433+
43754434

43764435
class TestExampleOssScript(TestQNN):
43774436
def test_albert(self):
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import os
2+
from abc import ABC
3+
from dataclasses import dataclass
4+
from typing import Callable, Dict, Type
5+
6+
from executorch.examples.models.qwen2_5 import (
7+
convert_weights as convert_qwen2_5_weights,
8+
)
9+
from executorch.examples.models.qwen3 import convert_weights as convert_qwen3_weights
10+
11+
BASE_DIR = os.path.dirname(__file__)
12+
13+
14+
@dataclass(init=False, frozen=True)
15+
class HFModel(ABC):
16+
repo_id: str
17+
params_path: str
18+
runner_version: str
19+
convert_weights: Callable
20+
21+
22+
SUPPORTED_HF_MODELS: Dict[str, Type[HFModel]] = {}
23+
24+
25+
def register_hf_model(name: str):
26+
def decorator(cls: Type[HFModel]):
27+
SUPPORTED_HF_MODELS[name.lower()] = cls()
28+
return cls()
29+
30+
return decorator
31+
32+
33+
@register_hf_model("qwen2_5")
34+
@dataclass(init=False, frozen=True)
35+
class Qwen2_5(HFModel):
36+
repo_id: str = "Qwen/Qwen2.5-0.5B"
37+
params_path: str = os.path.join(
38+
BASE_DIR, "../../../models/qwen2_5/config/0_5b_config.json"
39+
)
40+
runner_version: str = "qwen2_5"
41+
convert_weights = convert_qwen2_5_weights
42+
43+
44+
@register_hf_model("qwen3_0_6b")
45+
@dataclass(init=False, frozen=True)
46+
class Qwen3_0_6B(HFModel):
47+
repo_id: str = "Qwen/Qwen3-0.6B"
48+
params_path: str = os.path.join(
49+
BASE_DIR, "../../../models/qwen3/config/0_6b_config.json"
50+
)
51+
runner_version: str = "qwen2_5"
52+
convert_weights = convert_qwen3_weights
53+
54+
55+
@register_hf_model("qwen3_1_7b")
56+
@dataclass(init=False, frozen=True)
57+
class Qwen3_1_7B(HFModel):
58+
repo_id: str = "Qwen/Qwen/Qwen3-1.7B"
59+
params_path: str = os.path.join(
60+
BASE_DIR, "../../../models/qwen3/config/1_7b_config.json"
61+
)
62+
runner_version: str = "qwen2_5"
63+
convert_weights = convert_qwen3_weights

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from executorch.examples.models.llama.source_transformation.quantize import (
6666
get_quant_embedding_transform,
6767
)
68+
from executorch.examples.qualcomm.oss_scripts.llama import SUPPORTED_HF_MODELS
6869
from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import (
6970
LlamaModel,
7071
ModelArgs,
@@ -103,8 +104,6 @@
103104
logging.basicConfig(level=logging.INFO, format=FORMAT)
104105
logging.getLogger().setLevel(logging.INFO)
105106

106-
HUGGING_FACE_REPO_IDS = {"qwen2_5": "Qwen/Qwen2.5-0.5B"}
107-
108107

109108
def next_power_of_two(n):
110109
if n == 0:
@@ -409,7 +408,6 @@ def quantize(
409408

410409
self.has_quant_io = True
411410
fx_graph_module = None
412-
413411
with torch.no_grad():
414412
fx_graph_module = torch.export.export(
415413
self.llama_graph_module, self.inputs, strict=True
@@ -517,18 +515,7 @@ def compile(args, pte_filename, tokenizer):
517515
if args.params:
518516
params_path = args.params
519517
else:
520-
if args.decoder_model == "qwen2_5":
521-
cur_dir = os.path.dirname(__file__)
522-
params_path = os.path.join(
523-
cur_dir,
524-
"..",
525-
"..",
526-
"..",
527-
"models",
528-
"qwen2_5",
529-
"config",
530-
"0_5b_config.json",
531-
)
518+
params_path = SUPPORTED_HF_MODELS[args.decoder_model].params_path
532519
with open(params_path) as f:
533520
kv_config = ModelArgs(**json.load(f))
534521

@@ -601,13 +588,10 @@ def compile(args, pte_filename, tokenizer):
601588
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")
602589

603590
if args.checkpoint is None: # HF models
604-
model_id = HUGGING_FACE_REPO_IDS[args.decoder_model]
605-
if args.decoder_model == "qwen2_5":
606-
from executorch.examples.models.qwen2_5 import ( # pyre-ignore[21]
607-
convert_weights,
608-
)
609-
610-
checkpoint = download_and_convert_hf_checkpoint(model_id, convert_weights)
591+
checkpoint = download_and_convert_hf_checkpoint(
592+
SUPPORTED_HF_MODELS[args.decoder_model].repo_id,
593+
SUPPORTED_HF_MODELS[args.decoder_model].convert_weights,
594+
)
611595
state_dict = torch.load(
612596
checkpoint, weights_only=True, map_location="cpu", mmap=True
613597
)
@@ -1041,8 +1025,9 @@ def _build_parser():
10411025

10421026
parser.add_argument(
10431027
"--decoder_model",
1044-
choices=["stories260k", "stories110m", "llama3_2", "qwen2_5"],
1045-
help="The Llama model to export. Current available options are: [stories260k, stories110m, llama3_2, qwen2_5]",
1028+
choices=["stories260k", "stories110m", "llama3_2"]
1029+
+ list(SUPPORTED_HF_MODELS.keys()),
1030+
help=f"The Llama model to export. Current available options are: [stories260k, stories110m, llama3_2] + {SUPPORTED_HF_MODELS.keys()}",
10461031
required=True,
10471032
)
10481033

@@ -1244,12 +1229,12 @@ def export_llama(args) -> None:
12441229
), f"Wrong tokenizer provided for llama3_2."
12451230
runtime_tokenizer_path = args.tokenizer_model
12461231
decoder_model_version = "llama3"
1247-
elif args.decoder_model == "qwen2_5":
1248-
model_id = HUGGING_FACE_REPO_IDS[args.decoder_model]
1232+
elif args.decoder_model in {"qwen2_5", "qwen3_0_6b", "qwen3_1_7b"}:
1233+
model_id = SUPPORTED_HF_MODELS[args.decoder_model].repo_id
12491234
tokenizer = AutoTokenizer.from_pretrained(model_id)
12501235
runtime_tokenizer_path = tokenizer.save_pretrained(args.artifact)[-1]
12511236
tokenizer = get_tokenizer(runtime_tokenizer_path)
1252-
decoder_model_version = args.decoder_model
1237+
decoder_model_version = SUPPORTED_HF_MODELS[args.decoder_model].runner_version
12531238

12541239
with open(runtime_tokenizer_path, "r+") as file:
12551240
data = json.load(file)

examples/qualcomm/oss_scripts/llama/model/static_llama.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
import torch.nn as nn
1414
import torch.nn.functional as F
1515
from executorch.examples.models.llama.model_args import ModelArgs
16-
from executorch.examples.models.llama.rope import precompute_freqs_cis
16+
from executorch.examples.models.llama.rope import (
17+
hf_precompute_freqs_cis,
18+
precompute_freqs_cis,
19+
)
1720

1821

1922
def apply_rotary_emb_single(
@@ -46,6 +49,14 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
4649
self.max_seq_len = config.max_seq_len
4750
self.output_new_cache_only = output_new_cache_only
4851
self.enable_masked_softmax = getattr(config, "enable_masked_softmax", False)
52+
self.use_qk_norm = config.use_qk_norm
53+
self.qk_norm_before_rope = config.qk_norm_before_rope
54+
55+
if self.use_qk_norm:
56+
q_norm_dim = self.head_dim
57+
k_norm_dim = self.head_dim
58+
self.q_norm_fn = torch.nn.RMSNorm(q_norm_dim, eps=config.norm_eps)
59+
self.k_norm_fn = torch.nn.RMSNorm(k_norm_dim, eps=config.norm_eps)
4960

5061
self.wq = nn.Linear(
5162
self.dim,
@@ -170,10 +181,19 @@ def forward_sha(
170181
.reshape(bsz, seq_len, self.head_dim)
171182
for wv_sha in self.wv_sha
172183
]
184+
173185
for i in range(len(q)):
186+
if self.use_qk_norm and self.qk_norm_before_rope:
187+
q[i] = self.q_norm_fn(q[i])
174188
q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin)
189+
if self.use_qk_norm and not self.qk_norm_before_rope:
190+
q[i] = self.q_norm_fn(q[i])
175191
for i in range(len(k)):
192+
if self.use_qk_norm and self.qk_norm_before_rope:
193+
k[i] = self.k_norm_fn(k[i])
176194
k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).transpose(1, 2)
195+
if self.use_qk_norm and not self.qk_norm_before_rope:
196+
k[i] = self.k_norm_fn(k[i])
177197

178198
output_y = []
179199
kh, vh = [], []
@@ -230,9 +250,17 @@ def forward(
230250
k = k.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
231251
v = v.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
232252

253+
if self.use_qk_norm and self.qk_norm_before_rope:
254+
q = self.q_norm_fn(q)
255+
k = self.k_norm_fn(k)
256+
233257
q = apply_rotary_emb_single(q, freqs_cos, freqs_sin)
234258
k = apply_rotary_emb_single(k, freqs_cos, freqs_sin).permute(0, 2, 3, 1)
235259

260+
if self.use_qk_norm and not self.qk_norm_before_rope:
261+
q = self.q_norm_fn(q)
262+
k = self.k_norm_fn(k)
263+
236264
output_kh, output_vh, output_y = [], [], []
237265
kh, vh = [], []
238266
# kv cache mode
@@ -384,13 +412,23 @@ def __init__(
384412
self.norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)
385413
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
386414
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
387-
freqs_cos, freqs_sin = precompute_freqs_cis(
388-
config.head_dim,
389-
config.max_seq_len,
390-
config.rope_freq_base,
391-
config.use_scaled_rope,
392-
config.rope_scale_factor,
393-
)
415+
if config.use_hf_rope:
416+
freqs_cos, freqs_sin = hf_precompute_freqs_cis(
417+
config.head_dim,
418+
config.max_seq_len,
419+
config.rope_freq_base,
420+
config.partial_rotary_factor,
421+
)
422+
freqs_cos = freqs_cos[:, : freqs_cos.shape[-1] // 2]
423+
freqs_sin = freqs_sin[:, : freqs_sin.shape[-1] // 2]
424+
else:
425+
freqs_cos, freqs_sin = precompute_freqs_cis(
426+
config.head_dim,
427+
config.max_seq_len,
428+
config.rope_freq_base,
429+
config.use_scaled_rope,
430+
config.rope_scale_factor,
431+
)
394432
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
395433
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
396434

0 commit comments

Comments
 (0)