Skip to content

Commit 3798202

Browse files
committed
Qualcomm AI Engine Direct - Scripts and accuracy improvement for Qwen 2.5 - 1.5B and Qwen 3 -
0.6B/1.7B
1 parent 9359481 commit 3798202

File tree

10 files changed

+214
-52
lines changed

10 files changed

+214
-52
lines changed

backends/qualcomm/_passes/annotate_quant_attrs.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,14 @@ class AnnotateQuantAttrs(ExportPass):
3030
generated after quantization process.
3131
"""
3232

33-
def __init__(self, edge_program: torch.export.ExportedProgram):
33+
def __init__(
34+
self,
35+
edge_program: torch.export.ExportedProgram,
36+
skip_advanced_requant: bool = False,
37+
):
3438
super(AnnotateQuantAttrs, self).__init__()
3539
self.edge_program = edge_program
40+
self.skip_advanced_requant = skip_advanced_requant
3641

3742
def _annotate_source_nodes(
3843
self, quant_node: torch.fx.Node, quant_attrs: Dict[str, Any]
@@ -82,16 +87,29 @@ def _annotate_requant(self, n):
8287
# TODO: Store multiple pairs of requantize attributes when we have an op builder
8388
# that has multiple outputs that requires quant attributes.
8489

85-
if any(
86-
q_attrs[attr] != dq_attrs[attr]
87-
for attr in [
88-
QCOM_SCALE,
89-
QCOM_ZERO_POINT,
90-
QCOM_QUANT_MIN,
91-
QCOM_QUANT_MAX,
92-
QCOM_DTYPE,
93-
]
94-
):
90+
# Determine if requantization is needed based on configuration and attribute mismatch.
91+
is_requant_needed = False
92+
if self.skip_advanced_requant:
93+
# In skip_advanced_requant mode, only consider requant if dtypes differ.
94+
if q_attrs[QCOM_DTYPE] != dq_attrs[QCOM_DTYPE]:
95+
is_requant_needed = True
96+
else:
97+
# In full requant mode, consider requant if any key attribute differs.
98+
# This aims to improve accuracy by adjusting scale, zero_point, etc.
99+
# Users can disable this if it causes regressions.
100+
if any(
101+
q_attrs[attr] != dq_attrs[attr]
102+
for attr in [
103+
QCOM_SCALE,
104+
QCOM_ZERO_POINT,
105+
QCOM_QUANT_MIN,
106+
QCOM_QUANT_MAX,
107+
QCOM_DTYPE,
108+
]
109+
):
110+
is_requant_needed = True
111+
112+
if is_requant_needed:
95113
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
96114
user_node = list(dq_node.users)[0]
97115
n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4550,7 +4550,7 @@ def test_static_qwen2_5(self):
45504550
"--ptq",
45514551
"16a8w",
45524552
"--decoder_model",
4553-
"qwen2_5",
4553+
"qwen2_5-0_5b",
45544554
"--model_mode",
45554555
"kv",
45564556
"--max_seq_len",
@@ -4613,13 +4613,15 @@ def test_static_qwen3(self):
46134613
"--ptq",
46144614
"16a8w",
46154615
"--decoder_model",
4616-
"qwen3_0_6b",
4616+
"qwen3-0_6b",
46174617
"--model_mode",
46184618
"hybrid",
46194619
"--prefill_ar_len",
46204620
"32",
46214621
"--max_seq_len",
46224622
"128",
4623+
"--r3",
4624+
"--enable_masked_softmax",
46234625
]
46244626
if self.compile_only:
46254627
cmds.extend(["--compile_only"])
@@ -4632,8 +4634,8 @@ def test_static_qwen3(self):
46324634
if self.pre_gen_pte:
46334635
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
46344636

4635-
# Accuracy is bad for now. Just check user's prompt is returned.
4636-
golden_start_with = "My favourite condiment is "
4637+
# TODO: Change to PPL evaluation
4638+
golden_start_with = "<|im_start|>user"
46374639
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
46384640
with Listener((self.ip, self.port)) as listener:
46394641
conn = listener.accept()

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ This file provides you the instructions to run LLM Decoder model with different
55
1. LLAMA2 Stories 110M
66
2. LLAMA3.2 1B
77
3. LLAMA3.2 3B
8-
4. QWEN2.5 0.5B
8+
4. QWEN2.5 0.5B / 1.5B
99
5. QWEN3 0.6B / 1.7B
1010
6. Phi4-mini-instruct
1111
7. SMOLLM2 135M
@@ -72,7 +72,25 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL
7272
#### QWEN2.5 0.5B
7373
Default example using hybrid mode
7474
```bash
75-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --ptq 16a8w --enable_masked_softmax --r3 --decoder_model qwen2_5 --prompt "I would like to learn python, could you teach me with a simple example?"
75+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --ptq 16a8w --enable_masked_softmax --r3 --decoder_model qwen2_5-0_5b --prompt "I would like to learn python, could you teach me with a simple example?"
76+
```
77+
78+
#### QWEN2.5 1.5B
79+
Default example using hybrid mode
80+
```bash
81+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --ptq 16a8w --enable_masked_softmax --r3 --decoder_model qwen2_5-1_5b --prompt "I would like to learn python, could you teach me with a simple example?"
82+
```
83+
84+
#### QWEN3 0.6B
85+
Default example using hybrid mode
86+
```bash
87+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --ptq 16a8w --enable_masked_softmax --r3 --decoder_model qwen3-0_6b --prompt "I would like to learn python, could you teach me with a simple example?"
88+
```
89+
90+
#### QWEN3 1.7B
91+
Default example using hybrid mode
92+
```bash
93+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --ptq 16a8w --enable_masked_softmax --r3 --decoder_model qwen3-1_7b --prompt "I would like to learn python, could you teach me with a simple example?"
7694
```
7795

7896
#### SMOLLM2
@@ -175,18 +193,18 @@ To evaluate the perplexity across all 3 phases, users should provide the `--eval
175193

176194
For example, using the Qwen model and 1 wikitext sample as the evaluation task, users can assess all 3 phases perplexity score in a single run by including the appropriate configuration:
177195
```bash
178-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --ptq 16a8w --decoder_model qwen2_5 --eval_perplexity --tasks wikitext --limit 1
196+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --ptq 16a8w --decoder_model qwen2_5-0_5b --eval_perplexity --tasks wikitext --limit 1
179197
```
180198

181199
For the example script above, 1 wikitext sample is used to evaluate all 3 phases. However, there are cases where a user may want to use one sample for quantization calibration and multiple samples for perplexity evaluation. In this case, the process should be split into two runs. In the 1st run, the model is compiled using one sample. In the 2nd run, the user can provide a different configuration for QNN device execution.
182200
Example:
183201
```bash
184202
# 1st run to compile with --limit 1
185-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --ptq 16a8w --decoder_model qwen2_5 --eval_perplexity --tasks wikitext --limit 1 --compile_only
203+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --ptq 16a8w --decoder_model qwen2_5-0_5b --eval_perplexity --tasks wikitext --limit 1 --compile_only
186204
```
187205
```bash
188206
# 2nd run to perform QNN device execution with --limit 3
189-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --ptq 16a8w --decoder_model qwen2_5 --eval_perplexity --tasks wikitext --limit 3 --pre_gen_pte ${PATH_TO_ARTIFACT_IN_1ST_RUN} --quant_attrs_path ${PATH_TO_ARTIFACT_IN_1ST_RUN}/kv_llama_qnn_quant_attrs.json
207+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --prompt "What is 1+1?" --temperature 0 --model_mode kv --max_seq_len 1024 --ptq 16a8w --decoder_model qwen2_5-0_5b --eval_perplexity --tasks wikitext --limit 3 --pre_gen_pte ${PATH_TO_ARTIFACT_IN_1ST_RUN} --quant_attrs_path ${PATH_TO_ARTIFACT_IN_1ST_RUN}/kv_llama_qnn_quant_attrs.json
190208
```
191209

192210
#### Tasks quantization calibration

examples/qualcomm/oss_scripts/llama/__init__.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import os
88
from abc import ABC
9-
from dataclasses import dataclass, field
9+
from dataclasses import dataclass
1010
from typing import Callable, Dict, Type
1111

1212
from executorch.examples.models.phi_4_mini import (
@@ -28,10 +28,19 @@
2828

2929
@dataclass(init=False, frozen=True)
3030
class HFModel(ABC):
31+
""" Base class for all hugging face models
32+
33+
repo_id: Hugging Face Repo ID.
34+
params_path: Path to model's config.json. If the corresponding .json has not yet exsit, please create one.
35+
convert_weights: Used to convert Hugging Face weights parameters to Static Decoder's parameter naming.
36+
transform_weight: Set to true to change HuggingFace weight to improve the performance of RoPE in HTP backend.
37+
instruct_model: True if the model uses chat templates. Check Hugging Face model card to ensure the model uses chat templates.
38+
"""
3139
repo_id: str
3240
params_path: str
33-
runner_version: str
3441
convert_weights: Callable
42+
transform_weight: bool
43+
instruct_model: bool
3544

3645

3746
SUPPORTED_HF_MODELS: Dict[str, HFModel] = {}
@@ -45,40 +54,52 @@ def decorator(cls: Type[HFModel]):
4554
return decorator
4655

4756

48-
@register_hf_model("qwen2_5")
57+
@register_hf_model("qwen2_5-0_5b")
4958
@dataclass(init=False, frozen=True)
50-
class Qwen2_5(HFModel):
59+
class Qwen2_5_0_5B(HFModel):
5160
repo_id: str = "Qwen/Qwen2.5-0.5B"
5261
params_path: str = os.path.join(
5362
BASE_DIR, "../../../models/qwen2_5/config/0_5b_config.json"
5463
)
55-
runner_version: str = field(default=DECODER_MODEL_VERSION["qwen2_5"])
5664
convert_weights = convert_qwen2_5_weights
5765
transform_weight = False
66+
instruct_model = False
5867

5968

60-
@register_hf_model("qwen3_0_6b")
69+
@register_hf_model("qwen2_5-1_5b")
70+
@dataclass(init=False, frozen=True)
71+
class Qwen2_5_1_5B(HFModel):
72+
repo_id: str = "Qwen/Qwen2.5-1.5B"
73+
params_path: str = os.path.join(
74+
BASE_DIR, "../../../models/qwen2_5/config/1_5b_config.json"
75+
)
76+
convert_weights = convert_qwen2_5_weights
77+
transform_weight = False
78+
instruct_model = False
79+
80+
81+
@register_hf_model("qwen3-0_6b")
6182
@dataclass(init=False, frozen=True)
6283
class Qwen3_0_6B(HFModel):
6384
repo_id: str = "Qwen/Qwen3-0.6B"
6485
params_path: str = os.path.join(
6586
BASE_DIR, "../../../models/qwen3/config/0_6b_config.json"
6687
)
67-
runner_version: str = field(default=DECODER_MODEL_VERSION["qwen2_5"])
6888
convert_weights = convert_qwen3_weights
6989
transform_weight = False
90+
instruct_model = True
7091

7192

72-
@register_hf_model("qwen3_1_7b")
93+
@register_hf_model("qwen3-1_7b")
7394
@dataclass(init=False, frozen=True)
7495
class Qwen3_1_7B(HFModel):
7596
repo_id: str = "Qwen/Qwen3-1.7B"
7697
params_path: str = os.path.join(
7798
BASE_DIR, "../../../models/qwen3/config/1_7b_config.json"
7899
)
79-
runner_version: str = field(default=DECODER_MODEL_VERSION["qwen2_5"])
80100
convert_weights = convert_qwen3_weights
81101
transform_weight = False
102+
instruct_model = True
82103

83104

84105
@register_hf_model("phi_4_mini")
@@ -88,9 +109,9 @@ class Phi4Mini(HFModel):
88109
params_path: str = os.path.join(
89110
BASE_DIR, "../../../models/phi_4_mini/config/config.json"
90111
)
91-
runner_version: str = field(default=DECODER_MODEL_VERSION["phi_4_mini"])
92112
convert_weights = convert_phi_4_mini_weights
93113
transform_weight = False
114+
instruct_model = True
94115

95116

96117
@register_hf_model("smollm2_135m")
@@ -100,6 +121,6 @@ class Smollm2_135M(HFModel):
100121
params_path: str = os.path.join(
101122
BASE_DIR, "../../../models/smollm2/135M_config.json"
102123
)
103-
runner_version: str = field(default=DECODER_MODEL_VERSION["smollm2_135m"])
104124
convert_weights = convert_smollm2_weights
105125
transform_weight = True
126+
instruct_model = True

examples/qualcomm/oss_scripts/llama/decoder_constants.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
"lookahead": 2,
1111
}
1212

13+
# The dict's value is mainly for runner to decide what special tokens are required to wrap the prompt.
1314
DECODER_MODEL_VERSION = {
1415
"stories260k": "llama2",
1516
"stories110m": "llama2",
1617
"llama3_2": "llama3",
17-
"qwen2_5": "qwen2_5",
18-
"qwen3_0_6b": "qwen2_5", # TODO: temp workaround, use special token for qwen3 in runner
19-
"qwen3_1_7b": "qwen2_5",
18+
"qwen2_5-0_5b": "qwen2_5",
19+
"qwen2_5-1_5b": "qwen2_5",
20+
"qwen3-0_6b": "qwen3",
21+
"qwen3-1_7b": "qwen3",
2022
"phi_4_mini": "phi_4_mini",
2123
"smollm2_135m": "smollm2_135m",
2224
}

examples/qualcomm/oss_scripts/llama/decoder_utils.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -458,22 +458,34 @@ def prefill_inference(
458458

459459

460460
def graph_module_inference(
461-
args,
462-
use_kv_cache,
461+
use_kv_cache: bool,
463462
get_example_inputs: Callable,
464463
module: torch.fx.GraphModule,
465464
tokenizer,
466465
ar_len=1,
467466
max_seq_len=512,
468467
kv_updater=smart_mask_updater,
468+
prompt=None,
469+
tasks=None,
470+
tasks_limit=1,
471+
num_fewshot=None,
469472
use_i64_token=False,
470473
event_name: Optional[str] = None,
471474
):
472-
if args.tasks is None:
475+
"""
476+
This function supports model execution from static nn.Module decoder model
477+
all the way to edge program.
478+
Users could choose to provide either the prompt or tasks for execution but not both.
479+
"""
480+
# Checks 1 and only 1 is provided.
481+
assert (tasks is None) != (
482+
prompt is None
483+
), "Please provide either tasks or prompt - not both or neither"
484+
if tasks is None:
473485
if use_kv_cache:
474486
kv_inference(
475487
get_example_inputs,
476-
args.prompt[0],
488+
prompt,
477489
module,
478490
tokenizer,
479491
ar_len,
@@ -485,7 +497,7 @@ def graph_module_inference(
485497
else:
486498
prefill_inference(
487499
get_example_inputs,
488-
args.prompt[0],
500+
prompt,
489501
module,
490502
tokenizer,
491503
max_seq_len,
@@ -507,9 +519,24 @@ def graph_module_inference(
507519
with torch.no_grad():
508520
eval_results = simple_evaluate(
509521
model=calibration_wrapper,
510-
tasks=args.tasks,
511-
limit=args.limit,
522+
tasks=tasks,
523+
num_fewshot=num_fewshot,
524+
limit=tasks_limit,
512525
)
513526
logging.info(f"Perplexity evaluation summary for {event_name}")
514527
for task, res in eval_results["results"].items():
515528
logging.info(f"{task}: {res}")
529+
530+
531+
def apply_prompt_template(
532+
chat_template: Callable, prompt: str, system_prompt: str = None
533+
):
534+
messages = [{"role": "user", "content": prompt}]
535+
if system_prompt:
536+
messages.append({"role": "system", "content": system_prompt})
537+
538+
template_prompt = chat_template(
539+
messages, tokenize=False, add_generation_prompt=True
540+
)
541+
logging.info(f"Prompt after applying template: {template_prompt}")
542+
return template_prompt

0 commit comments

Comments
 (0)