Skip to content

Commit fc31b75

Browse files
committed
Qualcomm AI Engine Direct - Improve GA Static Phi-4-mini accuracy
Summary: - Refactor custom annotation for R3 - Fix warning message in quantization - Add phi-4-mini setting into README - Fixed segmemtation fault when run the model with sharding - Add a test case for phi-4 in test_qnn_delegate.py - Add new parameter "group_size" in llama.py to set block size in block quantization
1 parent 3dac421 commit fc31b75

File tree

6 files changed

+121
-19
lines changed

6 files changed

+121
-19
lines changed

backends/qualcomm/builders/node_visitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
162162
for ch in range(num_channels):
163163
max_scale = scales[ch].reshape(1, -1).amax(dim=-1) / num_steps
164164
q_scales = torch.clamp(
165-
input=scales[ch] / max_scale,
165+
input=torch.round(input=scales[ch] / max_scale),
166166
min=1,
167167
max=2**bitwidth_of_scale,
168168
).to(quant_scales_dtype)

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ def annotate_matmul_input1(node: Node, is_qat: str):
317317
torch.ops.aten.transpose.int,
318318
torch.ops.aten.view.default,
319319
torch.ops.aten.reshape.default,
320+
torch.ops.aten.slice.Tensor,
320321
]:
321322
annotate_single_in_single_out(node, quantization_config_8a8w)
322323
node = node.args[0]
@@ -340,7 +341,11 @@ def annotate_matmul_input1(node: Node, is_qat: str):
340341
node, quantization_config=quantization_config_8a4w_per_channel
341342
)
342343
break
343-
elif node.target in [torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor]:
344+
elif node.target in [
345+
torch.ops.aten.add.Tensor,
346+
torch.ops.aten.sub.Tensor,
347+
torch.ops.aten.matmul.default,
348+
]:
344349
break
345350
else:
346351
print(f"The node ({node}) is not expected in the input1 of the matmul")
@@ -356,7 +361,12 @@ def annotate_matmul_input1(node: Node, is_qat: str):
356361
)
357362

358363
for node in gm.graph.nodes:
359-
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
364+
if (
365+
node.op == "call_function"
366+
and node.target == torch.ops.aten.matmul.default
367+
and all(arg.op == "call_function" for arg in node.args)
368+
):
369+
# Only apply custom annotation on Q @ K^T @ V
360370
annotate_matmul(node, quantization_config_16a8w)
361371
annotate_matmul_input1(node.args[1], is_qat=is_qat)
362372

backends/qualcomm/runtime/backends/QnnImplementation.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,8 @@ Error QnnImplementation::StartBackend(
5151
const std::string& lib_path,
5252
const QnnSaver_Config_t** saver_config) {
5353
Qnn_ErrorHandle_t error = QNN_SUCCESS;
54-
// RTLD_GLOBAL is needed on x86 as HTP op package has a requirement for the
55-
// symbols in backend to be visible. Using RTLD_LOCAL on Android to allow full
56-
// unloading of HTP backend shared library on dlclose() as RTLD_GLOBAL isn't
57-
// letting it happen.
5854
void* lib_handle = nullptr;
59-
#if defined(__ANDROID__)
60-
lib_handle = dlopen(lib_path.c_str(), RTLD_NOW | RTLD_LOCAL);
61-
#else
6255
lib_handle = dlopen(lib_path.c_str(), RTLD_NOW | RTLD_GLOBAL);
63-
#endif
6456
if (lib_handle == nullptr) {
6557
QNN_EXECUTORCH_LOG_ERROR(
6658
"Cannot Open QNN library %s, with error: %s",

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,7 +1127,8 @@ def test_qnn_backend_where(self):
11271127
(torch.randn(30, 20),),
11281128
]
11291129
for i, module in enumerate(modules):
1130-
self.lower_module_and_test_output(module, sample_inputs[i])
1130+
with self.subTest(i=i):
1131+
self.lower_module_and_test_output(module, sample_inputs[i])
11311132

11321133
def test_qnn_backend_masked_fill(self):
11331134
module = MaskedFill() # noqa: F405
@@ -2556,8 +2557,9 @@ def test_qnn_backend_where(self):
25562557
(torch.randn(30, 20),),
25572558
]
25582559
for i, module in enumerate(modules):
2559-
module = self.get_qdq_module(module, sample_inputs[i])
2560-
self.lower_module_and_test_output(module, sample_inputs[i])
2560+
with self.subTest(i=i):
2561+
module = self.get_qdq_module(module, sample_inputs[i])
2562+
self.lower_module_and_test_output(module, sample_inputs[i])
25612563

25622564
def test_qnn_backend_masked_fill(self):
25632565
module = MaskedFill() # noqa: F405
@@ -4527,6 +4529,77 @@ def test_llama_stories_110m(self):
45274529
if not self.compile_only and not self.enable_x86_64:
45284530
self.assertGreaterEqual(msg["inference_speed"], 220) # Lanai
45294531

4532+
def test_static_phi4(self):
4533+
if not self.required_envs():
4534+
self.skipTest("missing required envs")
4535+
4536+
prompt = "My favourite condiment is "
4537+
cmds = [
4538+
"python",
4539+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
4540+
"--artifact",
4541+
self.artifact_dir,
4542+
"--build_folder",
4543+
self.build_folder,
4544+
"--model",
4545+
self.model,
4546+
"--ip",
4547+
self.ip,
4548+
"--port",
4549+
str(self.port),
4550+
"--prompt",
4551+
f"{prompt}",
4552+
"--ptq",
4553+
"16a4w_block",
4554+
"--group_size",
4555+
"16",
4556+
"--decoder_model",
4557+
"phi_4_mini",
4558+
"--model_mode",
4559+
"kv",
4560+
"--max_seq_len",
4561+
"1024",
4562+
"--num_sharding",
4563+
"8",
4564+
"--eval_perplexity",
4565+
"--tasks",
4566+
"wikitext",
4567+
"--limit",
4568+
"1",
4569+
]
4570+
if self.compile_only:
4571+
cmds.extend(["--compile_only"])
4572+
elif self.device:
4573+
cmds.extend(["--device", self.device])
4574+
if self.host:
4575+
cmds.extend(["--host", self.host])
4576+
elif self.enable_x86_64:
4577+
cmds.extend(["--enable_x86_64"])
4578+
if self.pre_gen_pte:
4579+
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
4580+
cmds.extend(
4581+
[
4582+
"--quant_attrs_path",
4583+
f"{self.pre_gen_pte}/kv_llama_qnn_quant_attrs.json",
4584+
]
4585+
)
4586+
4587+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
4588+
with Listener((self.ip, self.port)) as listener:
4589+
conn = listener.accept()
4590+
p.communicate()
4591+
msg = json.loads(conn.recv())
4592+
if "Error" in msg:
4593+
self.fail(msg["Error"])
4594+
else:
4595+
inference_speed_ref = {"SM8650": 14, "SM8750": 19}
4596+
self.assertLessEqual(msg["wiki_ppl"], 12)
4597+
self.assertLessEqual(msg["pte_size"], 4000000000) # 4gb
4598+
if self.model in inference_speed_ref:
4599+
self.assertGreaterEqual(
4600+
msg["inference_speed"], inference_speed_ref[self.model]
4601+
)
4602+
45304603
def test_static_qwen2_5(self):
45314604
if not self.required_envs():
45324605
self.skipTest("missing required envs")

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ Default example using hybrid mode.
6969
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1"
7070
```
7171

72+
#### Phi4-mini-instruct
73+
Default example using hybrid mode.
74+
```bash
75+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w_block --group_size 16 --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model phi_4_mini --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --num_sharding 8 --prompt "I would like to learn python, could you teach me with a simple example?"
76+
```
77+
7278
#### QWEN2.5 0.5B
7379
Default example using hybrid mode
7480
```bash
@@ -81,6 +87,7 @@ Default example using hybrid mode.
8187
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -H mlgtw-linux -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a8w --tokenizer_bin tokenizer.bin --decoder_model smollm2 --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?"
8288
```
8389

90+
8491
### KV Cache update mechanism
8592
We have two distinct mechanisms for updating the key-value (KV) cache, which can be selected at runtime. Shift Pointer and Smart Mask.
8693

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@
116116
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
117117
logging.basicConfig(level=logging.INFO, format=FORMAT)
118118
logging.getLogger().setLevel(logging.INFO)
119+
# Avoid the error message "Could not initialize NNPACK! Reason: Unsupported hardware."
120+
torch.backends.nnpack.set_flags(False)
119121

120122

121123
def next_power_of_two(n):
@@ -233,10 +235,16 @@ def quantize(
233235
).module()
234236

235237
if quant_dtype == QuantDtype.use_16a4w_block:
238+
if args.group_size is None:
239+
raise ValueError(
240+
"Group size is required when use quant_dtype 16a4w_block"
241+
)
236242
conv_nodes = [
237243
n for n in fx_graph_module.graph.nodes if "conv" in n.name
238244
]
239-
block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes}
245+
block_size_map = {
246+
n.name: (1, args.group_size, 1, 1) for n in conv_nodes
247+
}
240248
quantizer.set_block_size_map(block_size_map)
241249

242250
fx_graph_module = prepare_pt2e(fx_graph_module, quantizer)
@@ -584,7 +592,7 @@ def permute(w, heads):
584592
if args.ptq != "16a8w":
585593
# 16a8w use 16bit kv io, so skip this custom annotation
586594
custom_annotations = custom_annotations + (annotate_matmul_16a8w,)
587-
if args.decoder_model in {"stories110m", "stories260k"}:
595+
if args.decoder_model in {"stories110m", "stories260k", "phi_4_mini"}:
588596
custom_annotations = custom_annotations + (
589597
annotate_linear_16a8w_in_affine_layer,
590598
)
@@ -801,12 +809,20 @@ def post_process():
801809

802810
seq_len = args.max_seq_len
803811
multi_prompts = " ".join([f'--prompt "{prompt}"' for prompt in args.prompt])
812+
lookahead_args = " ".join(
813+
[
814+
f"--window {args.window}",
815+
f"--gcap {args.gcap}",
816+
f"--ngram {args.ngram}",
817+
]
818+
)
804819
runner_args = " ".join(
805820
[
806821
multi_prompts,
807822
f"--eval_mode {EVAL_MODE[args.model_mode]}",
808823
f"--temperature {args.temperature}",
809824
f"--system_prompt '{args.system_prompt}'",
825+
lookahead_args if args.model_mode == "lookahead" else "",
810826
]
811827
)
812828

@@ -856,9 +872,6 @@ def post_process():
856872
"--output_path outputs/outputs.txt",
857873
f"--performance_output_path {performance_output_path}",
858874
f"--kv_updater {'SmartMask' if args.kv_updater == smart_mask_updater else 'ShiftPointer'}",
859-
f"--window {args.window}",
860-
f"--gcap {args.gcap}",
861-
f"--ngram {args.ngram}",
862875
runner_args,
863876
]
864877
)
@@ -1123,6 +1136,13 @@ def _build_parser():
11231136
action="store_true",
11241137
default=False,
11251138
)
1139+
parser.add_argument(
1140+
"-G",
1141+
"--group_size",
1142+
type=int,
1143+
default=None,
1144+
help="group_size used in block quantization for weight quantization.",
1145+
)
11261146

11271147
parser.add_argument("-v", "--verbose", action="store_true")
11281148

0 commit comments

Comments
 (0)