Skip to content

Commit d686472

Browse files
authored
Qualcomm AI Engine Direct - Static LLM Refactor & Qwen3 1.7B Improvement (#13755)
### Summary - Refactor llama.py. The current script has some limitations when it gets to customizing configs, especially quantization configs. As there are more models enabled, the script is a little messy, consisting of multiple `if`-`else` statement deciding what model should go into specific optimization. We want to move it all the model specs under `__init__.py`. - Hiding scale/offset into model's metadata, so args.quant_attrs_path is no longer required when evaluating ppl score. - Enable Qwen3 1.7B with 16a4w_block quant. Before is using 16a8w, which is much slower. Targeting maximizing token rate while ensuring ppl remains within a 20% margin compared to the FP CPU baseline #### Stats token rate = 37tok/sec ppl = 14.79 ### Test plan Tested all scripts ensuring no regression cc: @haowhsu-quic
1 parent 4d1da11 commit d686472

File tree

8 files changed

+452
-305
lines changed

8 files changed

+452
-305
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,12 @@ def annotate_mimi_decoder(gm: torch.fx.GraphModule):
9292
break
9393

9494

95-
def annotate_linear_16a8w_in_affine_layer(
96-
gm: torch.fx.GraphModule, is_qat: bool = False
97-
) -> None:
95+
def annotate_output_16a8w(gm: torch.fx.GraphModule, is_qat: bool = False) -> None:
96+
"""
97+
This function is for static LLM models.
98+
This function will annotate the last conv(linear), which is the lm_head, as 16a8w.
99+
"""
100+
98101
def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
99102
input_qspec_map = {}
100103
input_act = node.args[0]
@@ -163,11 +166,30 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
163166
)
164167

165168

166-
def annotate_matmul_16a8w( # noqa: C901
169+
def annotate_wv_sha(gm: torch.fx.GraphModule, quantization_config: QuantizationConfig):
170+
for node in gm.graph.nodes:
171+
if (
172+
node.target == torch.ops.aten.conv2d.default
173+
and "wv_sha" in node.meta["stack_trace"]
174+
):
175+
input_qspec_map = {}
176+
input_qspec_map[node.args[0]] = quantization_config.input_activation
177+
input_qspec_map[node.args[1]] = quantization_config.weight
178+
if len(node.args) > 2 and isinstance(node.args[2], Node):
179+
input_qspec_map[node.args[2]] = quantization_config.bias(node)
180+
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
181+
input_qspec_map=input_qspec_map,
182+
output_qspec=quantization_config.output_activation,
183+
_annotated=True,
184+
)
185+
186+
187+
def annotate_kv_8bit( # noqa: C901
167188
gm: torch.fx.GraphModule,
168189
is_qat=False,
169190
) -> None:
170191
"""
192+
This function is for static LLM models.
171193
This function is specific for matmul op 16a8w.
172194
For k, we will tag such as the below, and
173195
for v, we will tag 8a until conv op.
@@ -213,25 +235,6 @@ def annotate_cat(node: Node, quantization_config: QuantizationConfig):
213235
_annotated=True,
214236
)
215237

216-
def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
217-
input_qspec_map = {}
218-
input_act = node.args[0]
219-
input_spec = quantization_config.input_activation
220-
input_qspec_map[input_act] = input_spec
221-
222-
weight = node.args[1]
223-
input_qspec_map[weight] = quantization_config.weight
224-
225-
if len(node.args) > 2 and isinstance(node.args[2], Node):
226-
bias = node.args[2]
227-
input_qspec_map[bias] = quantization_config.bias(node)
228-
229-
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
230-
input_qspec_map=input_qspec_map,
231-
output_qspec=quantization_config.output_activation,
232-
_annotated=True,
233-
)
234-
235238
def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> None:
236239
act_node = node.args[0]
237240
weight_node = node.args[2]
@@ -301,22 +304,10 @@ def annotate_matmul_input1(node: Node, is_qat: str):
301304
quantization_config_8a8w = get_8a8w_qnn_qat_config(
302305
act_symmetric=True, act_observer=MinMaxObserver
303306
)
304-
quantization_config_8a4w_per_channel = get_qat_per_channel_quant_config(
305-
act_dtype=torch.uint8,
306-
weight_dtype=torch.int4,
307-
act_observer=MinMaxObserver,
308-
act_symmetric=True,
309-
)
310307
else:
311308
quantization_config_8a8w = get_8a8w_qnn_ptq_config(
312309
act_symmetric=True, act_observer=MinMaxObserver
313310
)
314-
quantization_config_8a4w_per_channel = get_ptq_per_channel_quant_config(
315-
act_dtype=torch.uint8,
316-
weight_dtype=torch.int4,
317-
act_observer=MinMaxObserver,
318-
act_symmetric=True,
319-
)
320311
while isinstance(node, Node) and node.op == "call_function":
321312
if node.target in [
322313
torch.ops.aten.permute.default,
@@ -343,15 +334,11 @@ def annotate_matmul_input1(node: Node, is_qat: str):
343334
# For k, we tag 8a until add or sub op (rotatary embedding).
344335
# The arguments of cat op: (the past kv cache, the new kv cache)
345336
node = node.args[0][1]
346-
elif node.target == torch.ops.aten.conv2d.default:
347-
annotate_conv2d(
348-
node, quantization_config=quantization_config_8a4w_per_channel
349-
)
350-
break
351337
elif node.target in [
352338
torch.ops.aten.add.Tensor,
353339
torch.ops.aten.sub.Tensor,
354340
torch.ops.aten.matmul.default,
341+
torch.ops.aten.conv2d.default,
355342
]:
356343
break
357344
else:

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4582,8 +4582,6 @@ def test_llama3_2_1b(self):
45824582
str(self.port),
45834583
"--prompt",
45844584
f"{prompt}",
4585-
"--ptq",
4586-
"16a4w",
45874585
"--temperature",
45884586
"0",
45894587
"--decoder_model",
@@ -4594,8 +4592,6 @@ def test_llama3_2_1b(self):
45944592
"32",
45954593
"--max_seq_len",
45964594
"512",
4597-
"--num_sharding",
4598-
"4",
45994595
]
46004596
if self.compile_only:
46014597
cmds.extend(["--compile_only"])
@@ -4662,8 +4658,6 @@ def test_llama_stories_260k(self):
46624658
str(self.port),
46634659
"--prompt",
46644660
f"{prompt}",
4665-
"--ptq",
4666-
"16a4w",
46674661
"--temperature",
46684662
"0",
46694663
"--decoder_model",
@@ -4740,8 +4734,6 @@ def test_llama_stories_110m(self):
47404734
str(self.port),
47414735
"--prompt",
47424736
f"{prompt}",
4743-
"--ptq",
4744-
"16a4w",
47454737
"--temperature",
47464738
"0",
47474739
"--decoder_model",
@@ -4806,18 +4798,12 @@ def test_static_phi4(self):
48064798
str(self.port),
48074799
"--prompt",
48084800
f"{prompt}",
4809-
"--ptq",
4810-
"16a4w_block",
4811-
"--group_size",
4812-
"16",
48134801
"--decoder_model",
48144802
"phi_4_mini",
48154803
"--model_mode",
48164804
"kv",
48174805
"--max_seq_len",
48184806
"1024",
4819-
"--num_sharding",
4820-
"8",
48214807
"--eval_perplexity",
48224808
"--tasks",
48234809
"wikitext",
@@ -4877,8 +4863,6 @@ def test_static_qwen2_5(self):
48774863
str(self.port),
48784864
"--prompt",
48794865
f"{prompt}",
4880-
"--ptq",
4881-
"16a8w",
48824866
"--decoder_model",
48834867
"qwen2_5-0_5b",
48844868
"--model_mode",
@@ -4890,8 +4874,6 @@ def test_static_qwen2_5(self):
48904874
"wikitext",
48914875
"--limit",
48924876
"1",
4893-
"--r3",
4894-
"--enable_masked_softmax",
48954877
]
48964878
if self.compile_only:
48974879
cmds.extend(["--compile_only"])
@@ -4940,8 +4922,6 @@ def test_static_qwen3(self):
49404922
str(self.port),
49414923
"--prompt",
49424924
f"{prompt}",
4943-
"--ptq",
4944-
"16a8w",
49454925
"--decoder_model",
49464926
"qwen3-0_6b",
49474927
"--model_mode",
@@ -4953,8 +4933,6 @@ def test_static_qwen3(self):
49534933
"wikitext",
49544934
"--limit",
49554935
"1",
4956-
"--r3",
4957-
"--enable_masked_softmax",
49584936
]
49594937
if self.compile_only:
49604938
cmds.extend(["--compile_only"])
@@ -5003,8 +4981,6 @@ def test_static_smollm2(self):
50034981
str(self.port),
50044982
"--prompt",
50054983
f"{prompt}",
5006-
"--ptq",
5007-
"16a8w",
50084984
"--decoder_model",
50094985
"smollm2_135m",
50104986
"--model_mode",

0 commit comments

Comments
 (0)