Skip to content

Commit 3f2cad4

Browse files
committed
Qualcomm AI Engine Direct - PTQ for llama3.2 1b/3b
- add ptq recipe for llama3.2 1b/3b - add seq_mse support for helping quantizing 1b model - complement qnn_llama_runner for smollm2
1 parent 43bd889 commit 3f2cad4

File tree

12 files changed

+435
-50
lines changed

12 files changed

+435
-50
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from .remove_redundancy import RemoveRedundancy
4040
from .replace_arange_args import ReplaceArangeArgs
4141
from .replace_inf_values import ReplaceInfValues
42+
from .seq_mse import SeqMSE
4243
from .tag_quant_io import TagQuantIO
4344

4445

@@ -78,5 +79,6 @@
7879
RemoveRedundancy,
7980
ReplaceArangeArgs,
8081
ReplaceInfValues,
82+
SeqMSE,
8183
TagQuantIO,
8284
]
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import types
7+
from contextlib import contextmanager
8+
9+
import torch
10+
import torchao
11+
from executorch.backends.qualcomm.quantizer.observers.per_block_param_observer import (
12+
PerBlockParamObserver,
13+
)
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
from torchao.quantization.pt2e import PerChannelMinMaxObserver
16+
17+
18+
class SeqMseModule(torch.nn.Module):
19+
"""
20+
Args:
21+
nominal_weight: Tensor
22+
nominal parameters from operator
23+
nominal_bias: Tensor
24+
nominal parameters from operator
25+
operator: fx.Node
26+
operator to be executed
27+
observer: UniformQuantizationObserverBase
28+
parameter observer (specific for weight)
29+
num_candidates: int
30+
grids to search minimal mse loss
31+
"""
32+
33+
def __init__(
34+
self,
35+
nominal_weight,
36+
nominal_bias,
37+
operator,
38+
observer,
39+
num_candidates,
40+
):
41+
super().__init__()
42+
self.nominal_weight = nominal_weight
43+
self.nominal_bias = nominal_bias
44+
self.observer = observer
45+
self.steps = torch.linspace(
46+
1 / num_candidates, 1, steps=num_candidates
47+
).tolist()
48+
self.operator = self._make_operator(operator)
49+
self.best_candidate_step = 1.0
50+
51+
def _make_operator(self, aten_op):
52+
if aten_op.target == torch.ops.aten.conv2d.default:
53+
stride = [1, 1] if len(aten_op.args) < 4 else aten_op.args[3]
54+
padding = [0, 0] if len(aten_op.args) < 5 else aten_op.args[4]
55+
dilation = [1, 1] if len(aten_op.args) < 6 else aten_op.args[5]
56+
groups = 1 if len(aten_op.args) < 7 else aten_op.args[6]
57+
has_bias = self.nominal_bias is not None
58+
module = torch.nn.Conv2d(
59+
in_channels=self.nominal_weight.shape[1],
60+
out_channels=self.nominal_weight.shape[0],
61+
kernel_size=self.nominal_weight.shape[-2:],
62+
stride=stride,
63+
padding=padding,
64+
dilation=dilation,
65+
groups=groups,
66+
bias=has_bias,
67+
)
68+
module.weight.data = self.nominal_weight
69+
if has_bias:
70+
module.bias.data = self.nominal_bias
71+
return module
72+
else:
73+
raise NotImplementedError(f"target of {aten_op.target} is not implemented")
74+
75+
def _per_block_qdq(self, scale, zero_point):
76+
return torchao.quantization.quant_primitives._fake_quantize_affine(
77+
input=self.nominal_weight,
78+
block_size=self.observer.block_size,
79+
scale=scale,
80+
zero_point=zero_point,
81+
quant_dtype=self.observer.dtype,
82+
quant_min=self.observer.quant_min,
83+
quant_max=self.observer.quant_max,
84+
)
85+
86+
def _per_channel_qdq(self, scale, zero_point):
87+
return torch.fake_quantize_per_channel_affine(
88+
input=self.nominal_weight,
89+
scale=scale,
90+
zero_point=zero_point,
91+
axis=0,
92+
quant_min=self.observer.quant_min,
93+
quant_max=self.observer.quant_max,
94+
)
95+
96+
def _fake_quant(self, scale, zero_point):
97+
dispatcher = {
98+
PerChannelMinMaxObserver: self._per_channel_qdq,
99+
PerBlockParamObserver: self._per_block_qdq,
100+
}
101+
return dispatcher[type(self.observer)](scale, zero_point)
102+
103+
def _find_best_candidate(self, nominal_input, nominal_output):
104+
# calculate current baseline
105+
scale, zero_point = self.observer.calculate_qparams()
106+
zero_point = zero_point.to(torch.int32)
107+
self.operator.weight.data = self._fake_quant(scale, zero_point)
108+
candidate, current_loss = (
109+
1,
110+
torch.nn.functional.mse_loss(
111+
self.operator(nominal_input), nominal_output
112+
).item(),
113+
)
114+
for step in self.steps:
115+
self.operator.weight.data = self._fake_quant(scale * step, zero_point)
116+
loss = torch.nn.functional.mse_loss(
117+
self.operator(nominal_input), nominal_output
118+
).item()
119+
if loss < current_loss:
120+
candidate, current_loss = step, loss
121+
return candidate
122+
123+
def forward(self, nominal_input, nominal_output):
124+
self.best_candidate_step = self._find_best_candidate(
125+
nominal_input=nominal_input, nominal_output=nominal_output
126+
)
127+
128+
129+
class InsertSeqMse(ExportPass):
130+
"""
131+
Insert Seq Mse Observer to find the best quant config for certain node's weight.
132+
"""
133+
134+
seq_mse_ops = {torch.ops.aten.conv2d.default}
135+
136+
def __init__(self, num_candidates=1000):
137+
super(InsertSeqMse, self).__init__()
138+
self.num_candidates = num_candidates
139+
140+
def _insert_seq_mse(
141+
self, graph_module: torch.fx.GraphModule
142+
) -> torch.fx.GraphModule:
143+
count = 0
144+
for node in graph_module.graph.nodes:
145+
if node.target in self.seq_mse_ops:
146+
# extract observer
147+
weight_node_obs = node.args[1]
148+
observer = getattr(graph_module, weight_node_obs.name)
149+
# extract parameters
150+
weight_node = weight_node_obs.args[0]
151+
weight_tensor = graph_module.get_parameter(weight_node.target).detach()
152+
bias_tensor = None
153+
if len(node.args) > 2 and node.args[2] is not None:
154+
bias_tensor = graph_module.get_parameter(
155+
node.args[2].args[0].target
156+
).detach()
157+
158+
with graph_module.graph.inserting_after(node):
159+
seq_mse_mod = SeqMseModule(
160+
nominal_weight=weight_tensor,
161+
nominal_bias=bias_tensor,
162+
operator=node,
163+
observer=observer,
164+
num_candidates=self.num_candidates,
165+
)
166+
module_name = f"seq_mse_{count}"
167+
count += 1
168+
setattr(graph_module, module_name, seq_mse_mod)
169+
input_nodes = (node.args[0], node)
170+
graph_module.graph.create_node(
171+
"call_module", module_name, input_nodes, {}
172+
)
173+
174+
def call(self, graph_module: torch.fx.GraphModule):
175+
self._insert_seq_mse(graph_module)
176+
graph_module.recompile()
177+
return PassResult(graph_module, True)
178+
179+
180+
class RemoveSeqMse(ExportPass):
181+
"""
182+
Remove Seq Mse before invoking convert_pt2e and update final quantization encoding.
183+
"""
184+
185+
def __init__(self):
186+
super(RemoveSeqMse, self).__init__()
187+
188+
def _remove_seq_mse(
189+
self, graph_module: torch.fx.GraphModule
190+
) -> torch.fx.GraphModule:
191+
node_to_erase = []
192+
for node in graph_module.graph.nodes:
193+
if node.op == "call_module":
194+
# try extracting SeqMse module
195+
module = getattr(graph_module, node.target)
196+
if isinstance(module, SeqMseModule):
197+
# rewrite observer method for pre-calculated scale
198+
scale, zero_point = module.observer.calculate_qparams()
199+
module.observer.updated_encoding = (
200+
scale * module.best_candidate_step,
201+
zero_point,
202+
)
203+
module.observer.calculate_qparams = types.MethodType(
204+
lambda s: s.updated_encoding, module.observer
205+
)
206+
node_to_erase.append(node)
207+
208+
for node in node_to_erase:
209+
graph_module.graph.erase_node(node)
210+
211+
def call(self, graph_module: torch.fx.GraphModule):
212+
self._remove_seq_mse(graph_module)
213+
graph_module.recompile()
214+
return PassResult(graph_module, True)
215+
216+
217+
@contextmanager
218+
def SeqMSE(prepared_gm, num_candidates):
219+
prepared_gm = InsertSeqMse(num_candidates)(prepared_gm).graph_module
220+
try:
221+
yield
222+
finally:
223+
prepared_gm = RemoveSeqMse()(prepared_gm).graph_module

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,25 @@
3131
)
3232

3333

34+
def annotate_down_proj(
35+
gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
36+
):
37+
for node in gm.graph.nodes:
38+
if (
39+
node.target == torch.ops.aten.conv2d.default
40+
and any(s in node.meta["stack_trace"] for s in ["forward_feedfoward_conv"])
41+
and node.args[0].target == torch.ops.aten.mul.Tensor
42+
):
43+
input_qspec_map = {}
44+
input_qspec_map[node.args[0]] = quantization_config.input_activation
45+
input_qspec_map[node.args[1]] = quantization_config.weight
46+
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
47+
input_qspec_map=input_qspec_map,
48+
output_qspec=quantization_config.output_activation,
49+
_annotated=True,
50+
)
51+
52+
3453
def annotate_eurobert(gm: torch.fx.GraphModule):
3554
"""
3655
QNN does not support int32 -> signed 16bit quant

backends/qualcomm/quantizer/observers/per_block_param_observer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
eps=eps,
3535
**kwargs,
3636
)
37+
self.dtype = dtype
3738
self.block_size = block_size
3839
# TODO: expand this when QNN starts to support more configurations
3940
self.bitwidth_of_scale = 4

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4574,9 +4574,36 @@ def test_qnn_backend_generate_optrace(self):
45744574
qhas_data = json.load(qhas_file)
45754575
self.assertIn("data", qhas_data)
45764576

4577+
def test_qnn_backend_seq_mse(self):
4578+
from executorch.backends.qualcomm._passes.seq_mse import SeqMSE
4579+
4580+
o_ch, i_ch, kernel, padding = 32, 512, (1, 1), 0
4581+
module = Conv2dSingle( # noqa: F405
4582+
in_channel=i_ch,
4583+
out_channel=o_ch,
4584+
kernel_size=kernel,
4585+
padding=padding,
4586+
)
4587+
sample_input = (torch.randn(1, i_ch, 1, o_ch),)
4588+
# per-channel / per-block
4589+
quantizers = [
4590+
make_quantizer(),
4591+
make_quantizer(quant_dtype=QuantDtype.use_16a4w_block),
4592+
]
4593+
quantizers[-1].set_block_size_map({"conv2d": (1, 32, 1, 1)})
4594+
4595+
for i, quantizer in enumerate(quantizers):
4596+
with self.subTest(i=i):
4597+
ep = torch.export.export(module, sample_input).module()
4598+
prepared = prepare_pt2e(ep, quantizer)
4599+
with SeqMSE(prepared, 100):
4600+
prepared(*sample_input)
4601+
converted = convert_pt2e(prepared)
4602+
self.lower_module_and_test_output(converted, sample_input)
4603+
45774604

45784605
class TestExampleLLMScript(TestQNN):
4579-
def test_llama3_2_1b(self):
4606+
def test_llama3_2_instruct(self):
45804607
if not self.required_envs():
45814608
self.skipTest("missing required envs")
45824609
assert (
@@ -4608,13 +4635,16 @@ def test_llama3_2_1b(self):
46084635
"--temperature",
46094636
"0",
46104637
"--decoder_model",
4611-
"llama3_2",
4638+
"llama3_2-1b_instruct",
46124639
"--model_mode",
4613-
"hybrid",
4614-
"--prefill_ar_len",
4615-
"32",
4640+
"kv",
46164641
"--max_seq_len",
4617-
"512",
4642+
"1024",
4643+
"--eval_perplexity",
4644+
"--tasks",
4645+
"wikitext",
4646+
"--limit",
4647+
"1",
46184648
]
46194649
if self.compile_only:
46204650
cmds.extend(["--compile_only"])
@@ -4627,7 +4657,6 @@ def test_llama3_2_1b(self):
46274657
if self.pre_gen_pte:
46284658
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
46294659

4630-
golden_start_with = "<|start_header_id|>user<|end_header_id|>"
46314660
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
46324661
with Listener((self.ip, self.port)) as listener:
46334662
conn = listener.accept()
@@ -4636,19 +4665,17 @@ def test_llama3_2_1b(self):
46364665
if "Error" in msg:
46374666
self.fail(msg["Error"])
46384667
else:
4639-
if not self.compile_only:
4640-
model_out = msg["result"][0]
4641-
self.assertTrue(
4642-
model_out.startswith(golden_start_with),
4643-
f"Expected Output: {golden_start_with}. Actual Output: {model_out}",
4668+
inference_speed_ref = {"SM8650": 37, "SM8750": 49}
4669+
if (
4670+
not self.compile_only
4671+
and not self.enable_x86_64
4672+
and self.model in inference_speed_ref
4673+
):
4674+
self.assertLessEqual(msg["pte_size"], 1500000000)
4675+
self.assertLessEqual(msg["wiki_ppl"], 15)
4676+
self.assertGreaterEqual(
4677+
msg["inference_speed"], inference_speed_ref[self.model]
46444678
)
4645-
# x86 does not allow weight sharing, so we don't check pte size.
4646-
# Inference speed on x86 is slow, so we only check when running on Android
4647-
if not self.enable_x86_64:
4648-
pte_size = msg["pte_size"]
4649-
self.assertLessEqual(pte_size, 1300000000)
4650-
if not self.compile_only and not self.enable_x86_64:
4651-
self.assertGreaterEqual(msg["inference_speed"], 66) # Lanai
46524679

46534680
def test_llama_stories_260k(self):
46544681
if not self.required_envs():
@@ -4843,12 +4870,6 @@ def test_static_phi4(self):
48434870
cmds.extend(["--enable_x86_64"])
48444871
if self.pre_gen_pte:
48454872
cmds.extend(["--pre_gen_pte", self.pre_gen_pte])
4846-
cmds.extend(
4847-
[
4848-
"--quant_attrs_path",
4849-
f"{self.pre_gen_pte}/kv_llama_qnn_quant_attrs.json",
4850-
]
4851-
)
48524873

48534874
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
48544875
with Listener((self.ip, self.port)) as listener:

0 commit comments

Comments
 (0)