Skip to content

Commit 2b155d1

Browse files
author
Chun-I Tsai
committed
Qualcomm AI Engine Direct - Add submodule quant config setting
- Add API to qnn quantizer for setting submodule quant config
1 parent 1ea101e commit 2b155d1

File tree

11 files changed

+224
-109
lines changed

11 files changed

+224
-109
lines changed

backends/qualcomm/_passes/decompose_einsum.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from executorch.exir.pass_base import ExportPass, PassResult
99
from torch.fx.experimental.proxy_tensor import make_fx
1010

11+
from .utils import copy_nn_module_stack
12+
1113

1214
class DecomposeEinsum(ExportPass):
1315
"""
@@ -36,6 +38,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
3638
remap[f"arg1_{i+1}"] = arg
3739

3840
for decomposed_node in decomposed_module.graph.nodes:
41+
copy_nn_module_stack(node, decomposed_node)
3942
# This is the arg[0] equation string, which is not required anymore after decomposition
4043
if "arg0" in decomposed_node.name:
4144
continue

backends/qualcomm/_passes/decompose_linalg_vector_norm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from executorch.exir import to_edge
99
from executorch.exir.pass_base import ExportPass, PassResult
1010

11+
from .utils import copy_nn_module_stack
12+
1113

1214
class LinalgVectorNorm(torch.nn.Module):
1315
def __init__(self, exp, dim, keepdim):
@@ -62,6 +64,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
6264
remap = {"x": node.args[0]}
6365

6466
for decomposed_node in decomposed_module.graph.nodes:
67+
copy_nn_module_stack(node, decomposed_node)
6568
# no need to copy existent 'output'
6669
if decomposed_node.op == "output":
6770
for user in node.users.copy():

backends/qualcomm/_passes/utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88

99
import torch
1010
from executorch.backends.qualcomm.builders.utils import get_parameter
11-
from executorch.backends.qualcomm.utils.constants import QCOM_DTYPE, QCOM_ENCODING
11+
from executorch.backends.qualcomm.utils.constants import (
12+
QCOM_DTYPE,
13+
QCOM_ENCODING,
14+
QCOM_NN_MODULE_STACK,
15+
)
1216
from executorch.exir.dialects._ops import ops as exir_ops
1317
from torch._subclasses import FakeTensor
1418

@@ -122,6 +126,14 @@ def get_passes_dependency_for_capture_program():
122126
}
123127

124128

129+
def copy_nn_module_stack(src, target):
130+
"""
131+
Copy meta["nn_module_stack"] from src node to target node if existing.
132+
"""
133+
if value := src.meta.get(QCOM_NN_MODULE_STACK):
134+
target.meta[QCOM_NN_MODULE_STACK] = value
135+
136+
125137
def is_float_tensor(node: torch.fx.Node) -> bool:
126138
if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
127139
return False

backends/qualcomm/quantizer/quantizer.py

Lines changed: 112 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import importlib
7+
from dataclasses import dataclass
68
from enum import IntEnum, unique
79
from functools import partial
810
from typing import Callable, Dict, Optional, Sequence, Set, Tuple
@@ -71,7 +73,7 @@ class QuantDtype(IntEnum):
7173
use_8a8w = 4
7274

7375

74-
quant_config_dict = {
76+
QUANT_CONFIG_DICT = {
7577
# PTQ
7678
(QuantDtype.use_16a16w, False): (
7779
get_16a16w_qnn_ptq_config,
@@ -136,21 +138,66 @@ class QuantDtype(IntEnum):
136138
}
137139

138140

141+
@dataclass
142+
class ModuleQConfig:
143+
quant_dtype: QuantDtype = QuantDtype.use_8a8w
144+
is_qat: bool = False
145+
is_conv_per_channel: bool = False
146+
is_linear_per_channel: bool = False
147+
act_observer: Optional[
148+
torch.ao.quantization.observer.UniformQuantizationObserverBase
149+
] = None
150+
151+
def __post_init__(self):
152+
if (self.quant_dtype, self.is_qat) not in QUANT_CONFIG_DICT:
153+
raise RuntimeError(
154+
f"the quant config, (quant_dtype: {self.quant_dtype}, is_qat: {self.is_qat}) is not support"
155+
)
156+
quant_config_func, per_channel_quant_config_func, per_block_quant_config_func = QUANT_CONFIG_DICT[
157+
(self.quant_dtype, self.is_qat)
158+
]
159+
self.quant_config = (
160+
quant_config_func(act_observer=self.act_observer)
161+
if self.act_observer
162+
else quant_config_func()
163+
)
164+
self.per_channel_quant_config = (
165+
per_channel_quant_config_func(act_observer=self.act_observer)
166+
if self.act_observer
167+
else per_channel_quant_config_func()
168+
)
169+
self.per_block_quant_config = (
170+
per_block_quant_config_func(act_observer=act_observer)
171+
if self.act_observer
172+
else per_block_quant_config_func()
173+
)
174+
self.use_per_channel_weight_quant_ops = set()
175+
if self.is_conv_per_channel:
176+
self.use_per_channel_weight_quant_ops.update(
177+
{
178+
torch.ops.aten.conv1d.default,
179+
torch.ops.aten.conv2d.default,
180+
torch.ops.aten.conv_transpose2d.input,
181+
}
182+
)
183+
if self.is_linear_per_channel:
184+
self.use_per_channel_weight_quant_ops.update(
185+
{
186+
torch.ops.aten.linear.default,
187+
}
188+
)
189+
190+
139191
class QnnQuantizer(Quantizer):
140192
SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys())
141193

142194
def __init__(self):
143195
super().__init__()
144196
self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy()
145197

146-
self.is_qat = False
147-
self.quant_dtype = QuantDtype.use_8a8w
148-
self.quant_config: QuantizationConfig = get_8a8w_qnn_ptq_config()
149-
self.per_channel_quant_config = get_ptq_per_channel_quant_config()
150-
self.per_block_quant_config = get_ptq_per_block_quant_config()
198+
self.default_quant_config = ModuleQConfig()
199+
self.module_qconfig_dict: Dict[torch.nn.Module, ModuleQConfig] = {}
151200
self.block_size_map = {}
152-
self.use_per_channel_weight_quant_ops: Set[OpOverload] = set()
153-
self.use_per_block_weight_quant_ops: Set[OpOverload] = set()
154201

155202
self.custom_quant_annotations: Sequence[Callable] = []
156203
self.discard_nodes: Set[str] = set()
@@ -168,41 +215,52 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
168215
for annotation_func in self.custom_quant_annotations:
169216
annotation_func(gm)
170217

171-
def _get_quant_config(self, op: torch.fx.Node) -> Optional[QuantizationConfig]:
218+
def _get_submodule(self, node: torch.fx.Node):
219+
"""
220+
An example of nn_module_stack
221+
{
222+
'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
223+
'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
224+
}
172225
"""
173-
Priority:
226+
227+
nn_module_stack = node.meta.get("nn_module_stack")
228+
if nn_module_stack:
229+
module_source_str, module_str = list(nn_module_stack.values())[-1][
230+
-1
231+
].rsplit(".", 1)
232+
module_source = importlib.import_module(module_source_str)
233+
return getattr(module_source, module_str)
234+
return None
235+
236+
def _get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]:
237+
"""
238+
How to pick:
174239
1. is one of use_per_block_weight_quant_ops
175-
2. is one of use_per_channel_weight_quant_ops
176-
3. quant config
240+
2. Choose specific submodule config if given.
241+
3. Pick one if op belongs to use_per_channel_weight_quant_ops
242+
4. If not 2, pick normal quant config
177243
"""
178-
target = op.target
179-
if isinstance(target, str):
244+
op = node.target
245+
if isinstance(op, str):
180246
return
181247

182-
if target in self.use_per_block_weight_quant_ops:
183-
if block_size := self.block_size_map.get(op.name):
184-
self.per_block_quant_config.block_size = block_size
185-
return self.per_block_quant_config
186-
187-
if target in self.use_per_channel_weight_quant_ops:
188-
return self.per_channel_quant_config
248+
if block_size := self.block_size_map.get(op.name):
249+
config = self.default_quant_config.per_block_quant_config
250+
config.block_size = block_size
251+
return config
189252

190-
if target in self.quant_ops:
191-
return self.quant_config
253+
config = self.module_qconfig_dict.get(
254+
self._get_submodule(node), self.default_quant_config
255+
)
192256

193-
print(f"No quant config is implemented for op, {op}")
257+
if op in config.use_per_channel_weight_quant_ops:
258+
return config.per_channel_quant_config
194259

195-
def _update_per_block_weight_quant_ops(self, ops: Set[OpOverload], enable: bool):
196-
if enable:
197-
self.use_per_block_weight_quant_ops.update(ops)
198-
else:
199-
self.use_per_block_weight_quant_ops.difference_update(ops)
260+
if op in self.quant_ops:
261+
return config.quant_config
200262

201-
def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: bool):
202-
if enable:
203-
self.use_per_channel_weight_quant_ops.update(ops)
204-
else:
205-
self.use_per_channel_weight_quant_ops.difference_update(ops)
263+
print(f"No quant config is implemented for op, {op}")
206264

207265
def add_custom_quant_annotations(
208266
self, custom_quant_annotations: Sequence[Callable]
@@ -225,52 +283,32 @@ def annotate(self, model: GraphModule) -> GraphModule:
225283
def get_supported_ops(self) -> Set[OpOverload]:
226284
return self.SUPPORTED_OPS
227285

228-
def set_quant_config(
229-
self, quant_dtype: QuantDtype, is_qat=False, act_observer=None
286+
def set_default_quant_config(
287+
self,
288+
quant_dtype: QuantDtype,
289+
is_qat=False,
290+
is_conv_per_channel=False,
291+
is_linear_per_channel=False,
292+
act_observer=None,
230293
) -> None:
231-
self.quant_dtype = quant_dtype
232-
self.is_qat = is_qat
233-
if (quant_dtype, is_qat) not in quant_config_dict:
234-
raise RuntimeError(
235-
f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support"
236-
)
237-
238-
quant_config_fuc, per_channel_quant_config_fuc, per_block_quant_config_fuc = (
239-
quant_config_dict[(quant_dtype, is_qat)]
240-
)
241-
self.quant_config = (
242-
quant_config_fuc(act_observer=act_observer)
243-
if act_observer
244-
else quant_config_fuc()
294+
self.default_quant_config = ModuleQConfig(
295+
quant_dtype,
296+
is_qat,
297+
is_conv_per_channel,
298+
is_linear_per_channel,
299+
act_observer,
245300
)
246-
self.per_channel_quant_config = (
247-
per_channel_quant_config_fuc(act_observer=act_observer)
248-
if act_observer
249-
else per_channel_quant_config_fuc()
250-
)
251-
if per_block_quant_config_fuc is not None:
252-
self.per_block_quant_config = (
253-
per_block_quant_config_fuc(act_observer=act_observer)
254-
if act_observer
255-
else per_block_quant_config_fuc()
256-
)
257301

258302
def set_block_size_map(self, block_size_map: Dict[str, Tuple]) -> None:
259303
self.block_size_map = block_size_map
260304

261-
def set_per_block_conv_quant(self, enable: bool) -> None:
262-
conv_ops = {torch.ops.aten.conv2d.default}
263-
self._update_per_block_weight_quant_ops(conv_ops, enable)
264-
265-
def set_per_channel_conv_quant(self, enable: bool) -> None:
266-
conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default}
267-
self._update_per_channel_weight_quant_ops(conv_ops, enable)
268-
269-
def set_per_channel_linear_quant(self, enable: bool) -> None:
270-
linear_ops = {
271-
torch.ops.aten.linear.default,
272-
}
273-
self._update_per_channel_weight_quant_ops(linear_ops, enable)
305+
def set_submodule_quant_config(
306+
self, submodule: torch.nn.Module, module_qconfig: ModuleQConfig
307+
) -> None:
308+
"""
309+
Set the quant config specific for a submodule
310+
"""
311+
self.module_qconfig_dict[submodule] = module_qconfig
274312

275313
def transform_for_annotation(self, model: GraphModule) -> GraphModule:
276314
model = ReduceDynamicRange()(model).graph_module

backends/qualcomm/tests/models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,6 +1452,18 @@ def forward(self, x):
14521452
return 10 - x
14531453

14541454

1455+
class SimpleSubModules(torch.nn.Module):
1456+
def __init__(self):
1457+
super().__init__()
1458+
self.add = Add()
1459+
self.sub = Sub()
1460+
1461+
def forward(self, a, b, c, d):
1462+
lhs = self.add(a, b)
1463+
rhs = self.sub(c, d)
1464+
return torch.mul(lhs, rhs)
1465+
1466+
14551467
class SumIntList(torch.nn.Module):
14561468
def __init__(self):
14571469
super().__init__()

0 commit comments

Comments
 (0)