Skip to content

Commit 2e44184

Browse files
authored
Qualcomm AI Engine Direct - PTQ for llama3.2 1b/3b (#12700)
### Summary - add ptq recipe for llama3.2 1b/3b - add seq_mse support for helping quantizing 1b model - complement qnn_llama_runner for smollm2 ### Test Plan ```bash python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -H $HOST -s $SN -m SM8750 --temperature 0 --model_mode hybrid --prefill_ar_len 128 --max_seq_len 1024 --decoder_model llama3_2-1b_instruct --prompt "I would like to learn python, could you teach me with a simple example?" --artifact ./llama_artifact --tasks wikitext --limit 1 --compile_only --params ../.llama/checkpoints/Llama3.2-1B-Instruct/params.json --tokenizer_model ../.llama/checkpoints/Llama3.2-1B-Instruct/tokenizer.model --checkpoint ../.llama/checkpoints/Llama3.2-1B-Instruct/consolidated.00.pth ```
1 parent 7f95941 commit 2e44184

File tree

12 files changed

+437
-50
lines changed

12 files changed

+437
-50
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from .remove_redundancy import RemoveRedundancy
4040
from .replace_arange_args import ReplaceArangeArgs
4141
from .replace_inf_values import ReplaceInfValues
42+
from .seq_mse import SeqMSE
4243
from .tag_quant_io import TagQuantIO
4344

4445

@@ -78,5 +79,6 @@
7879
RemoveRedundancy,
7980
ReplaceArangeArgs,
8081
ReplaceInfValues,
82+
SeqMSE,
8183
TagQuantIO,
8284
]
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import types
7+
from contextlib import contextmanager
8+
9+
import torch
10+
import torchao
11+
from executorch.backends.qualcomm.quantizer.observers.per_block_param_observer import (
12+
PerBlockParamObserver,
13+
)
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
from torchao.quantization.pt2e import PerChannelMinMaxObserver
16+
17+
18+
class SeqMseModule(torch.nn.Module):
19+
"""
20+
Args:
21+
nominal_weight: Tensor
22+
nominal parameters from operator
23+
nominal_bias: Tensor
24+
nominal parameters from operator
25+
operator: fx.Node
26+
operator to be executed
27+
observer: UniformQuantizationObserverBase
28+
parameter observer (specific for weight)
29+
num_candidates: int
30+
grids to search minimal mse loss
31+
"""
32+
33+
def __init__(
34+
self,
35+
nominal_weight,
36+
nominal_bias,
37+
operator,
38+
observer,
39+
num_candidates,
40+
):
41+
super().__init__()
42+
self.nominal_weight = nominal_weight
43+
self.nominal_bias = nominal_bias
44+
self.observer = observer
45+
self.steps = torch.linspace(
46+
1 / num_candidates, 1, steps=num_candidates
47+
).tolist()
48+
self.operator = self._make_operator(operator)
49+
self.best_candidate_step = 1.0
50+
51+
def _make_operator(self, aten_op):
52+
if aten_op.target == torch.ops.aten.conv2d.default:
53+
stride = [1, 1] if len(aten_op.args) < 4 else aten_op.args[3]
54+
padding = [0, 0] if len(aten_op.args) < 5 else aten_op.args[4]
55+
dilation = [1, 1] if len(aten_op.args) < 6 else aten_op.args[5]
56+
groups = 1 if len(aten_op.args) < 7 else aten_op.args[6]
57+
has_bias = self.nominal_bias is not None
58+
module = torch.nn.Conv2d(
59+
in_channels=self.nominal_weight.shape[1],
60+
out_channels=self.nominal_weight.shape[0],
61+
kernel_size=self.nominal_weight.shape[-2:],
62+
stride=stride,
63+
padding=padding,
64+
dilation=dilation,
65+
groups=groups,
66+
bias=has_bias,
67+
)
68+
module.weight.data = self.nominal_weight
69+
if has_bias:
70+
module.bias.data = self.nominal_bias
71+
return module
72+
else:
73+
raise NotImplementedError(f"target of {aten_op.target} is not implemented")
74+
75+
def _per_block_qdq(self, scale, zero_point):
76+
return torchao.quantization.quant_primitives._fake_quantize_affine(
77+
input=self.nominal_weight,
78+
block_size=self.observer.block_size,
79+
scale=scale,
80+
zero_point=zero_point,
81+
quant_dtype=self.observer.dtype,
82+
quant_min=self.observer.quant_min,
83+
quant_max=self.observer.quant_max,
84+
)
85+
86+
def _per_channel_qdq(self, scale, zero_point):
87+
return torch.fake_quantize_per_channel_affine(
88+
input=self.nominal_weight,
89+
scale=scale,
90+
zero_point=zero_point,
91+
axis=0,
92+
quant_min=self.observer.quant_min,
93+
quant_max=self.observer.quant_max,
94+
)
95+
96+
def _fake_quant(self, scale, zero_point):
97+
dispatcher = {
98+
PerChannelMinMaxObserver: self._per_channel_qdq,
99+
PerBlockParamObserver: self._per_block_qdq,
100+
}
101+
return dispatcher[type(self.observer)](scale, zero_point)
102+
103+
def _find_best_candidate(self, nominal_input, nominal_output):
104+
# calculate current baseline
105+
scale, zero_point = self.observer.calculate_qparams()
106+
zero_point = zero_point.to(torch.int32)
107+
self.operator.weight.data = self._fake_quant(scale, zero_point)
108+
candidate, current_loss = (
109+
1,
110+
torch.nn.functional.mse_loss(
111+
self.operator(nominal_input), nominal_output
112+
).item(),
113+
)
114+
for step in self.steps:
115+
self.operator.weight.data = self._fake_quant(scale * step, zero_point)
116+
loss = torch.nn.functional.mse_loss(
117+
self.operator(nominal_input), nominal_output
118+
).item()
119+
if loss < current_loss:
120+
candidate, current_loss = step, loss
121+
return candidate
122+
123+
def forward(self, nominal_input, nominal_output):
124+
self.best_candidate_step = self._find_best_candidate(
125+
nominal_input=nominal_input, nominal_output=nominal_output
126+
)
127+
128+
129+
class InsertSeqMse(ExportPass):
130+
"""
131+
Insert Seq Mse Observer to find the best quant config for certain node's weight.
132+
"""
133+
134+
seq_mse_ops = {torch.ops.aten.conv2d.default}
135+
136+
def __init__(self, num_candidates=1000):
137+
super(InsertSeqMse, self).__init__()
138+
self.num_candidates = num_candidates
139+
140+
def _insert_seq_mse(
141+
self, graph_module: torch.fx.GraphModule
142+
) -> torch.fx.GraphModule:
143+
count = 0
144+
for node in graph_module.graph.nodes:
145+
if node.target in self.seq_mse_ops:
146+
# extract observer
147+
weight_node_obs = node.args[1]
148+
observer = getattr(graph_module, weight_node_obs.name)
149+
# extract parameters
150+
weight_node = weight_node_obs.args[0]
151+
weight_tensor = graph_module.get_parameter(weight_node.target).detach()
152+
bias_tensor = None
153+
if len(node.args) > 2 and node.args[2] is not None:
154+
bias_tensor = graph_module.get_parameter(
155+
node.args[2].args[0].target
156+
).detach()
157+
158+
with graph_module.graph.inserting_after(node):
159+
seq_mse_mod = SeqMseModule(
160+
nominal_weight=weight_tensor,
161+
nominal_bias=bias_tensor,
162+
operator=node,
163+
observer=observer,
164+
num_candidates=self.num_candidates,
165+
)
166+
module_name = f"seq_mse_{count}"
167+
count += 1
168+
setattr(graph_module, module_name, seq_mse_mod)
169+
input_nodes = (node.args[0], node)
170+
graph_module.graph.create_node(
171+
"call_module", module_name, input_nodes, {}
172+
)
173+
174+
def call(self, graph_module: torch.fx.GraphModule):
175+
self._insert_seq_mse(graph_module)
176+
graph_module.recompile()
177+
return PassResult(graph_module, True)
178+
179+
180+
class RemoveSeqMse(ExportPass):
181+
"""
182+
Remove Seq Mse before invoking convert_pt2e and update final quantization encoding.
183+
"""
184+
185+
def __init__(self):
186+
super(RemoveSeqMse, self).__init__()
187+
188+
def _remove_seq_mse(
189+
self, graph_module: torch.fx.GraphModule
190+
) -> torch.fx.GraphModule:
191+
node_to_erase = []
192+
for node in graph_module.graph.nodes:
193+
if node.op == "call_module":
194+
# try extracting SeqMse module
195+
module = getattr(graph_module, node.target)
196+
if isinstance(module, SeqMseModule):
197+
# rewrite observer method for pre-calculated scale
198+
scale, zero_point = module.observer.calculate_qparams()
199+
module.observer.updated_encoding = (
200+
scale * module.best_candidate_step,
201+
zero_point,
202+
)
203+
module.observer.calculate_qparams = types.MethodType(
204+
lambda s: s.updated_encoding, module.observer
205+
)
206+
node_to_erase.append(node)
207+
208+
for node in node_to_erase:
209+
graph_module.graph.erase_node(node)
210+
211+
def call(self, graph_module: torch.fx.GraphModule):
212+
self._remove_seq_mse(graph_module)
213+
graph_module.recompile()
214+
return PassResult(graph_module, True)
215+
216+
217+
@contextmanager
218+
def SeqMSE(prepared_gm, num_candidates):
219+
prepared_gm = InsertSeqMse(num_candidates)(prepared_gm).graph_module
220+
try:
221+
yield
222+
finally:
223+
prepared_gm = RemoveSeqMse()(prepared_gm).graph_module

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,25 @@
3131
)
3232

3333

34+
def annotate_down_proj(
35+
gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
36+
):
37+
for node in gm.graph.nodes:
38+
if (
39+
node.target == torch.ops.aten.conv2d.default
40+
and any(s in node.meta["stack_trace"] for s in ["forward_feedfoward_conv"])
41+
and node.args[0].target == torch.ops.aten.mul.Tensor
42+
):
43+
input_qspec_map = {}
44+
input_qspec_map[node.args[0]] = quantization_config.input_activation
45+
input_qspec_map[node.args[1]] = quantization_config.weight
46+
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
47+
input_qspec_map=input_qspec_map,
48+
output_qspec=quantization_config.output_activation,
49+
_annotated=True,
50+
)
51+
52+
3453
def annotate_eurobert(gm: torch.fx.GraphModule):
3554
"""
3655
QNN does not support int32 -> signed 16bit quant

backends/qualcomm/quantizer/observers/per_block_param_observer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
eps=eps,
3535
**kwargs,
3636
)
37+
self.dtype = dtype
3738
self.block_size = block_size
3839
# TODO: expand this when QNN starts to support more configurations
3940
self.bitwidth_of_scale = 4

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4636,6 +4636,33 @@ def test_qnn_backend_generate_optrace(self):
46364636
qhas_data = json.load(qhas_file)
46374637
self.assertIn("data", qhas_data)
46384638

4639+
def test_qnn_backend_seq_mse(self):
4640+
from executorch.backends.qualcomm._passes.seq_mse import SeqMSE
4641+
4642+
o_ch, i_ch, kernel, padding = 32, 512, (1, 1), 0
4643+
module = Conv2dSingle( # noqa: F405
4644+
in_channel=i_ch,
4645+
out_channel=o_ch,
4646+
kernel_size=kernel,
4647+
padding=padding,
4648+
)
4649+
sample_input = (torch.randn(1, i_ch, 1, o_ch),)
4650+
# per-channel / per-block
4651+
quantizers = [
4652+
make_quantizer(),
4653+
make_quantizer(quant_dtype=QuantDtype.use_16a4w_block),
4654+
]
4655+
quantizers[-1].set_block_size_map({"conv2d": (1, 32, 1, 1)})
4656+
4657+
for i, quantizer in enumerate(quantizers):
4658+
with self.subTest(i=i):
4659+
ep = torch.export.export(module, sample_input).module()
4660+
prepared = prepare_pt2e(ep, quantizer)
4661+
with SeqMSE(prepared, 100):
4662+
prepared(*sample_input)
4663+
converted = convert_pt2e(prepared)
4664+
self.lower_module_and_test_output(converted, sample_input)
4665+
46394666

46404667
class TestExampleLLMScript(TestQNN):
46414668
def test_static_gemma3_1b(self):
@@ -4709,7 +4736,7 @@ def test_static_gemma3_1b(self):
47094736
msg["inference_speed"], inference_speed_ref[self.model]
47104737
)
47114738

4712-
def test_llama3_2_1b(self):
4739+
def test_llama3_2_instruct(self):
47134740
if not self.required_envs():
47144741
self.skipTest("missing required envs")
47154742
assert (
@@ -4741,13 +4768,16 @@ def test_llama3_2_1b(self):
47414768
"--temperature",
47424769
"0",
47434770
"--decoder_model",
4744-
"llama3_2",
4771+
"llama3_2-1b_instruct",
47454772
"--model_mode",
4746-
"hybrid",
4747-
"--prefill_ar_len",
4748-
"32",
4773+
"kv",
47494774
"--max_seq_len",
4750-
"512",
4775+
"1024",
4776+
"--eval_perplexity",
4777+
"--tasks",
4778+
"wikitext",
4779+
"--limit",
4780+
"1",
47514781
]
47524782
if self.compile_only:
47534783
cmds.extend(["--compile_only"])
@@ -4760,7 +4790,6 @@ def test_llama3_2_1b(self):
47604790
if self.pre_gen_pte:
47614791
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
47624792

4763-
golden_start_with = "<|start_header_id|>user<|end_header_id|>"
47644793
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
47654794
with Listener((self.ip, self.port)) as listener:
47664795
conn = listener.accept()
@@ -4769,19 +4798,17 @@ def test_llama3_2_1b(self):
47694798
if "Error" in msg:
47704799
self.fail(msg["Error"])
47714800
else:
4772-
if not self.compile_only:
4773-
model_out = msg["result"][0]
4774-
self.assertTrue(
4775-
model_out.startswith(golden_start_with),
4776-
f"Expected Output: {golden_start_with}. Actual Output: {model_out}",
4801+
inference_speed_ref = {"SM8650": 37, "SM8750": 49}
4802+
if (
4803+
not self.compile_only
4804+
and not self.enable_x86_64
4805+
and self.model in inference_speed_ref
4806+
):
4807+
self.assertLessEqual(msg["pte_size"], 1_500_000_000)
4808+
self.assertLessEqual(msg["wiki_ppl"], 15)
4809+
self.assertGreaterEqual(
4810+
msg["inference_speed"], inference_speed_ref[self.model]
47774811
)
4778-
# x86 does not allow weight sharing, so we don't check pte size.
4779-
# Inference speed on x86 is slow, so we only check when running on Android
4780-
if not self.enable_x86_64:
4781-
pte_size = msg["pte_size"]
4782-
self.assertLessEqual(pte_size, 1_300_000_000) # 1.3GB
4783-
if not self.compile_only and not self.enable_x86_64:
4784-
self.assertGreaterEqual(msg["inference_speed"], 66) # Lanai
47854812

47864813
def test_llama_stories_260k(self):
47874814
if not self.required_envs():
@@ -4976,12 +5003,6 @@ def test_static_phi4(self):
49765003
cmds.extend(["--enable_x86_64"])
49775004
if self.pre_gen_pte:
49785005
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
4979-
cmds.extend(
4980-
[
4981-
"--quant_attrs_path",
4982-
f"{self.pre_gen_pte}/kv_llama_qnn_quant_attrs.json",
4983-
]
4984-
)
49855006

49865007
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
49875008
with Listener((self.ip, self.port)) as listener:

0 commit comments

Comments
 (0)