Skip to content

Commit f2ce0e7

Browse files
author
Chun-I Tsai
committed
Fix based on comments
- Change to string based way to set up qconfig for submodule
1 parent 8493349 commit f2ce0e7

File tree

6 files changed

+89
-59
lines changed

6 files changed

+89
-59
lines changed

backends/qualcomm/_passes/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from executorch.backends.qualcomm.utils.constants import (
1212
QCOM_DTYPE,
1313
QCOM_ENCODING,
14-
QCOM_NN_MODULE_STACK,
1514
)
1615
from executorch.exir.dialects._ops import ops as exir_ops
1716
from torch._subclasses import FakeTensor
@@ -129,8 +128,8 @@ def copy_nn_module_stack(src, target):
129128
"""
130129
Copy meta["nn_module_stack"] from src node to target node if existing.
131130
"""
132-
if value := src.meta.get(QCOM_NN_MODULE_STACK):
133-
target.meta[QCOM_NN_MODULE_STACK] = value
131+
if value := src.meta.get("nn_module_stack"):
132+
target.meta["nn_module_stack"] = value
134133

135134

136135
def is_float_tensor(node: torch.fx.Node) -> bool:

backends/qualcomm/quantizer/quantizer.py

Lines changed: 62 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
import importlib
76
from dataclasses import dataclass
87
from enum import IntEnum, unique
98
from functools import partial
10-
from typing import Callable, Dict, Optional, Sequence, Set, Tuple
9+
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple
1110

1211
import torch
1312
from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager
@@ -140,9 +139,11 @@ def __post_init__(self):
140139
raise RuntimeError(
141140
f"the quant config, (quant_dtype: {self.quant_dtype}, is_qat: {self.is_qat}) is not support"
142141
)
143-
quant_config_func, per_channel_quant_config_func, per_block_quant_config_func = QUANT_CONFIG_DICT[
144-
(self.quant_dtype, self.is_qat)
145-
]
142+
(
143+
quant_config_func,
144+
per_channel_quant_config_func,
145+
per_block_quant_config_func,
146+
) = QUANT_CONFIG_DICT[(self.quant_dtype, self.is_qat)]
146147
self.quant_config = (
147148
quant_config_func(act_observer=self.act_observer)
148149
if self.act_observer
@@ -184,7 +185,9 @@ def __init__(self):
184185
self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy()
185186

186187
self.default_quant_config = ModuleQConfig()
187-
self.module_qconfig_dict: Dict[torch.nn.Module, ModuleQConfig] = {}
188+
self.submodule_qconfig_list: List[
189+
Tuple[Callable[[torch.fx.Node], bool], ModuleQConfig]
190+
] = []
188191
self.block_size_map = {}
189192

190193
self.custom_quant_annotations: Sequence[Callable] = []
@@ -203,44 +206,30 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
203206
for annotation_func in self.custom_quant_annotations:
204207
annotation_func(gm)
205208

206-
def _get_submodule(self, node: torch.fx.Node):
207-
"""
208-
An example of nn_module_stack
209-
{
210-
'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
211-
'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
212-
}
213-
"""
214-
215-
nn_module_stack = node.meta.get("nn_module_stack")
216-
if nn_module_stack:
217-
module_source_str, module_str = list(nn_module_stack.values())[-1][
218-
-1
219-
].rsplit(".", 1)
220-
module_source = importlib.import_module(module_source_str)
221-
return getattr(module_source, module_str)
222-
return None
209+
def _get_submodule_qconfig(self, node: torch.fx.Node):
210+
for func, qconfig in self.submodule_qconfig_list:
211+
if func(node):
212+
return qconfig
213+
return self.default_quant_config
223214

224215
def _get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]:
225216
"""
226217
How to pick:
227-
1. is one of use_per_block_weight_quant_ops
228-
2. Choose specific submodule config if given.
218+
1. is one of per_block_quant_config
219+
2. Pick specific submodule config if given.
229220
3. Pick one if op belongs to use_per_channel_weight_quant_ops
230-
4. If not 2, pick normal quant config
221+
4. If not 3, pick normal quant config
231222
"""
232223
op = node.target
233224
if isinstance(op, str):
234225
return
235226

236-
if block_size := self.block_size_map.get(op.name):
227+
if block_size := self.block_size_map.get(node.name):
237228
config = self.default_quant_config.per_block_quant_config
238229
config.block_size = block_size
239230
return config
240231

241-
config = self.module_qconfig_dict.get(
242-
self._get_submodule(node), self.default_quant_config
243-
)
232+
config = self._get_submodule_qconfig(node)
244233

245234
if op in config.use_per_channel_weight_quant_ops:
246235
return config.per_channel_quant_config
@@ -290,16 +279,55 @@ def set_default_quant_config(
290279
def set_block_size_map(self, block_size_map: Dict[str, Tuple]) -> None:
291280
self.block_size_map = block_size_map
292281

293-
def set_submodule_quant_config(
294-
self, submodule: torch.nn.Module, module_qconfig: ModuleQConfig
282+
def set_submodule_qconfig_list(
283+
self, submodule_qconfig_list: List[Tuple[Callable, ModuleQConfig]]
295284
) -> None:
296285
"""
297-
Set the quant config specific for a submodule
286+
Set specific quant config from a callback function.
287+
If a node fits more than one callback, only apply the first one.
298288
"""
299-
self.module_qconfig_dict[submodule] = module_qconfig
289+
self.submodule_qconfig_list = submodule_qconfig_list
300290

301291
def transform_for_annotation(self, model: GraphModule) -> GraphModule:
302292
return QnnPassManager().transform_for_annotation_pipeline(model)
303293

304294
def validate(self, model: GraphModule) -> None:
305295
pass
296+
297+
298+
def get_submodule_type_predicate(module_type_str):
299+
"""
300+
An example of nn_module_stack
301+
{
302+
'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
303+
'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
304+
}
305+
"""
306+
307+
def predicate(node):
308+
if nn_module_stack := node.meta.get("nn_module_stack"):
309+
for _, type_name in nn_module_stack.values():
310+
if module_type_str in type_name:
311+
return True
312+
return False
313+
314+
return predicate
315+
316+
317+
def get_submodule_name_predicate(module_name_str):
318+
"""
319+
An example of nn_module_stack
320+
{
321+
'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'),
322+
'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add')
323+
}
324+
"""
325+
326+
def predicate(node):
327+
if nn_module_stack := node.meta.get("nn_module_stack"):
328+
for name in nn_module_stack.keys():
329+
if module_name_str in name:
330+
return True
331+
return False
332+
333+
return predicate

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2131,11 +2131,20 @@ def test_qnn_backend_submodules(self):
21312131
torch.rand(1, 3, 8, 8),
21322132
)
21332133

2134-
submodule_quant_config = {
2135-
Add: ModuleQConfig(QuantDtype.use_16a16w) # noqa: F405
2136-
}
2134+
from executorch.backends.qualcomm.quantizer.quantizer import (
2135+
get_submodule_type_predicate,
2136+
)
2137+
2138+
submodule_qconfig_list = [
2139+
(
2140+
get_submodule_type_predicate("Add"),
2141+
ModuleQConfig(QuantDtype.use_16a16w),
2142+
) # noqa: F405
2143+
]
21372144
module = self.get_qdq_module(
2138-
module, sample_input, submodule_quant_config=submodule_quant_config
2145+
module,
2146+
sample_input,
2147+
submodule_qconfig_list=submodule_qconfig_list,
21392148
)
21402149
self.lower_module_and_test_output(module, sample_input)
21412150

backends/qualcomm/tests/utils.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,7 @@
1616

1717
from executorch import exir
1818
from executorch.backends.qualcomm.qnn_preprocess import QnnBackend
19-
from executorch.backends.qualcomm.quantizer.quantizer import (
20-
ModuleQConfig,
21-
QnnQuantizer,
22-
QuantDtype,
23-
)
19+
from executorch.backends.qualcomm.quantizer.quantizer import ModuleQConfig, QuantDtype
2420
from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
2521
from executorch.backends.qualcomm.utils.constants import (
2622
QCOM_DTYPE,
@@ -508,8 +504,9 @@ def get_qdq_module(
508504
dynamic_shapes: Dict = None,
509505
bypass_check: bool = False,
510506
block_size_map: Dict[str, Tuple] = None,
511-
submodule_quant_config: Optional[Dict[torch.nn.Module, ModuleQConfig]] = None,
507+
submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None,
512508
) -> torch.fx.GraphModule:
509+
module = module.eval()
513510
m = torch.export.export(
514511
module, inputs, dynamic_shapes=dynamic_shapes, strict=True
515512
).module()
@@ -519,7 +516,7 @@ def get_qdq_module(
519516
custom_annotations=custom_quant_annotations,
520517
per_channel_conv=is_conv_per_channel,
521518
per_channel_linear=is_linear_per_channel,
522-
submodule_quant_config = submodule_quant_config,
519+
submodule_qconfig_list=submodule_qconfig_list,
523520
)
524521
if block_size_map is not None:
525522
quantizer.set_block_size_map(block_size_map)
@@ -547,7 +544,7 @@ def get_prepared_qat_module(
547544
is_linear_per_channel: Optional[bool] = False,
548545
custom_quant_annotations: Tuple[Callable] = (),
549546
quant_dtype: QuantDtype = QuantDtype.use_8a8w,
550-
submodule_quant_config: Optional[Dict[str, ModuleQConfig]] = None,
547+
submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None,
551548
) -> torch.fx.GraphModule:
552549
m = torch.export.export_for_training(module, inputs, strict=True).module()
553550

@@ -557,12 +554,11 @@ def get_prepared_qat_module(
557554
per_channel_conv=is_conv_per_channel,
558555
per_channel_linear=is_linear_per_channel,
559556
is_qat=True,
560-
submodule_quant_config=submodule_quant_config
557+
submodule_qconfig_list=submodule_qconfig_list,
561558
)
562559

563-
submodule_quant_config = submodule_quant_config or {}
564-
for submodule, module_qconfig in submodule_quant_config.items():
565-
quantizer.set_submodule_quant_config(submodule, module_qconfig)
560+
submodule_qconfig_list = submodule_qconfig_list or []
561+
quantizer.set_submodule_qconfig_list(submodule_qconfig_list)
566562

567563
prepared = prepare_qat_pt2e(m, quantizer)
568564
return torch.ao.quantization.move_exported_model_to_train(prepared)

backends/qualcomm/utils/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
QCOM_INSERTED_PERMUTE = "qnn_permute"
2222
QCOM_LAYOUT_CHANGE = "layout_change"
2323
QCOM_NUM_BLOCKS_PER_AXIS = "num_blocks_per_axis"
24-
QCOM_NN_MODULE_STACK = "nn_module_stack"
2524
QCOM_OFFSET = "offset"
2625
QCOM_ORIG_DTYPE = "orig_dtype"
2726
QCOM_QUANTIZED_IO = "q_tensor_io"

examples/qualcomm/utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import tempfile
1515
from pathlib import Path
1616

17-
from typing import Callable, Dict, List, Optional
17+
from typing import Callable, List, Optional, Tuple
1818

1919
import numpy as np
2020

@@ -262,7 +262,7 @@ def make_quantizer(
262262
per_channel_linear=False,
263263
act_observer=MovingAverageMinMaxObserver,
264264
is_qat=False,
265-
submodule_quant_config: Optional[Dict[str, ModuleQConfig]] = None,
265+
callback_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None,
266266
):
267267
quantizer = QnnQuantizer()
268268
quantizer.add_custom_quant_annotations(custom_annotations)
@@ -273,9 +273,8 @@ def make_quantizer(
273273
is_linear_per_channel=per_channel_linear,
274274
act_observer=act_observer,
275275
)
276-
submodule_quant_config = submodule_quant_config or {}
277-
for submodule, module_qconfig in submodule_quant_config.items():
278-
quantizer.set_submodule_quant_config(submodule, module_qconfig)
276+
callback_qconfig_list = callback_qconfig_list or []
277+
quantizer.set_submodule_qconfig_list(callback_qconfig_list)
279278
return quantizer
280279

281280

0 commit comments

Comments
 (0)