Skip to content

Commit 7481013

Browse files
committed
QC SeqMSE Draft
1 parent e31eb56 commit 7481013

File tree

8 files changed

+267
-16
lines changed

8 files changed

+267
-16
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from .remove_redundancy import RemoveRedundancy
3838
from .replace_arange_args import ReplaceArangeArgs
3939
from .replace_inf_values import ReplaceInfValues
40+
from .seq_mse import InsertSeqMse, RemoveSeqMse
4041
from .tag_quant_io import TagQuantIO
4142

4243

@@ -65,13 +66,15 @@
6566
I64toI32,
6667
InsertIOQDQ,
6768
InsertRequantize,
69+
InsertSeqMse,
6870
LayoutTransform,
6971
LiftConstantScalarOperands,
7072
RecomposePixelUnshuffle,
7173
RecomposeRmsNorm,
7274
ReduceDynamicRange,
7375
Remove0DTensor,
7476
RemoveRedundancy,
77+
RemoveSeqMse,
7578
ReplaceArangeArgs,
7679
ReplaceInfValues,
7780
TagQuantIO,
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import types
7+
8+
import torch
9+
import torchao
10+
from torchao.quantization.pt2e import PerChannelMinMaxObserver
11+
from executorch.backends.qualcomm.quantizer.observers.per_block_param_observer import (
12+
PerBlockParamObserver,
13+
)
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
16+
class SeqMseModule(torch.nn.Module):
17+
"""
18+
Args:
19+
nominal_weight: Tensor
20+
nominal parameters from operator
21+
nominal_bias: Tensor
22+
nominal parameters from operator
23+
operator: fx.Node
24+
operator to be executed
25+
observer: UniformQuantizationObserverBase
26+
parameter observer (specific for weight)
27+
num_candidates: int
28+
grids to search minimal mse loss
29+
"""
30+
31+
def __init__(
32+
self,
33+
nominal_weight,
34+
nominal_bias,
35+
operator,
36+
observer,
37+
num_candidates,
38+
):
39+
super().__init__()
40+
self.nominal_weight = nominal_weight
41+
self.nominal_bias = nominal_bias
42+
self.observer = observer
43+
step = 1 / num_candidates
44+
self.steps = torch.arange(start=step, end=1+step, step=step).tolist()
45+
self.operator = self._make_operator(operator)
46+
self.best_candidate = 1.0
47+
48+
def _make_operator(self, aten_op):
49+
if aten_op.target == torch.ops.aten.conv2d.default:
50+
stride = [1, 1] if len(aten_op.args) < 4 else aten_op.args[3]
51+
padding = [0, 0] if len(aten_op.args) < 5 else aten_op.args[4]
52+
dilation = [1, 1] if len(aten_op.args) < 6 else aten_op.args[5]
53+
groups = 1 if len(aten_op.args) < 7 else aten_op.args[6]
54+
has_bias = self.nominal_bias is not None
55+
module = torch.nn.Conv2d(
56+
in_channels=self.nominal_weight.shape[1],
57+
out_channels=self.nominal_weight.shape[0],
58+
kernel_size=self.nominal_weight.shape[-2:],
59+
stride=stride,
60+
padding=padding,
61+
dilation=dilation,
62+
groups=groups,
63+
bias=has_bias
64+
)
65+
module.weight.data = self.nominal_weight
66+
if has_bias:
67+
module.bias.data = self.nominal_bias
68+
return module
69+
else:
70+
raise NotImplementedError(f"target of {aten_op.target} is not implemented")
71+
72+
def _per_block_qdq(self, scale, zero_point):
73+
return torchao.quantization.quant_primitives._fake_quantize_affine(
74+
input=self.nominal_weight,
75+
block_size=self.observer.block_size,
76+
scale=scale,
77+
zero_point=zero_point,
78+
quant_dtype=self.observer.dtype,
79+
quant_min=self.observer.quant_min,
80+
quant_max=self.observer.quant_max,
81+
)
82+
83+
def _per_channel_qdq(self, scale, zero_point):
84+
return torch.fake_quantize_per_channel_affine(
85+
input=self.nominal_weight,
86+
scale=scale,
87+
zero_point=zero_point,
88+
axis=0,
89+
quant_min=self.observer.quant_min,
90+
quant_max=self.observer.quant_max,
91+
)
92+
93+
def _fake_quant(self, scale, zero_point):
94+
dispatcher = {
95+
PerChannelMinMaxObserver: self._per_channel_qdq,
96+
PerBlockParamObserver: self._per_block_qdq,
97+
}
98+
return dispatcher[type(self.observer)](scale, zero_point)
99+
100+
def _find_best_candidate(self, nominal_input, nominal_output):
101+
# calculate current baseline
102+
scale, zero_point = self.observer.calculate_qparams()
103+
zero_point = zero_point.to(torch.int32)
104+
self.operator.weight.data = self._fake_quant(scale, zero_point)
105+
candidate, current_loss = 1, torch.nn.functional.mse_loss(
106+
self.operator(nominal_input), nominal_output
107+
).item()
108+
for step in self.steps:
109+
self.operator.weight.data = self._fake_quant(scale * step, zero_point)
110+
loss = torch.nn.functional.mse_loss(
111+
self.operator(nominal_input), nominal_output
112+
).item()
113+
if loss < current_loss:
114+
candidate, current_loss = step, loss
115+
return candidate
116+
117+
def forward(self, nominal_input, nominal_output):
118+
self.best_candidate = self._find_best_candidate(
119+
nominal_input=nominal_input, nominal_output=nominal_output
120+
)
121+
122+
123+
class InsertSeqMse(ExportPass):
124+
"""
125+
Insert Seq Mse Observer to find the best quant config for certain node's weight.
126+
"""
127+
128+
seq_mse_ops = {
129+
torch.ops.aten.conv2d.default
130+
}
131+
132+
def __init__(self, num_candidates=1000):
133+
super(InsertSeqMse, self).__init__()
134+
self.num_candidates = num_candidates
135+
136+
def _insert_seq_mse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
137+
count = 0
138+
for node in graph_module.graph.nodes:
139+
if node.target in self.seq_mse_ops:
140+
# extract observer
141+
weight_node_obs = node.args[1]
142+
observer = getattr(graph_module, weight_node_obs.name)
143+
# extract parameters
144+
weight_node = weight_node_obs.args[0]
145+
weight_tensor = graph_module.get_parameter(
146+
weight_node.target
147+
).detach()
148+
bias_tensor = None
149+
if len(node.args) > 2 and node.args[2] is not None:
150+
bias_tensor = graph_module.get_parameter(
151+
node.args[2].args[0].target
152+
).detach()
153+
154+
with graph_module.graph.inserting_after(node):
155+
seq_mse_mod = SeqMseModule(
156+
nominal_weight=weight_tensor,
157+
nominal_bias=bias_tensor,
158+
operator=node,
159+
observer=observer,
160+
num_candidates=self.num_candidates,
161+
)
162+
module_name = f"seq_mse_{count}"
163+
count += 1
164+
setattr(graph_module, module_name, seq_mse_mod)
165+
input_nodes = (node.args[0], node)
166+
graph_module.graph.create_node("call_module", module_name, input_nodes, {})
167+
168+
def call(self, graph_module: torch.fx.GraphModule):
169+
self._insert_seq_mse(graph_module)
170+
graph_module.recompile()
171+
return PassResult(graph_module, True)
172+
173+
174+
class RemoveSeqMse(ExportPass):
175+
"""
176+
Remove Seq Mse before invoking convert_pt2e and update final quantization encoding.
177+
"""
178+
def __init__(self):
179+
super(RemoveSeqMse, self).__init__()
180+
181+
def _remove_seq_mse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
182+
node_to_erase = []
183+
for node in graph_module.graph.nodes:
184+
if node.op == "call_module" and "seq_mse" in node.name:
185+
# extract SeqMse module
186+
module = getattr(graph_module, node.target)
187+
# rewrite observer method for pre-calculated scale
188+
scale, zero_point = module.observer.calculate_qparams()
189+
module.observer.updated_encoding = (
190+
scale * module.best_candidate, zero_point
191+
)
192+
module.observer.calculate_qparams = (
193+
types.MethodType(lambda s: s.updated_encoding, module.observer)
194+
)
195+
node_to_erase.append(node)
196+
197+
for node in node_to_erase:
198+
graph_module.graph.erase_node(node)
199+
200+
def call(self, graph_module: torch.fx.GraphModule):
201+
self._remove_seq_mse(graph_module)
202+
graph_module.recompile()
203+
return PassResult(graph_module, True)

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,9 @@ def annotate_matmul_input1(node: Node):
285285
quantization_config_8a8w = get_8a8w_qnn_ptq_config(
286286
act_symmetric=True, act_observer=MinMaxObserver
287287
)
288-
quantization_config_8a4w_per_channel = get_ptq_per_channel_quant_config(
289-
act_dtype=torch.uint8,
290-
weight_dtype=torch.int4,
288+
quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config(
289+
act_dtype=torch.uint16,
290+
weight_dtype=torch.int8,
291291
act_observer=MinMaxObserver,
292292
act_symmetric=True,
293293
)
@@ -318,7 +318,7 @@ def annotate_matmul_input1(node: Node):
318318
node = node.args[0][1]
319319
elif node.target == torch.ops.aten.conv2d.default:
320320
annotate_conv2d(
321-
node, quantization_config=quantization_config_8a4w_per_channel
321+
node, quantization_config=quantization_config_16a8w_per_channel
322322
)
323323
break
324324
elif node.target in [torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor]:

backends/qualcomm/quantizer/observers/per_block_param_observer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
eps=eps,
3535
**kwargs,
3636
)
37+
self.dtype = dtype
3738
self.block_size = block_size
3839
self.calibrated = False
3940

backends/qualcomm/quantizer/quantizer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
from contextlib import contextmanager
67
from dataclasses import dataclass
78
from enum import IntEnum, unique
89
from functools import partial
910
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple
1011

1112
import torch
13+
from executorch.backends.qualcomm._passes import InsertSeqMse, RemoveSeqMse
1214
from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager
1315

1416
from torch._ops import OpOverload
@@ -427,3 +429,12 @@ def predicate(node):
427429
return False
428430

429431
return predicate
432+
433+
434+
@contextmanager
435+
def qnn_ptq_manager(prepared_gm):
436+
prepared_gm = InsertSeqMse()(prepared_gm).graph_module
437+
try:
438+
yield
439+
finally:
440+
prepared_gm = RemoveSeqMse()(prepared_gm).graph_module

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3981,6 +3981,33 @@ def test_qnn_backend_generate_optrace(self):
39813981
qhas_data = json.load(qhas_file)
39823982
self.assertIn("data", qhas_data)
39833983

3984+
def test_qnn_backend_seq_mse(self):
3985+
from executorch.backends.qualcomm.quantizer.quantizer import qnn_ptq_manager
3986+
3987+
o_ch, i_ch, kernel, padding = 32, 512, (1, 1), 0
3988+
module = Conv2dSingle( # noqa: F405
3989+
in_channel=i_ch,
3990+
out_channel=o_ch,
3991+
kernel_size=kernel,
3992+
padding=padding,
3993+
)
3994+
sample_input = (torch.randn(1, i_ch, 1, o_ch),)
3995+
# per-channel / per-block
3996+
quantizers = [
3997+
make_quantizer(),
3998+
make_quantizer(quant_dtype=QuantDtype.use_16a4w_block),
3999+
]
4000+
quantizers[-1].set_block_size_map({"conv2d": (1, 32, 1, 1)})
4001+
4002+
for i, quantizer in enumerate(quantizers):
4003+
with self.subTest(i=i):
4004+
ep = torch.export.export(module, sample_input).module()
4005+
prepared = prepare_pt2e(ep, quantizer)
4006+
with qnn_ptq_manager(prepared):
4007+
prepared(*sample_input)
4008+
converted = convert_pt2e(prepared)
4009+
self.lower_module_and_test_output(converted, sample_input)
4010+
39844011

39854012
class TestExampleLLMScript(TestQNN):
39864013
def test_llama3_2_1b(self):

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ def _kv_calibrate(
142142
updater=smart_mask_updater,
143143
use_i64_token=False,
144144
):
145+
from contextlib import nullcontext
146+
from executorch.backends.qualcomm.quantizer.quantizer import qnn_ptq_manager
147+
145148
_, atten_mask, _, k_caches, v_caches = example_inputs
146149

147150
# TODO: change criteria & support batch inputs if necessary
@@ -191,13 +194,14 @@ def _kv_calibrate(
191194
dim=-1,
192195
)
193196

194-
logits, new_k_caches, new_v_caches = module(
195-
tmp_token_list,
196-
tmp_atten_mask,
197-
tmp_pos,
198-
*k_caches,
199-
*v_caches,
200-
)
197+
with qnn_ptq_manager(module) if pos == max_seq_len-1 else nullcontext():
198+
logits, new_k_caches, new_v_caches = module(
199+
tmp_token_list,
200+
tmp_atten_mask,
201+
tmp_pos,
202+
*k_caches,
203+
*v_caches,
204+
)
201205
atten_mask, pos, k_caches, v_caches = updater(
202206
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
203207
)
@@ -647,10 +651,6 @@ def permute(w, heads):
647651
if args.ptq:
648652
start_quantize_ts = time.time()
649653
custom_annotations = (annotate_matmul_16a8w,)
650-
if args.llama_model == "stories110m":
651-
custom_annotations = custom_annotations + (
652-
annotate_linear_16a8w_in_affine_layer,
653-
)
654654
kv_quant_attrs = {}
655655
for i, llama_instance in enumerate(llama_instance_list):
656656
llama_instance.quantize(

examples/qualcomm/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ def build_executorch_binary(
374374
online_prepare=False,
375375
optrace=False,
376376
op_package_options: QnnExecuTorchOpPackageOptions = None,
377+
use_seq_mse=False,
377378
):
378379
"""
379380
A function to generate an ExecuTorch binary for Qualcomm platforms.
@@ -397,10 +398,14 @@ def build_executorch_binary(
397398
optrace (bool, optional): Enable optrace mode for performance analysis if set to True.
398399
op_package_options: Optional structure to specify op packages
399400
loaded and used by the backend.
401+
use_seq_mse (bool, optional): Optional flag to minimize mse error of activation range
400402
401403
Returns:
402404
None: The function writes the output to a specified .pte file.
403405
"""
406+
from contextlib import nullcontext
407+
from executorch.backends.qualcomm.quantizer.quantizer import qnn_ptq_manager
408+
404409
backend_options = generate_htp_compiler_spec(
405410
use_fp16=False if quant_dtype else True
406411
)
@@ -426,7 +431,8 @@ def build_executorch_binary(
426431
else:
427432
quantizer = custom_quantizer or make_quantizer(quant_dtype=quant_dtype)
428433
# ptq calibration
429-
annotated_model = ptq_calibrate(captured_model, quantizer, dataset)
434+
with qnn_ptq_manager(captured_model) if use_seq_mse else nullcontext():
435+
annotated_model = ptq_calibrate(captured_model, quantizer, dataset)
430436

431437
quantized_model = convert_pt2e(annotated_model)
432438
edge_prog_mgr = to_edge_transform_and_lower_to_qnn(

0 commit comments

Comments
 (0)