Skip to content

Commit f154d50

Browse files
authored
Qualcomm AI Engine Direct - Scripts and accuracy improvement for Qwen3_0.6B/1.7B and Qwen 2.5_1.5B (#13544)
### Summary - Adding static Qwen 2.5 - 1.5B to script. - Adding static Qwen 3 0.6B/1.5B to script - Adding back `skip_advanced_requant`. - Adding prompt + special token for calibration, which helps certain models to improve accuracy. #### Example Scripts: `python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -H haowhsu-linux -s 5f396958 -m SM8750 --prompt "How many r's in strawberries?" --temperature 0 --model_mode kv --max_seq_len 1024 --ptq 16a8w --decoder_model qwen3-0_6b --tasks wikitext --limit 1 --artifact ./qwen3-0_6b` #### Statistics on SM8750, seq_len=1024 qwen2 1.5B: ~34tok/sec. QNN on device PPL=9.4 (CPU FP=9.1) qwen3 0.6B: ~56tok/sec. QNN on device PPL=16.8 (CPU FP=16.26) qwen3 1.7B: ~14tok/sec. QNN on device PPL=14.1 (CPU FP=13.52) ### Test plan E2E in test_qnn_delegate.py
1 parent 0f444ab commit f154d50

File tree

10 files changed

+228
-66
lines changed

10 files changed

+228
-66
lines changed

backends/qualcomm/_passes/annotate_quant_attrs.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,14 @@ class AnnotateQuantAttrs(ExportPass):
3030
generated after quantization process.
3131
"""
3232

33-
def __init__(self, edge_program: torch.export.ExportedProgram):
33+
def __init__(
34+
self,
35+
edge_program: torch.export.ExportedProgram,
36+
skip_advanced_requant: bool = False,
37+
):
3438
super(AnnotateQuantAttrs, self).__init__()
3539
self.edge_program = edge_program
40+
self.skip_advanced_requant = skip_advanced_requant
3641

3742
def _annotate_source_nodes(
3843
self, quant_node: torch.fx.Node, quant_attrs: Dict[str, Any]
@@ -82,16 +87,29 @@ def _annotate_requant(self, n):
8287
# TODO: Store multiple pairs of requantize attributes when we have an op builder
8388
# that has multiple outputs that requires quant attributes.
8489

85-
if any(
86-
q_attrs[attr] != dq_attrs[attr]
87-
for attr in [
88-
QCOM_SCALE,
89-
QCOM_ZERO_POINT,
90-
QCOM_QUANT_MIN,
91-
QCOM_QUANT_MAX,
92-
QCOM_DTYPE,
93-
]
94-
):
90+
# Determine if requantization is needed based on configuration and attribute mismatch.
91+
is_requant_needed = False
92+
if self.skip_advanced_requant:
93+
# In skip_advanced_requant mode, only consider requant if dtypes differ.
94+
if q_attrs[QCOM_DTYPE] != dq_attrs[QCOM_DTYPE]:
95+
is_requant_needed = True
96+
else:
97+
# In full requant mode, consider requant if any key attribute differs.
98+
# This aims to improve accuracy by adjusting scale, zero_point, etc.
99+
# Users can disable this if it causes regressions.
100+
if any(
101+
q_attrs[attr] != dq_attrs[attr]
102+
for attr in [
103+
QCOM_SCALE,
104+
QCOM_ZERO_POINT,
105+
QCOM_QUANT_MIN,
106+
QCOM_QUANT_MAX,
107+
QCOM_DTYPE,
108+
]
109+
):
110+
is_requant_needed = True
111+
112+
if is_requant_needed:
95113
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
96114
user_node = list(dq_node.users)[0]
97115
n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4564,7 +4564,7 @@ def test_static_qwen2_5(self):
45644564
"--ptq",
45654565
"16a8w",
45664566
"--decoder_model",
4567-
"qwen2_5",
4567+
"qwen2_5-0_5b",
45684568
"--model_mode",
45694569
"kv",
45704570
"--max_seq_len",
@@ -4627,13 +4627,18 @@ def test_static_qwen3(self):
46274627
"--ptq",
46284628
"16a8w",
46294629
"--decoder_model",
4630-
"qwen3_0_6b",
4630+
"qwen3-0_6b",
46314631
"--model_mode",
4632-
"hybrid",
4633-
"--prefill_ar_len",
4634-
"32",
4632+
"kv",
46354633
"--max_seq_len",
4636-
"128",
4634+
"1024",
4635+
"--eval_perplexity",
4636+
"--tasks",
4637+
"wikitext",
4638+
"--limit",
4639+
"1",
4640+
"--r3",
4641+
"--enable_masked_softmax",
46374642
]
46384643
if self.compile_only:
46394644
cmds.extend(["--compile_only"])
@@ -4646,8 +4651,6 @@ def test_static_qwen3(self):
46464651
if self.pre_gen_pte:
46474652
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
46484653

4649-
# Accuracy is bad for now. Just check user's prompt is returned.
4650-
golden_start_with = "My favourite condiment is "
46514654
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
46524655
with Listener((self.ip, self.port)) as listener:
46534656
conn = listener.accept()
@@ -4656,12 +4659,13 @@ def test_static_qwen3(self):
46564659
if "Error" in msg:
46574660
self.fail(msg["Error"])
46584661
else:
4659-
model_out = msg["result"][0]
4660-
self.assertTrue(
4661-
model_out.startswith(golden_start_with),
4662-
f"Expected Output: {golden_start_with}. Actual Output: {model_out}",
4663-
)
4664-
self.assertGreaterEqual(msg["inference_speed"], 70) # Lanai
4662+
inference_speed_ref = {"SM8650": 38, "SM8750": 56}
4663+
self.assertLessEqual(msg["wiki_ppl"], 18)
4664+
self.assertLessEqual(msg["pte_size"], 950_000_000) # 950mb
4665+
if self.model in inference_speed_ref:
4666+
self.assertGreaterEqual(
4667+
msg["inference_speed"], inference_speed_ref[self.model]
4668+
)
46654669

46664670
def test_smollm2(self):
46674671
if not self.required_envs():

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ This file provides you the instructions to run LLM Decoder model with different
55
1. LLAMA2 Stories 110M
66
2. LLAMA3.2 1B
77
3. LLAMA3.2 3B
8-
4. QWEN2.5 0.5B
8+
4. QWEN2.5 0.5B / 1.5B
99
5. QWEN3 0.6B / 1.7B
1010
6. Phi4-mini-instruct
1111
7. SMOLLM2 135M
@@ -72,13 +72,31 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL
7272
#### QWEN2.5 0.5B
7373
Default example using hybrid mode
7474
```bash
75-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --ptq 16a8w --enable_masked_softmax --r3 --decoder_model qwen2_5 --prompt "I would like to learn python, could you teach me with a simple example?"
75+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --ptq 16a8w --enable_masked_softmax --r3 --decoder_model qwen2_5-0_5b --prompt "I would like to learn python, could you teach me with a simple example?"
76+
```
77+
78+
#### QWEN2.5 1.5B
79+
Default example using hybrid mode
80+
```bash
81+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --ptq 16a8w --enable_masked_softmax --r3 --decoder_model qwen2_5-1_5b --prompt "I would like to learn python, could you teach me with a simple example?"
82+
```
83+
84+
#### QWEN3 0.6B
85+
Default example using hybrid mode
86+
```bash
87+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --ptq 16a8w --enable_masked_softmax --r3 --decoder_model qwen3-0_6b --prompt "I would like to learn python, could you teach me with a simple example?"
88+
```
89+
90+
#### QWEN3 1.7B
91+
Default example using hybrid mode
92+
```bash
93+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --ptq 16a8w --enable_masked_softmax --r3 --decoder_model qwen3-1_7b --prompt "I would like to learn python, could you teach me with a simple example?"
7694
```
7795

7896
#### SMOLLM2
7997
Default example using hybrid mode.
8098
```bash
81-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -H mlgtw-linux -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a8w --tokenizer_bin tokenizer.bin --decoder_model smollm2 --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?"
99+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -H mlgtw-linux -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a8w --decoder_model smollm2_135m --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?"
82100
```
83101

84102
### KV Cache update mechanism
@@ -175,18 +193,18 @@ To evaluate the perplexity across all 3 phases, users should provide the `--eval
175193

176194
For example, using the Qwen model and 1 wikitext sample as the evaluation task, users can assess all 3 phases perplexity score in a single run by including the appropriate configuration:
177195
```bash
178-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --ptq 16a8w --decoder_model qwen2_5 --eval_perplexity --tasks wikitext --limit 1
196+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --ptq 16a8w --decoder_model qwen2_5-0_5b --eval_perplexity --tasks wikitext --limit 1
179197
```
180198

181199
For the example script above, 1 wikitext sample is used to evaluate all 3 phases. However, there are cases where a user may want to use one sample for quantization calibration and multiple samples for perplexity evaluation. In this case, the process should be split into two runs. In the 1st run, the model is compiled using one sample. In the 2nd run, the user can provide a different configuration for QNN device execution.
182200
Example:
183201
```bash
184202
# 1st run to compile with --limit 1
185-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --ptq 16a8w --decoder_model qwen2_5 --eval_perplexity --tasks wikitext --limit 1 --compile_only
203+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --ptq 16a8w --decoder_model qwen2_5-0_5b --eval_perplexity --tasks wikitext --limit 1 --compile_only
186204
```
187205
```bash
188206
# 2nd run to perform QNN device execution with --limit 3
189-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --ptq 16a8w --decoder_model qwen2_5 --eval_perplexity --tasks wikitext --limit 3 --pre_gen_pte ${PATH_TO_ARTIFACT_IN_1ST_RUN} --quant_attrs_path ${PATH_TO_ARTIFACT_IN_1ST_RUN}/kv_llama_qnn_quant_attrs.json
207+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --ptq 16a8w --decoder_model qwen2_5-0_5b --eval_perplexity --tasks wikitext --limit 3 --pre_gen_pte ${PATH_TO_ARTIFACT_IN_1ST_RUN} --quant_attrs_path ${PATH_TO_ARTIFACT_IN_1ST_RUN}/kv_llama_qnn_quant_attrs.json
190208
```
191209

192210
#### Tasks quantization calibration

examples/qualcomm/oss_scripts/llama/__init__.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import os
88
from abc import ABC
9-
from dataclasses import dataclass, field
9+
from dataclasses import dataclass
1010
from typing import Callable, Dict, Type
1111

1212
from executorch.examples.models.phi_4_mini import (
@@ -19,19 +19,26 @@
1919
from executorch.examples.models.smollm2 import (
2020
convert_weights as convert_smollm2_weights,
2121
)
22-
from executorch.examples.qualcomm.oss_scripts.llama.decoder_constants import (
23-
DECODER_MODEL_VERSION,
24-
)
2522

2623
BASE_DIR = os.path.dirname(__file__)
2724

2825

2926
@dataclass(init=False, frozen=True)
3027
class HFModel(ABC):
28+
"""Base class for all hugging face models
29+
30+
repo_id: Hugging Face Repo ID.
31+
params_path: Path to model's config.json. If the corresponding .json has not yet exsit, please create one.
32+
convert_weights: Used to convert Hugging Face weights parameters to Static Decoder's parameter naming.
33+
transform_weight: Set to true to change HuggingFace weight to improve the performance of RoPE in HTP backend.
34+
instruct_model: True if the model uses chat templates. Check Hugging Face model card to ensure the model uses chat templates.
35+
"""
36+
3137
repo_id: str
3238
params_path: str
33-
runner_version: str
3439
convert_weights: Callable
40+
transform_weight: bool
41+
instruct_model: bool
3542

3643

3744
SUPPORTED_HF_MODELS: Dict[str, HFModel] = {}
@@ -45,40 +52,52 @@ def decorator(cls: Type[HFModel]):
4552
return decorator
4653

4754

48-
@register_hf_model("qwen2_5")
55+
@register_hf_model("qwen2_5-0_5b")
4956
@dataclass(init=False, frozen=True)
50-
class Qwen2_5(HFModel):
57+
class Qwen2_5_0_5B(HFModel):
5158
repo_id: str = "Qwen/Qwen2.5-0.5B"
5259
params_path: str = os.path.join(
5360
BASE_DIR, "../../../models/qwen2_5/config/0_5b_config.json"
5461
)
55-
runner_version: str = field(default=DECODER_MODEL_VERSION["qwen2_5"])
5662
convert_weights = convert_qwen2_5_weights
5763
transform_weight = False
64+
instruct_model = False
65+
66+
67+
@register_hf_model("qwen2_5-1_5b")
68+
@dataclass(init=False, frozen=True)
69+
class Qwen2_5_1_5B(HFModel):
70+
repo_id: str = "Qwen/Qwen2.5-1.5B"
71+
params_path: str = os.path.join(
72+
BASE_DIR, "../../../models/qwen2_5/config/1_5b_config.json"
73+
)
74+
convert_weights = convert_qwen2_5_weights
75+
transform_weight = False
76+
instruct_model = False
5877

5978

60-
@register_hf_model("qwen3_0_6b")
79+
@register_hf_model("qwen3-0_6b")
6180
@dataclass(init=False, frozen=True)
6281
class Qwen3_0_6B(HFModel):
6382
repo_id: str = "Qwen/Qwen3-0.6B"
6483
params_path: str = os.path.join(
6584
BASE_DIR, "../../../models/qwen3/config/0_6b_config.json"
6685
)
67-
runner_version: str = field(default=DECODER_MODEL_VERSION["qwen2_5"])
6886
convert_weights = convert_qwen3_weights
6987
transform_weight = False
88+
instruct_model = True
7089

7190

72-
@register_hf_model("qwen3_1_7b")
91+
@register_hf_model("qwen3-1_7b")
7392
@dataclass(init=False, frozen=True)
7493
class Qwen3_1_7B(HFModel):
7594
repo_id: str = "Qwen/Qwen3-1.7B"
7695
params_path: str = os.path.join(
7796
BASE_DIR, "../../../models/qwen3/config/1_7b_config.json"
7897
)
79-
runner_version: str = field(default=DECODER_MODEL_VERSION["qwen2_5"])
8098
convert_weights = convert_qwen3_weights
8199
transform_weight = False
100+
instruct_model = True
82101

83102

84103
@register_hf_model("phi_4_mini")
@@ -88,9 +107,9 @@ class Phi4Mini(HFModel):
88107
params_path: str = os.path.join(
89108
BASE_DIR, "../../../models/phi_4_mini/config/config.json"
90109
)
91-
runner_version: str = field(default=DECODER_MODEL_VERSION["phi_4_mini"])
92110
convert_weights = convert_phi_4_mini_weights
93111
transform_weight = False
112+
instruct_model = True
94113

95114

96115
@register_hf_model("smollm2_135m")
@@ -100,6 +119,6 @@ class Smollm2_135M(HFModel):
100119
params_path: str = os.path.join(
101120
BASE_DIR, "../../../models/smollm2/135M_config.json"
102121
)
103-
runner_version: str = field(default=DECODER_MODEL_VERSION["smollm2_135m"])
104122
convert_weights = convert_smollm2_weights
105123
transform_weight = True
124+
instruct_model = True

examples/qualcomm/oss_scripts/llama/decoder_constants.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
"lookahead": 2,
1111
}
1212

13+
# The dict's value is mainly for runner to decide what special tokens are required to wrap the prompt.
1314
DECODER_MODEL_VERSION = {
1415
"stories260k": "llama2",
1516
"stories110m": "llama2",
1617
"llama3_2": "llama3",
17-
"qwen2_5": "qwen2_5",
18-
"qwen3_0_6b": "qwen2_5", # TODO: temp workaround, use special token for qwen3 in runner
19-
"qwen3_1_7b": "qwen2_5",
18+
"qwen2_5-0_5b": "qwen2_5",
19+
"qwen2_5-1_5b": "qwen2_5",
20+
"qwen3-0_6b": "qwen3",
21+
"qwen3-1_7b": "qwen3",
2022
"phi_4_mini": "phi_4_mini",
2123
"smollm2_135m": "smollm2_135m",
2224
}

examples/qualcomm/oss_scripts/llama/decoder_utils.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -458,22 +458,34 @@ def prefill_inference(
458458

459459

460460
def graph_module_inference(
461-
args,
462-
use_kv_cache,
461+
use_kv_cache: bool,
463462
get_example_inputs: Callable,
464463
module: torch.fx.GraphModule,
465464
tokenizer,
466465
ar_len=1,
467466
max_seq_len=512,
468467
kv_updater=smart_mask_updater,
468+
prompt=None,
469+
tasks=None,
470+
tasks_limit=1,
471+
num_fewshot=None,
469472
use_i64_token=False,
470473
event_name: Optional[str] = None,
471474
):
472-
if args.tasks is None:
475+
"""
476+
This function supports model execution from static nn.Module decoder model
477+
all the way to edge program.
478+
Users could choose to provide either the prompt or tasks for execution but not both.
479+
"""
480+
# Checks 1 and only 1 is provided.
481+
assert (tasks is None) != (
482+
prompt is None
483+
), "Please provide either tasks or prompt - not both or neither"
484+
if tasks is None:
473485
if use_kv_cache:
474486
kv_inference(
475487
get_example_inputs,
476-
args.prompt[0],
488+
prompt,
477489
module,
478490
tokenizer,
479491
ar_len,
@@ -485,7 +497,7 @@ def graph_module_inference(
485497
else:
486498
prefill_inference(
487499
get_example_inputs,
488-
args.prompt[0],
500+
prompt,
489501
module,
490502
tokenizer,
491503
max_seq_len,
@@ -507,9 +519,24 @@ def graph_module_inference(
507519
with torch.no_grad():
508520
eval_results = simple_evaluate(
509521
model=calibration_wrapper,
510-
tasks=args.tasks,
511-
limit=args.limit,
522+
tasks=tasks,
523+
num_fewshot=num_fewshot,
524+
limit=tasks_limit,
512525
)
513526
logging.info(f"Perplexity evaluation summary for {event_name}")
514527
for task, res in eval_results["results"].items():
515528
logging.info(f"{task}: {res}")
529+
530+
531+
def apply_prompt_template(
532+
chat_template: Callable, prompt: str, system_prompt: str = None
533+
):
534+
messages = [{"role": "user", "content": prompt}]
535+
if system_prompt:
536+
messages.append({"role": "system", "content": system_prompt})
537+
538+
template_prompt = chat_template(
539+
messages, tokenize=False, add_generation_prompt=True
540+
)
541+
logging.info(f"Prompt after applying template: {template_prompt}")
542+
return template_prompt

0 commit comments

Comments
 (0)