Skip to content

Commit b743cc1

Browse files
authored
Qualcomm AI Engine Direct - Improve GA Static Phi-4-mini accuracy (#13573)
Summary: - Refactor custom annotation for R3 - Fix warning message in quantization - Add phi-4-mini setting into README - Fixed segmemtation fault when run the model with sharding - Add a test case for phi-4 in test_qnn_delegate.py - Add new parameter "group_size" in llama.py to set block size in block quantization ## Sample Script ``` python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} \ --ptq 16a4w_block --group_size 16 --checkpoint consolidated.00.pth --params params.json --num_sharding 4 \ --tokenizer_model tokenizer.model --decoder_model phi_4_mini --model_mode hybrid --prefill_ar_len 128 \ --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" ``` ## Result Stats with QNN2.37.0 on SM8750 Accuracy: 10.82 Token Rate: 22.727273 Results: --prompt "I would like to learn python, could you teach me with a simple example?" ``` <|user|>I would like to learn python, could you teach me with one simple program?<|end|><|assistant|>Of course! Let's get started with a simple Python program. We'll create a simple program that asks for your name and then greets you. ```python # Ask for the user's name name = input("Please enter your name: ") # Greet the user print(f"Hello, {name}! Welcome to the world of Python.") ``` To run this program, you would need to copy the code into a Python environment (like an IDE or a Python interpreter). When you run the program, it will prompt you to enter your name, and then it will greet you by name. Enjoy learning Python!<|end|> ``` ## Test plan Added E2E test to test_qnn_delegate.py cc: @haowhsu-quic
1 parent f154d50 commit b743cc1

File tree

6 files changed

+121
-19
lines changed

6 files changed

+121
-19
lines changed

backends/qualcomm/builders/node_visitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
162162
for ch in range(num_channels):
163163
max_scale = scales[ch].reshape(1, -1).amax(dim=-1) / num_steps
164164
q_scales = torch.clamp(
165-
input=scales[ch] / max_scale,
165+
input=torch.round(input=scales[ch] / max_scale),
166166
min=1,
167167
max=2**bitwidth_of_scale,
168168
).to(quant_scales_dtype)

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ def annotate_matmul_input1(node: Node, is_qat: str):
317317
torch.ops.aten.transpose.int,
318318
torch.ops.aten.view.default,
319319
torch.ops.aten.reshape.default,
320+
torch.ops.aten.slice.Tensor,
320321
]:
321322
annotate_single_in_single_out(node, quantization_config_8a8w)
322323
node = node.args[0]
@@ -340,7 +341,11 @@ def annotate_matmul_input1(node: Node, is_qat: str):
340341
node, quantization_config=quantization_config_8a4w_per_channel
341342
)
342343
break
343-
elif node.target in [torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor]:
344+
elif node.target in [
345+
torch.ops.aten.add.Tensor,
346+
torch.ops.aten.sub.Tensor,
347+
torch.ops.aten.matmul.default,
348+
]:
344349
break
345350
else:
346351
print(f"The node ({node}) is not expected in the input1 of the matmul")
@@ -356,7 +361,12 @@ def annotate_matmul_input1(node: Node, is_qat: str):
356361
)
357362

358363
for node in gm.graph.nodes:
359-
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
364+
if (
365+
node.op == "call_function"
366+
and node.target == torch.ops.aten.matmul.default
367+
and all(arg.op == "call_function" for arg in node.args)
368+
):
369+
# Only apply custom annotation on Q @ K^T @ V
360370
annotate_matmul(node, quantization_config_16a8w)
361371
annotate_matmul_input1(node.args[1], is_qat=is_qat)
362372

backends/qualcomm/runtime/backends/QnnImplementation.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,8 @@ Error QnnImplementation::StartBackend(
5151
const std::string& lib_path,
5252
const QnnSaver_Config_t** saver_config) {
5353
Qnn_ErrorHandle_t error = QNN_SUCCESS;
54-
// RTLD_GLOBAL is needed on x86 as HTP op package has a requirement for the
55-
// symbols in backend to be visible. Using RTLD_LOCAL on Android to allow full
56-
// unloading of HTP backend shared library on dlclose() as RTLD_GLOBAL isn't
57-
// letting it happen.
5854
void* lib_handle = nullptr;
59-
#if defined(__ANDROID__)
60-
lib_handle = dlopen(lib_path.c_str(), RTLD_NOW | RTLD_LOCAL);
61-
#else
6255
lib_handle = dlopen(lib_path.c_str(), RTLD_NOW | RTLD_GLOBAL);
63-
#endif
6456
if (lib_handle == nullptr) {
6557
QNN_EXECUTORCH_LOG_ERROR(
6658
"Cannot Open QNN library %s, with error: %s",

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,7 +1134,8 @@ def test_qnn_backend_where(self):
11341134
(torch.randn(30, 20),),
11351135
]
11361136
for i, module in enumerate(modules):
1137-
self.lower_module_and_test_output(module, sample_inputs[i])
1137+
with self.subTest(i=i):
1138+
self.lower_module_and_test_output(module, sample_inputs[i])
11381139

11391140
def test_qnn_backend_masked_fill(self):
11401141
module = MaskedFill() # noqa: F405
@@ -2571,8 +2572,9 @@ def test_qnn_backend_where(self):
25712572
(torch.randn(30, 20),),
25722573
]
25732574
for i, module in enumerate(modules):
2574-
module = self.get_qdq_module(module, sample_inputs[i])
2575-
self.lower_module_and_test_output(module, sample_inputs[i])
2575+
with self.subTest(i=i):
2576+
module = self.get_qdq_module(module, sample_inputs[i])
2577+
self.lower_module_and_test_output(module, sample_inputs[i])
25762578

25772579
def test_qnn_backend_masked_fill(self):
25782580
module = MaskedFill() # noqa: F405
@@ -4541,6 +4543,77 @@ def test_llama_stories_110m(self):
45414543
if not self.compile_only and not self.enable_x86_64:
45424544
self.assertGreaterEqual(msg["inference_speed"], 220) # Lanai
45434545

4546+
def test_static_phi4(self):
4547+
if not self.required_envs():
4548+
self.skipTest("missing required envs")
4549+
4550+
prompt = "My favourite condiment is "
4551+
cmds = [
4552+
"python",
4553+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
4554+
"--artifact",
4555+
self.artifact_dir,
4556+
"--build_folder",
4557+
self.build_folder,
4558+
"--model",
4559+
self.model,
4560+
"--ip",
4561+
self.ip,
4562+
"--port",
4563+
str(self.port),
4564+
"--prompt",
4565+
f"{prompt}",
4566+
"--ptq",
4567+
"16a4w_block",
4568+
"--group_size",
4569+
"16",
4570+
"--decoder_model",
4571+
"phi_4_mini",
4572+
"--model_mode",
4573+
"kv",
4574+
"--max_seq_len",
4575+
"1024",
4576+
"--num_sharding",
4577+
"8",
4578+
"--eval_perplexity",
4579+
"--tasks",
4580+
"wikitext",
4581+
"--limit",
4582+
"1",
4583+
]
4584+
if self.compile_only:
4585+
cmds.extend(["--compile_only"])
4586+
elif self.device:
4587+
cmds.extend(["--device", self.device])
4588+
if self.host:
4589+
cmds.extend(["--host", self.host])
4590+
elif self.enable_x86_64:
4591+
cmds.extend(["--enable_x86_64"])
4592+
if self.pre_gen_pte:
4593+
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
4594+
cmds.extend(
4595+
[
4596+
"--quant_attrs_path",
4597+
f"{self.pre_gen_pte}/kv_llama_qnn_quant_attrs.json",
4598+
]
4599+
)
4600+
4601+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
4602+
with Listener((self.ip, self.port)) as listener:
4603+
conn = listener.accept()
4604+
p.communicate()
4605+
msg = json.loads(conn.recv())
4606+
if "Error" in msg:
4607+
self.fail(msg["Error"])
4608+
else:
4609+
inference_speed_ref = {"SM8650": 14, "SM8750": 19}
4610+
self.assertLessEqual(msg["wiki_ppl"], 12)
4611+
self.assertLessEqual(msg["pte_size"], 4000000000) # 4gb
4612+
if self.model in inference_speed_ref:
4613+
self.assertGreaterEqual(
4614+
msg["inference_speed"], inference_speed_ref[self.model]
4615+
)
4616+
45444617
def test_static_qwen2_5(self):
45454618
if not self.required_envs():
45464619
self.skipTest("missing required envs")

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ Default example using hybrid mode.
6969
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1"
7070
```
7171

72+
#### Phi4-mini-instruct
73+
Default example using hybrid mode.
74+
```bash
75+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w_block --group_size 16 --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model phi_4_mini --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --num_sharding 8 --prompt "I would like to learn python, could you teach me with a simple example?"
76+
```
77+
7278
#### QWEN2.5 0.5B
7379
Default example using hybrid mode
7480
```bash
@@ -99,6 +105,7 @@ Default example using hybrid mode.
99105
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -H mlgtw-linux -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a8w --decoder_model smollm2_135m --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?"
100106
```
101107

108+
102109
### KV Cache update mechanism
103110
We have two distinct mechanisms for updating the key-value (KV) cache, which can be selected at runtime. Shift Pointer and Smart Mask.
104111

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@
117117
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
118118
logging.basicConfig(level=logging.INFO, format=FORMAT)
119119
logging.getLogger().setLevel(logging.INFO)
120+
# Avoid the error message "Could not initialize NNPACK! Reason: Unsupported hardware."
121+
torch.backends.nnpack.set_flags(False)
120122

121123

122124
def next_power_of_two(n):
@@ -235,10 +237,16 @@ def quantize(
235237
).module()
236238

237239
if quant_dtype == QuantDtype.use_16a4w_block:
240+
if args.group_size is None:
241+
raise ValueError(
242+
"Group size is required when use quant_dtype 16a4w_block"
243+
)
238244
conv_nodes = [
239245
n for n in fx_graph_module.graph.nodes if "conv" in n.name
240246
]
241-
block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes}
247+
block_size_map = {
248+
n.name: (1, args.group_size, 1, 1) for n in conv_nodes
249+
}
242250
quantizer.set_block_size_map(block_size_map)
243251

244252
fx_graph_module = prepare_pt2e(fx_graph_module, quantizer)
@@ -635,7 +643,7 @@ def permute(w, heads):
635643
if args.ptq != "16a8w":
636644
# 16a8w use 16bit kv io, so skip this custom annotation
637645
custom_annotations = custom_annotations + (annotate_matmul_16a8w,)
638-
if args.decoder_model in {"stories110m", "stories260k"}:
646+
if args.decoder_model in {"stories110m", "stories260k", "phi_4_mini"}:
639647
custom_annotations = custom_annotations + (
640648
annotate_linear_16a8w_in_affine_layer,
641649
)
@@ -853,12 +861,20 @@ def post_process():
853861

854862
seq_len = args.max_seq_len
855863
multi_prompts = " ".join([f'--prompt "{prompt}"' for prompt in args.prompt])
864+
lookahead_args = " ".join(
865+
[
866+
f"--window {args.window}",
867+
f"--gcap {args.gcap}",
868+
f"--ngram {args.ngram}",
869+
]
870+
)
856871
runner_args = " ".join(
857872
[
858873
multi_prompts,
859874
f"--eval_mode {EVAL_MODE[args.model_mode]}",
860875
f"--temperature {args.temperature}",
861876
f"--system_prompt '{args.system_prompt}'",
877+
lookahead_args if args.model_mode == "lookahead" else "",
862878
]
863879
)
864880

@@ -908,9 +924,6 @@ def post_process():
908924
"--output_path outputs/outputs.txt",
909925
f"--performance_output_path {performance_output_path}",
910926
f"--kv_updater {'SmartMask' if args.kv_updater == smart_mask_updater else 'ShiftPointer'}",
911-
f"--window {args.window}",
912-
f"--gcap {args.gcap}",
913-
f"--ngram {args.ngram}",
914927
runner_args,
915928
]
916929
)
@@ -1175,6 +1188,13 @@ def _build_parser():
11751188
action="store_true",
11761189
default=False,
11771190
)
1191+
parser.add_argument(
1192+
"-G",
1193+
"--group_size",
1194+
type=int,
1195+
default=None,
1196+
help="group_size used in block quantization for weight quantization.",
1197+
)
11781198

11791199
parser.add_argument("-v", "--verbose", action="store_true")
11801200

0 commit comments

Comments
 (0)