Skip to content

Commit 2b5cf91

Browse files
author
Chun-I Tsai
committed
Qualcomm AI Engine Direct - Add submodule quant config setting
- Add API to qnn quantizer for setting submodule quant config
1 parent 6adff9c commit 2b5cf91

File tree

11 files changed

+224
-109
lines changed

11 files changed

+224
-109
lines changed

backends/qualcomm/_passes/decompose_einsum.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from executorch.exir.pass_base import ExportPass, PassResult
99
from torch.fx.experimental.proxy_tensor import make_fx
1010

11+
from .utils import copy_nn_module_stack
12+
1113

1214
class DecomposeEinsum(ExportPass):
1315
"""
@@ -36,6 +38,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
3638
remap[f"arg1_{i+1}"] = arg
3739

3840
for decomposed_node in decomposed_module.graph.nodes:
41+
copy_nn_module_stack(node, decomposed_node)
3942
# This is the arg[0] equation string, which is not required anymore after decomposition
4043
if "arg0" in decomposed_node.name:
4144
continue

backends/qualcomm/_passes/decompose_linalg_vector_norm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from executorch.exir import to_edge
99
from executorch.exir.pass_base import ExportPass, PassResult
1010

11+
from .utils import copy_nn_module_stack
12+
1113

1214
class LinalgVectorNorm(torch.nn.Module):
1315
def __init__(self, exp, dim, keepdim):
@@ -62,6 +64,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
6264
remap = {"x": node.args[0]}
6365

6466
for decomposed_node in decomposed_module.graph.nodes:
67+
copy_nn_module_stack(node, decomposed_node)
6568
# no need to copy existent 'output'
6669
if decomposed_node.op == "output":
6770
for user in node.users.copy():

backends/qualcomm/_passes/utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88

99
import torch
1010
from executorch.backends.qualcomm.builders.utils import get_parameter
11-
from executorch.backends.qualcomm.utils.constants import QCOM_DTYPE, QCOM_ENCODING
11+
from executorch.backends.qualcomm.utils.constants import (
12+
QCOM_DTYPE,
13+
QCOM_ENCODING,
14+
QCOM_NN_MODULE_STACK,
15+
)
1216
from executorch.exir.dialects._ops import ops as exir_ops
1317
from torch._subclasses import FakeTensor
1418

@@ -121,6 +125,14 @@ def get_passes_dependency_for_capture_program():
121125
}
122126

123127

128+
def copy_nn_module_stack(src, target):
129+
"""
130+
Copy meta["nn_module_stack"] from src node to target node if existing.
131+
"""
132+
if value := src.meta.get(QCOM_NN_MODULE_STACK):
133+
target.meta[QCOM_NN_MODULE_STACK] = value
134+
135+
124136
def is_float_tensor(node: torch.fx.Node) -> bool:
125137
if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
126138
return False

backends/qualcomm/quantizer/quantizer.py

Lines changed: 112 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import importlib
7+
from dataclasses import dataclass
68
from enum import IntEnum, unique
79
from functools import partial
810
from typing import Callable, Dict, Optional, Sequence, Set, Tuple
@@ -58,7 +60,7 @@ class QuantDtype(IntEnum):
5860
use_8a8w = 4
5961

6062

61-
quant_config_dict = {
63+
QUANT_CONFIG_DICT = {
6264
# PTQ
6365
(QuantDtype.use_16a16w, False): (
6466
get_16a16w_qnn_ptq_config,
@@ -123,21 +125,66 @@ class QuantDtype(IntEnum):
123125
}
124126

125127

128+
@dataclass
129+
class ModuleQConfig:
130+
quant_dtype: QuantDtype = QuantDtype.use_8a8w
131+
is_qat: bool = False
132+
is_conv_per_channel: bool = False
133+
is_linear_per_channel: bool = False
134+
act_observer: Optional[
135+
torch.ao.quantization.observer.UniformQuantizationObserverBase
136+
] = None
137+
138+
def __post_init__(self):
139+
if (self.quant_dtype, self.is_qat) not in QUANT_CONFIG_DICT:
140+
raise RuntimeError(
141+
f"the quant config, (quant_dtype: {self.quant_dtype}, is_qat: {self.is_qat}) is not support"
142+
)
143+
quant_config_func, per_channel_quant_config_func, per_block_quant_config_func = QUANT_CONFIG_DICT[
144+
(self.quant_dtype, self.is_qat)
145+
]
146+
self.quant_config = (
147+
quant_config_func(act_observer=self.act_observer)
148+
if self.act_observer
149+
else quant_config_func()
150+
)
151+
self.per_channel_quant_config = (
152+
per_channel_quant_config_func(act_observer=self.act_observer)
153+
if self.act_observer
154+
else per_channel_quant_config_func()
155+
)
156+
self.per_block_quant_config = (
157+
per_block_quant_config_func(act_observer=act_observer)
158+
if self.act_observer
159+
else per_block_quant_config_func()
160+
)
161+
self.use_per_channel_weight_quant_ops = set()
162+
if self.is_conv_per_channel:
163+
self.use_per_channel_weight_quant_ops.update(
164+
{
165+
torch.ops.aten.conv1d.default,
166+
torch.ops.aten.conv2d.default,
167+
torch.ops.aten.conv_transpose2d.input,
168+
}
169+
)
170+
if self.is_linear_per_channel:
171+
self.use_per_channel_weight_quant_ops.update(
172+
{
173+
torch.ops.aten.linear.default,
174+
}
175+
)
176+
177+
126178
class QnnQuantizer(Quantizer):
127179
SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys())
128180

129181
def __init__(self):
130182
super().__init__()
131183
self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy()
132184

133-
self.is_qat = False
134-
self.quant_dtype = QuantDtype.use_8a8w
135-
self.quant_config: QuantizationConfig = get_8a8w_qnn_ptq_config()
136-
self.per_channel_quant_config = get_ptq_per_channel_quant_config()
137-
self.per_block_quant_config = get_ptq_per_block_quant_config()
185+
self.default_quant_config = ModuleQConfig()
186+
self.module_qconfig_dict: Dict[torch.nn.Module, ModuleQConfig] = {}
138187
self.block_size_map = {}
139-
self.use_per_channel_weight_quant_ops: Set[OpOverload] = set()
140-
self.use_per_block_weight_quant_ops: Set[OpOverload] = set()
141188

142189
self.custom_quant_annotations: Sequence[Callable] = []
143190
self.discard_nodes: Set[str] = set()
@@ -155,41 +202,52 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
155202
for annotation_func in self.custom_quant_annotations:
156203
annotation_func(gm)
157204

158-
def _get_quant_config(self, op: torch.fx.Node) -> Optional[QuantizationConfig]:
205+
def _get_submodule(self, node: torch.fx.Node):
206+
"""
207+
An example of nn_module_stack
208+
{
209+
'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
210+
'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
211+
}
159212
"""
160-
Priority:
213+
214+
nn_module_stack = node.meta.get("nn_module_stack")
215+
if nn_module_stack:
216+
module_source_str, module_str = list(nn_module_stack.values())[-1][
217+
-1
218+
].rsplit(".", 1)
219+
module_source = importlib.import_module(module_source_str)
220+
return getattr(module_source, module_str)
221+
return None
222+
223+
def _get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]:
224+
"""
225+
How to pick:
161226
1. is one of use_per_block_weight_quant_ops
162-
2. is one of use_per_channel_weight_quant_ops
163-
3. quant config
227+
2. Choose specific submodule config if given.
228+
3. Pick one if op belongs to use_per_channel_weight_quant_ops
229+
4. If not 2, pick normal quant config
164230
"""
165-
target = op.target
166-
if isinstance(target, str):
231+
op = node.target
232+
if isinstance(op, str):
167233
return
168234

169-
if target in self.use_per_block_weight_quant_ops:
170-
if block_size := self.block_size_map.get(op.name):
171-
self.per_block_quant_config.block_size = block_size
172-
return self.per_block_quant_config
173-
174-
if target in self.use_per_channel_weight_quant_ops:
175-
return self.per_channel_quant_config
235+
if block_size := self.block_size_map.get(op.name):
236+
config = self.default_quant_config.per_block_quant_config
237+
config.block_size = block_size
238+
return config
176239

177-
if target in self.quant_ops:
178-
return self.quant_config
240+
config = self.module_qconfig_dict.get(
241+
self._get_submodule(node), self.default_quant_config
242+
)
179243

180-
print(f"No quant config is implemented for op, {op}")
244+
if op in config.use_per_channel_weight_quant_ops:
245+
return config.per_channel_quant_config
181246

182-
def _update_per_block_weight_quant_ops(self, ops: Set[OpOverload], enable: bool):
183-
if enable:
184-
self.use_per_block_weight_quant_ops.update(ops)
185-
else:
186-
self.use_per_block_weight_quant_ops.difference_update(ops)
247+
if op in self.quant_ops:
248+
return config.quant_config
187249

188-
def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: bool):
189-
if enable:
190-
self.use_per_channel_weight_quant_ops.update(ops)
191-
else:
192-
self.use_per_channel_weight_quant_ops.difference_update(ops)
250+
print(f"No quant config is implemented for op, {op}")
193251

194252
def add_custom_quant_annotations(
195253
self, custom_quant_annotations: Sequence[Callable]
@@ -212,52 +270,32 @@ def annotate(self, model: GraphModule) -> GraphModule:
212270
def get_supported_ops(self) -> Set[OpOverload]:
213271
return self.SUPPORTED_OPS
214272

215-
def set_quant_config(
216-
self, quant_dtype: QuantDtype, is_qat=False, act_observer=None
273+
def set_default_quant_config(
274+
self,
275+
quant_dtype: QuantDtype,
276+
is_qat=False,
277+
is_conv_per_channel=False,
278+
is_linear_per_channel=False,
279+
act_observer=None,
217280
) -> None:
218-
self.quant_dtype = quant_dtype
219-
self.is_qat = is_qat
220-
if (quant_dtype, is_qat) not in quant_config_dict:
221-
raise RuntimeError(
222-
f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support"
223-
)
224-
225-
quant_config_fuc, per_channel_quant_config_fuc, per_block_quant_config_fuc = (
226-
quant_config_dict[(quant_dtype, is_qat)]
227-
)
228-
self.quant_config = (
229-
quant_config_fuc(act_observer=act_observer)
230-
if act_observer
231-
else quant_config_fuc()
281+
self.default_quant_config = ModuleQConfig(
282+
quant_dtype,
283+
is_qat,
284+
is_conv_per_channel,
285+
is_linear_per_channel,
286+
act_observer,
232287
)
233-
self.per_channel_quant_config = (
234-
per_channel_quant_config_fuc(act_observer=act_observer)
235-
if act_observer
236-
else per_channel_quant_config_fuc()
237-
)
238-
if per_block_quant_config_fuc is not None:
239-
self.per_block_quant_config = (
240-
per_block_quant_config_fuc(act_observer=act_observer)
241-
if act_observer
242-
else per_block_quant_config_fuc()
243-
)
244288

245289
def set_block_size_map(self, block_size_map: Dict[str, Tuple]) -> None:
246290
self.block_size_map = block_size_map
247291

248-
def set_per_block_conv_quant(self, enable: bool) -> None:
249-
conv_ops = {torch.ops.aten.conv2d.default}
250-
self._update_per_block_weight_quant_ops(conv_ops, enable)
251-
252-
def set_per_channel_conv_quant(self, enable: bool) -> None:
253-
conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default}
254-
self._update_per_channel_weight_quant_ops(conv_ops, enable)
255-
256-
def set_per_channel_linear_quant(self, enable: bool) -> None:
257-
linear_ops = {
258-
torch.ops.aten.linear.default,
259-
}
260-
self._update_per_channel_weight_quant_ops(linear_ops, enable)
292+
def set_submodule_quant_config(
293+
self, submodule: torch.nn.Module, module_qconfig: ModuleQConfig
294+
) -> None:
295+
"""
296+
Set the quant config specific for a submodule
297+
"""
298+
self.module_qconfig_dict[submodule] = module_qconfig
261299

262300
def transform_for_annotation(self, model: GraphModule) -> GraphModule:
263301
return QnnPassManager().transform_for_annotation_pipeline(model)

backends/qualcomm/tests/models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,6 +1450,18 @@ def forward(self, x):
14501450
return 10 - x
14511451

14521452

1453+
class SimpleSubModules(torch.nn.Module):
1454+
def __init__(self):
1455+
super().__init__()
1456+
self.add = Add()
1457+
self.sub = Sub()
1458+
1459+
def forward(self, a, b, c, d):
1460+
lhs = self.add(a, b)
1461+
rhs = self.sub(c, d)
1462+
return torch.mul(lhs, rhs)
1463+
1464+
14531465
class SumIntList(torch.nn.Module):
14541466
def __init__(self):
14551467
super().__init__()

0 commit comments

Comments
 (0)