From 2b5cf9121aaaeaffff10cba84c8f56ad3d1cf847 Mon Sep 17 00:00:00 2001 From: Chun-I Tsai Date: Tue, 11 Mar 2025 10:14:17 +0800 Subject: [PATCH 1/4] Qualcomm AI Engine Direct - Add submodule quant config setting - Add API to qnn quantizer for setting submodule quant config --- backends/qualcomm/_passes/decompose_einsum.py | 3 + .../_passes/decompose_linalg_vector_norm.py | 3 + backends/qualcomm/_passes/utils.py | 14 +- backends/qualcomm/quantizer/quantizer.py | 186 +++++++++++------- backends/qualcomm/tests/models.py | 12 ++ backends/qualcomm/tests/test_qnn_delegate.py | 34 +++- backends/qualcomm/tests/utils.py | 20 +- backends/qualcomm/utils/constants.py | 1 + backends/transforms/decompose_sdpa.py | 3 + examples/qualcomm/utils.py | 24 ++- extension/llm/export/quantizer_lib.py | 33 ++-- 11 files changed, 224 insertions(+), 109 deletions(-) diff --git a/backends/qualcomm/_passes/decompose_einsum.py b/backends/qualcomm/_passes/decompose_einsum.py index cbf8cbf1249..046c1598311 100644 --- a/backends/qualcomm/_passes/decompose_einsum.py +++ b/backends/qualcomm/_passes/decompose_einsum.py @@ -8,6 +8,8 @@ from executorch.exir.pass_base import ExportPass, PassResult from torch.fx.experimental.proxy_tensor import make_fx +from .utils import copy_nn_module_stack + class DecomposeEinsum(ExportPass): """ @@ -36,6 +38,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: remap[f"arg1_{i+1}"] = arg for decomposed_node in decomposed_module.graph.nodes: + copy_nn_module_stack(node, decomposed_node) # This is the arg[0] equation string, which is not required anymore after decomposition if "arg0" in decomposed_node.name: continue diff --git a/backends/qualcomm/_passes/decompose_linalg_vector_norm.py b/backends/qualcomm/_passes/decompose_linalg_vector_norm.py index 7d70f5c9342..993f088da12 100644 --- a/backends/qualcomm/_passes/decompose_linalg_vector_norm.py +++ b/backends/qualcomm/_passes/decompose_linalg_vector_norm.py @@ -8,6 +8,8 @@ from executorch.exir import to_edge from executorch.exir.pass_base import ExportPass, PassResult +from .utils import copy_nn_module_stack + class LinalgVectorNorm(torch.nn.Module): def __init__(self, exp, dim, keepdim): @@ -62,6 +64,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: remap = {"x": node.args[0]} for decomposed_node in decomposed_module.graph.nodes: + copy_nn_module_stack(node, decomposed_node) # no need to copy existent 'output' if decomposed_node.op == "output": for user in node.users.copy(): diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index d538fe0d34f..adac5e627b8 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -8,7 +8,11 @@ import torch from executorch.backends.qualcomm.builders.utils import get_parameter -from executorch.backends.qualcomm.utils.constants import QCOM_DTYPE, QCOM_ENCODING +from executorch.backends.qualcomm.utils.constants import ( + QCOM_DTYPE, + QCOM_ENCODING, + QCOM_NN_MODULE_STACK, +) from executorch.exir.dialects._ops import ops as exir_ops from torch._subclasses import FakeTensor @@ -121,6 +125,14 @@ def get_passes_dependency_for_capture_program(): } +def copy_nn_module_stack(src, target): + """ + Copy meta["nn_module_stack"] from src node to target node if existing. + """ + if value := src.meta.get(QCOM_NN_MODULE_STACK): + target.meta[QCOM_NN_MODULE_STACK] = value + + def is_float_tensor(node: torch.fx.Node) -> bool: if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): return False diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 3620841aff9..e9c0699b7fe 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -3,6 +3,8 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import importlib +from dataclasses import dataclass from enum import IntEnum, unique from functools import partial from typing import Callable, Dict, Optional, Sequence, Set, Tuple @@ -58,7 +60,7 @@ class QuantDtype(IntEnum): use_8a8w = 4 -quant_config_dict = { +QUANT_CONFIG_DICT = { # PTQ (QuantDtype.use_16a16w, False): ( get_16a16w_qnn_ptq_config, @@ -123,6 +125,56 @@ class QuantDtype(IntEnum): } +@dataclass +class ModuleQConfig: + quant_dtype: QuantDtype = QuantDtype.use_8a8w + is_qat: bool = False + is_conv_per_channel: bool = False + is_linear_per_channel: bool = False + act_observer: Optional[ + torch.ao.quantization.observer.UniformQuantizationObserverBase + ] = None + + def __post_init__(self): + if (self.quant_dtype, self.is_qat) not in QUANT_CONFIG_DICT: + raise RuntimeError( + f"the quant config, (quant_dtype: {self.quant_dtype}, is_qat: {self.is_qat}) is not support" + ) + quant_config_func, per_channel_quant_config_func, per_block_quant_config_func = QUANT_CONFIG_DICT[ + (self.quant_dtype, self.is_qat) + ] + self.quant_config = ( + quant_config_func(act_observer=self.act_observer) + if self.act_observer + else quant_config_func() + ) + self.per_channel_quant_config = ( + per_channel_quant_config_func(act_observer=self.act_observer) + if self.act_observer + else per_channel_quant_config_func() + ) + self.per_block_quant_config = ( + per_block_quant_config_func(act_observer=act_observer) + if self.act_observer + else per_block_quant_config_func() + ) + self.use_per_channel_weight_quant_ops = set() + if self.is_conv_per_channel: + self.use_per_channel_weight_quant_ops.update( + { + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv_transpose2d.input, + } + ) + if self.is_linear_per_channel: + self.use_per_channel_weight_quant_ops.update( + { + torch.ops.aten.linear.default, + } + ) + + class QnnQuantizer(Quantizer): SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys()) @@ -130,14 +182,9 @@ def __init__(self): super().__init__() self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy() - self.is_qat = False - self.quant_dtype = QuantDtype.use_8a8w - self.quant_config: QuantizationConfig = get_8a8w_qnn_ptq_config() - self.per_channel_quant_config = get_ptq_per_channel_quant_config() - self.per_block_quant_config = get_ptq_per_block_quant_config() + self.default_quant_config = ModuleQConfig() + self.module_qconfig_dict: Dict[torch.nn.Module, ModuleQConfig] = {} self.block_size_map = {} - self.use_per_channel_weight_quant_ops: Set[OpOverload] = set() - self.use_per_block_weight_quant_ops: Set[OpOverload] = set() self.custom_quant_annotations: Sequence[Callable] = [] self.discard_nodes: Set[str] = set() @@ -155,41 +202,52 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None: for annotation_func in self.custom_quant_annotations: annotation_func(gm) - def _get_quant_config(self, op: torch.fx.Node) -> Optional[QuantizationConfig]: + def _get_submodule(self, node: torch.fx.Node): + """ + An example of nn_module_stack + { + 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'), + 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add') + } """ - Priority: + + nn_module_stack = node.meta.get("nn_module_stack") + if nn_module_stack: + module_source_str, module_str = list(nn_module_stack.values())[-1][ + -1 + ].rsplit(".", 1) + module_source = importlib.import_module(module_source_str) + return getattr(module_source, module_str) + return None + + def _get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]: + """ + How to pick: 1. is one of use_per_block_weight_quant_ops - 2. is one of use_per_channel_weight_quant_ops - 3. quant config + 2. Choose specific submodule config if given. + 3. Pick one if op belongs to use_per_channel_weight_quant_ops + 4. If not 2, pick normal quant config """ - target = op.target - if isinstance(target, str): + op = node.target + if isinstance(op, str): return - if target in self.use_per_block_weight_quant_ops: - if block_size := self.block_size_map.get(op.name): - self.per_block_quant_config.block_size = block_size - return self.per_block_quant_config - - if target in self.use_per_channel_weight_quant_ops: - return self.per_channel_quant_config + if block_size := self.block_size_map.get(op.name): + config = self.default_quant_config.per_block_quant_config + config.block_size = block_size + return config - if target in self.quant_ops: - return self.quant_config + config = self.module_qconfig_dict.get( + self._get_submodule(node), self.default_quant_config + ) - print(f"No quant config is implemented for op, {op}") + if op in config.use_per_channel_weight_quant_ops: + return config.per_channel_quant_config - def _update_per_block_weight_quant_ops(self, ops: Set[OpOverload], enable: bool): - if enable: - self.use_per_block_weight_quant_ops.update(ops) - else: - self.use_per_block_weight_quant_ops.difference_update(ops) + if op in self.quant_ops: + return config.quant_config - def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: bool): - if enable: - self.use_per_channel_weight_quant_ops.update(ops) - else: - self.use_per_channel_weight_quant_ops.difference_update(ops) + print(f"No quant config is implemented for op, {op}") def add_custom_quant_annotations( self, custom_quant_annotations: Sequence[Callable] @@ -212,52 +270,32 @@ def annotate(self, model: GraphModule) -> GraphModule: def get_supported_ops(self) -> Set[OpOverload]: return self.SUPPORTED_OPS - def set_quant_config( - self, quant_dtype: QuantDtype, is_qat=False, act_observer=None + def set_default_quant_config( + self, + quant_dtype: QuantDtype, + is_qat=False, + is_conv_per_channel=False, + is_linear_per_channel=False, + act_observer=None, ) -> None: - self.quant_dtype = quant_dtype - self.is_qat = is_qat - if (quant_dtype, is_qat) not in quant_config_dict: - raise RuntimeError( - f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support" - ) - - quant_config_fuc, per_channel_quant_config_fuc, per_block_quant_config_fuc = ( - quant_config_dict[(quant_dtype, is_qat)] - ) - self.quant_config = ( - quant_config_fuc(act_observer=act_observer) - if act_observer - else quant_config_fuc() + self.default_quant_config = ModuleQConfig( + quant_dtype, + is_qat, + is_conv_per_channel, + is_linear_per_channel, + act_observer, ) - self.per_channel_quant_config = ( - per_channel_quant_config_fuc(act_observer=act_observer) - if act_observer - else per_channel_quant_config_fuc() - ) - if per_block_quant_config_fuc is not None: - self.per_block_quant_config = ( - per_block_quant_config_fuc(act_observer=act_observer) - if act_observer - else per_block_quant_config_fuc() - ) def set_block_size_map(self, block_size_map: Dict[str, Tuple]) -> None: self.block_size_map = block_size_map - def set_per_block_conv_quant(self, enable: bool) -> None: - conv_ops = {torch.ops.aten.conv2d.default} - self._update_per_block_weight_quant_ops(conv_ops, enable) - - def set_per_channel_conv_quant(self, enable: bool) -> None: - conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default} - self._update_per_channel_weight_quant_ops(conv_ops, enable) - - def set_per_channel_linear_quant(self, enable: bool) -> None: - linear_ops = { - torch.ops.aten.linear.default, - } - self._update_per_channel_weight_quant_ops(linear_ops, enable) + def set_submodule_quant_config( + self, submodule: torch.nn.Module, module_qconfig: ModuleQConfig + ) -> None: + """ + Set the quant config specific for a submodule + """ + self.module_qconfig_dict[submodule] = module_qconfig def transform_for_annotation(self, model: GraphModule) -> GraphModule: return QnnPassManager().transform_for_annotation_pipeline(model) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 0857a597d88..face416e304 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1450,6 +1450,18 @@ def forward(self, x): return 10 - x +class SimpleSubModules(torch.nn.Module): + def __init__(self): + super().__init__() + self.add = Add() + self.sub = Sub() + + def forward(self, a, b, c, d): + lhs = self.add(a, b) + rhs = self.sub(c, d) + return torch.mul(lhs, rhs) + + class SumIntList(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 795459a9f77..8d563e1e48c 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -24,6 +24,7 @@ from executorch.backends.qualcomm.tests.utils import ( generate_context_binary, + ModuleQConfig, QuantDtype, TestQNN, validate_context_binary, @@ -1326,8 +1327,8 @@ def test_qnn_backend_element_wise_add(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + gm = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(gm, sample_input) index += 1 def test_qnn_backend_element_wise_and(self): @@ -1367,8 +1368,8 @@ def test_qnn_backend_element_wise_div(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + gm = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(gm, sample_input) index += 1 def test_qnn_backend_element_wise_mul(self): @@ -1395,8 +1396,8 @@ def test_qnn_backend_element_wise_mul(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + gm = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(gm, sample_input) index += 1 def test_qnn_backend_element_wise_or(self): @@ -1455,8 +1456,8 @@ def test_qnn_backend_element_wise_sub(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + gm = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(gm, sample_input) index += 1 def test_qnn_backend_elu(self): @@ -2122,6 +2123,23 @@ def test_qnn_backend_simple_model(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_submodules(self): + module = SimpleSubModules() # noqa: F405 + sample_input = ( + torch.rand(1, 3, 8, 8), + torch.rand(1, 3, 8, 8), + torch.rand(1, 3, 8, 8), + torch.rand(1, 3, 8, 8), + ) + + submodule_quant_config = { + Add: ModuleQConfig(QuantDtype.use_16a16w) # noqa: F405 + } + module = self.get_qdq_module( + module, sample_input, submodule_quant_config=submodule_quant_config + ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_topk_and_index(self): module = TopKandIndex() # noqa: F405 sample_input = (torch.randn(3, 10),) diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 41c56c08a85..2c7f7d47ed0 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -16,7 +16,11 @@ from executorch import exir from executorch.backends.qualcomm.qnn_preprocess import QnnBackend -from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.quantizer.quantizer import ( + ModuleQConfig, + QnnQuantizer, + QuantDtype, +) from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset from executorch.backends.qualcomm.utils.constants import ( QCOM_DTYPE, @@ -497,7 +501,6 @@ def get_qdq_module( self, module: torch.nn.Module, inputs: Tuple[torch.Tensor], - is_conv_per_block: Optional[bool] = False, is_conv_per_channel: Optional[bool] = True, is_linear_per_channel: Optional[bool] = False, custom_quant_annotations: Tuple[Callable] = (), @@ -505,6 +508,7 @@ def get_qdq_module( dynamic_shapes: Dict = None, bypass_check: bool = False, block_size_map: Dict[str, Tuple] = None, + submodule_quant_config: Optional[Dict[torch.nn.Module, ModuleQConfig]] = None, ) -> torch.fx.GraphModule: m = torch.export.export( module, inputs, dynamic_shapes=dynamic_shapes, strict=True @@ -513,9 +517,9 @@ def get_qdq_module( quantizer = make_quantizer( quant_dtype=quant_dtype, custom_annotations=custom_quant_annotations, - per_block_conv=is_conv_per_block, per_channel_conv=is_conv_per_channel, per_channel_linear=is_linear_per_channel, + submodule_quant_config = submodule_quant_config, ) if block_size_map is not None: quantizer.set_block_size_map(block_size_map) @@ -543,6 +547,7 @@ def get_prepared_qat_module( is_linear_per_channel: Optional[bool] = False, custom_quant_annotations: Tuple[Callable] = (), quant_dtype: QuantDtype = QuantDtype.use_8a8w, + submodule_quant_config: Optional[Dict[str, ModuleQConfig]] = None, ) -> torch.fx.GraphModule: m = torch.export.export_for_training(module, inputs, strict=True).module() @@ -551,12 +556,13 @@ def get_prepared_qat_module( custom_annotations=custom_quant_annotations, per_channel_conv=is_conv_per_channel, per_channel_linear=is_linear_per_channel, + is_qat=True, + submodule_quant_config=submodule_quant_config ) - if quant_dtype == QuantDtype.use_8a8w: - quantizer.set_quant_config(quant_dtype, is_qat=True) - else: - raise RuntimeError("Shuld not be here") + submodule_quant_config = submodule_quant_config or {} + for submodule, module_qconfig in submodule_quant_config.items(): + quantizer.set_submodule_quant_config(submodule, module_qconfig) prepared = prepare_qat_pt2e(m, quantizer) return torch.ao.quantization.move_exported_model_to_train(prepared) diff --git a/backends/qualcomm/utils/constants.py b/backends/qualcomm/utils/constants.py index ce917bf4115..04c7caa0906 100644 --- a/backends/qualcomm/utils/constants.py +++ b/backends/qualcomm/utils/constants.py @@ -21,6 +21,7 @@ QCOM_INSERTED_PERMUTE = "qnn_permute" QCOM_LAYOUT_CHANGE = "layout_change" QCOM_NUM_BLOCKS_PER_AXIS = "num_blocks_per_axis" +QCOM_NN_MODULE_STACK = "nn_module_stack" QCOM_OFFSET = "offset" QCOM_ORIG_DTYPE = "orig_dtype" QCOM_QUANTIZED_IO = "q_tensor_io" diff --git a/backends/transforms/decompose_sdpa.py b/backends/transforms/decompose_sdpa.py index 329dab96df2..73e9d986c3d 100644 --- a/backends/transforms/decompose_sdpa.py +++ b/backends/transforms/decompose_sdpa.py @@ -62,6 +62,9 @@ def call( # Copy node from decompose graph module for decomposed_node in decomposed_module.graph.nodes: + node.meta["nn_module_stack"] = decomposed_node.meta.get( + "nn_module_stack" + ) if decomposed_node.op == "placeholder": continue diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 2b2f32b037b..bb59c814b0a 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -14,12 +14,16 @@ import tempfile from pathlib import Path -from typing import Callable, List, Optional +from typing import Callable, Dict, List, Optional import numpy as np import torch -from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype +from executorch.backends.qualcomm.quantizer.quantizer import ( + ModuleQConfig, + QnnQuantizer, + QuantDtype, +) from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset from executorch.backends.qualcomm.utils.utils import ( generate_htp_compiler_spec, @@ -254,18 +258,24 @@ def qat_train(ori_model, captured_model, quantizer, dataset): def make_quantizer( quant_dtype: Optional[QuantDtype] = QuantDtype.use_8a8w, custom_annotations=(), - per_block_conv=False, per_channel_conv=True, per_channel_linear=False, act_observer=MovingAverageMinMaxObserver, is_qat=False, + submodule_quant_config: Optional[Dict[str, ModuleQConfig]] = None, ): quantizer = QnnQuantizer() quantizer.add_custom_quant_annotations(custom_annotations) - quantizer.set_per_block_conv_quant(per_block_conv) - quantizer.set_per_channel_conv_quant(per_channel_conv) - quantizer.set_per_channel_linear_quant(per_channel_linear) - quantizer.set_quant_config(quant_dtype, is_qat, act_observer) + quantizer.set_default_quant_config( + quant_dtype, + is_qat=is_qat, + is_conv_per_channel=per_channel_conv, + is_linear_per_channel=per_channel_linear, + act_observer=act_observer, + ) + submodule_quant_config = submodule_quant_config or {} + for submodule, module_qconfig in submodule_quant_config.items(): + quantizer.set_submodule_quant_config(submodule, module_qconfig) return quantizer diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index 40d81075d9f..e516ae4c0a9 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -166,30 +166,39 @@ def get_qnn_quantizer( backend == "qnn" ), f"The quantization config is for backend {backend} instead of qnn." qnn_quantizer = QnnQuantizer() # pyre-fixme[16] - qnn_quantizer.set_per_channel_conv_quant(enable=True) - qnn_quantizer.set_per_channel_linear_quant(enable=True) + # more custom quantization are supported including 16a4w etc. default to 8bit quantized custom_annotations = () if quant_config == "8a8w": quant_dtype = QuantDtype.use_8a8w # pyre-fixme[16] - qnn_quantizer.set_quant_config(quant_dtype, is_qat=is_qat) + qnn_quantizer.set_default_quant_config( + quant_dtype, + is_qat=is_qat, + is_conv_per_channel=True, + is_linear_per_channel=True, + ) elif quant_config == "16a16w": - quant_dtype = QuantDtype.use_16a16w # pyre-fixme[16] # Due to the error with 16a16w in Qnn Htp, we need to disable per channel linear quantization when use 16a16w # TODO: enable it after the issue is fixed logging.warning( "Disable per channel quantization for linear and conv due to the error with QNN HTP 16a16w." ) - qnn_quantizer.set_per_channel_conv_quant(enable=False) - qnn_quantizer.set_per_channel_linear_quant(enable=False) - qnn_quantizer.set_quant_config( - quant_dtype, is_qat=is_qat, act_observer=MinMaxObserver + quant_dtype = QuantDtype.use_16a16w # pyre-fixme[16] + qnn_quantizer.set_default_quant_config( + quant_dtype, + is_qat=is_qat, + is_conv_per_channel=False, + is_linear_per_channel=False, + act_observer=MinMaxObserver, ) elif quant_config == "16a4w": - # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. - quant_dtype = QuantDtype.use_16a4w - qnn_quantizer.set_quant_config( - quant_dtype, is_qat=is_qat, act_observer=MinMaxObserver + quant_dtype = QuantDtype.use_16a16w # pyre-fixme[16] + qnn_quantizer.set_default_quant_config( + quant_dtype, + is_qat=is_qat, + is_conv_per_channel=True, + is_linear_per_channel=True, + act_observer=MinMaxObserver, ) # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. custom_annotations = (custom_annotate_llama_matmul_16a8w,) From 8493349eb4a0c9ec8ab64e82a1c4f047d45f9273 Mon Sep 17 00:00:00 2001 From: Chun-I Tsai Date: Thu, 27 Mar 2025 14:46:50 +0800 Subject: [PATCH 2/4] Rebase --- backends/qualcomm/quantizer/quantizer.py | 11 ++++++----- backends/qualcomm/tests/test_qnn_delegate.py | 1 - 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index e9c0699b7fe..07aa11fecbb 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -153,11 +153,6 @@ def __post_init__(self): if self.act_observer else per_channel_quant_config_func() ) - self.per_block_quant_config = ( - per_block_quant_config_func(act_observer=act_observer) - if self.act_observer - else per_block_quant_config_func() - ) self.use_per_channel_weight_quant_ops = set() if self.is_conv_per_channel: self.use_per_channel_weight_quant_ops.update( @@ -173,6 +168,12 @@ def __post_init__(self): torch.ops.aten.linear.default, } ) + if per_block_quant_config_func: + self.per_block_quant_config = ( + per_block_quant_config_func(act_observer=self.act_observer) + if self.act_observer + else per_block_quant_config_func() + ) class QnnQuantizer(Quantizer): diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 8d563e1e48c..04ab7ee56f8 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1238,7 +1238,6 @@ def test_qnn_backend_conv2d_block(self): module = self.get_qdq_module( module, sample_input, - is_conv_per_block=True, quant_dtype=QuantDtype.use_16a4w_block, block_size_map={"conv2d": (1, 128, 1, 1)}, ) From f2ce0e7930f060394b3657268fbece0e50900641 Mon Sep 17 00:00:00 2001 From: Chun-I Tsai Date: Mon, 31 Mar 2025 16:13:03 +0800 Subject: [PATCH 3/4] Fix based on comments - Change to string based way to set up qconfig for submodule --- backends/qualcomm/_passes/utils.py | 5 +- backends/qualcomm/quantizer/quantizer.py | 96 +++++++++++++------- backends/qualcomm/tests/test_qnn_delegate.py | 17 +++- backends/qualcomm/tests/utils.py | 20 ++-- backends/qualcomm/utils/constants.py | 1 - examples/qualcomm/utils.py | 9 +- 6 files changed, 89 insertions(+), 59 deletions(-) diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index adac5e627b8..bcbbdb02258 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -11,7 +11,6 @@ from executorch.backends.qualcomm.utils.constants import ( QCOM_DTYPE, QCOM_ENCODING, - QCOM_NN_MODULE_STACK, ) from executorch.exir.dialects._ops import ops as exir_ops from torch._subclasses import FakeTensor @@ -129,8 +128,8 @@ def copy_nn_module_stack(src, target): """ Copy meta["nn_module_stack"] from src node to target node if existing. """ - if value := src.meta.get(QCOM_NN_MODULE_STACK): - target.meta[QCOM_NN_MODULE_STACK] = value + if value := src.meta.get("nn_module_stack"): + target.meta["nn_module_stack"] = value def is_float_tensor(node: torch.fx.Node) -> bool: diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 07aa11fecbb..8e65607dd84 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -3,11 +3,10 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import importlib from dataclasses import dataclass from enum import IntEnum, unique from functools import partial -from typing import Callable, Dict, Optional, Sequence, Set, Tuple +from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple import torch from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager @@ -140,9 +139,11 @@ def __post_init__(self): raise RuntimeError( f"the quant config, (quant_dtype: {self.quant_dtype}, is_qat: {self.is_qat}) is not support" ) - quant_config_func, per_channel_quant_config_func, per_block_quant_config_func = QUANT_CONFIG_DICT[ - (self.quant_dtype, self.is_qat) - ] + ( + quant_config_func, + per_channel_quant_config_func, + per_block_quant_config_func, + ) = QUANT_CONFIG_DICT[(self.quant_dtype, self.is_qat)] self.quant_config = ( quant_config_func(act_observer=self.act_observer) if self.act_observer @@ -184,7 +185,9 @@ def __init__(self): self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy() self.default_quant_config = ModuleQConfig() - self.module_qconfig_dict: Dict[torch.nn.Module, ModuleQConfig] = {} + self.submodule_qconfig_list: List[ + Tuple[Callable[[torch.fx.Node], bool], ModuleQConfig] + ] = [] self.block_size_map = {} self.custom_quant_annotations: Sequence[Callable] = [] @@ -203,44 +206,30 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None: for annotation_func in self.custom_quant_annotations: annotation_func(gm) - def _get_submodule(self, node: torch.fx.Node): - """ - An example of nn_module_stack - { - 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'), - 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add') - } - """ - - nn_module_stack = node.meta.get("nn_module_stack") - if nn_module_stack: - module_source_str, module_str = list(nn_module_stack.values())[-1][ - -1 - ].rsplit(".", 1) - module_source = importlib.import_module(module_source_str) - return getattr(module_source, module_str) - return None + def _get_submodule_qconfig(self, node: torch.fx.Node): + for func, qconfig in self.submodule_qconfig_list: + if func(node): + return qconfig + return self.default_quant_config def _get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]: """ How to pick: - 1. is one of use_per_block_weight_quant_ops - 2. Choose specific submodule config if given. + 1. is one of per_block_quant_config + 2. Pick specific submodule config if given. 3. Pick one if op belongs to use_per_channel_weight_quant_ops - 4. If not 2, pick normal quant config + 4. If not 3, pick normal quant config """ op = node.target if isinstance(op, str): return - if block_size := self.block_size_map.get(op.name): + if block_size := self.block_size_map.get(node.name): config = self.default_quant_config.per_block_quant_config config.block_size = block_size return config - config = self.module_qconfig_dict.get( - self._get_submodule(node), self.default_quant_config - ) + config = self._get_submodule_qconfig(node) if op in config.use_per_channel_weight_quant_ops: return config.per_channel_quant_config @@ -290,16 +279,55 @@ def set_default_quant_config( def set_block_size_map(self, block_size_map: Dict[str, Tuple]) -> None: self.block_size_map = block_size_map - def set_submodule_quant_config( - self, submodule: torch.nn.Module, module_qconfig: ModuleQConfig + def set_submodule_qconfig_list( + self, submodule_qconfig_list: List[Tuple[Callable, ModuleQConfig]] ) -> None: """ - Set the quant config specific for a submodule + Set specific quant config from a callback function. + If a node fits more than one callback, only apply the first one. """ - self.module_qconfig_dict[submodule] = module_qconfig + self.submodule_qconfig_list = submodule_qconfig_list def transform_for_annotation(self, model: GraphModule) -> GraphModule: return QnnPassManager().transform_for_annotation_pipeline(model) def validate(self, model: GraphModule) -> None: pass + + +def get_submodule_type_predicate(module_type_str): + """ + An example of nn_module_stack + { + 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'), + 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add') + } + """ + + def predicate(node): + if nn_module_stack := node.meta.get("nn_module_stack"): + for _, type_name in nn_module_stack.values(): + if module_type_str in type_name: + return True + return False + + return predicate + + +def get_submodule_name_predicate(module_name_str): + """ + An example of nn_module_stack + { + 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'), + 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add') + } + """ + + def predicate(node): + if nn_module_stack := node.meta.get("nn_module_stack"): + for name in nn_module_stack.keys(): + if module_name_str in name: + return True + return False + + return predicate diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 04ab7ee56f8..7e17fa11e4e 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -2131,11 +2131,20 @@ def test_qnn_backend_submodules(self): torch.rand(1, 3, 8, 8), ) - submodule_quant_config = { - Add: ModuleQConfig(QuantDtype.use_16a16w) # noqa: F405 - } + from executorch.backends.qualcomm.quantizer.quantizer import ( + get_submodule_type_predicate, + ) + + submodule_qconfig_list = [ + ( + get_submodule_type_predicate("Add"), + ModuleQConfig(QuantDtype.use_16a16w), + ) # noqa: F405 + ] module = self.get_qdq_module( - module, sample_input, submodule_quant_config=submodule_quant_config + module, + sample_input, + submodule_qconfig_list=submodule_qconfig_list, ) self.lower_module_and_test_output(module, sample_input) diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 2c7f7d47ed0..42eec15891c 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -16,11 +16,7 @@ from executorch import exir from executorch.backends.qualcomm.qnn_preprocess import QnnBackend -from executorch.backends.qualcomm.quantizer.quantizer import ( - ModuleQConfig, - QnnQuantizer, - QuantDtype, -) +from executorch.backends.qualcomm.quantizer.quantizer import ModuleQConfig, QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset from executorch.backends.qualcomm.utils.constants import ( QCOM_DTYPE, @@ -508,8 +504,9 @@ def get_qdq_module( dynamic_shapes: Dict = None, bypass_check: bool = False, block_size_map: Dict[str, Tuple] = None, - submodule_quant_config: Optional[Dict[torch.nn.Module, ModuleQConfig]] = None, + submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None, ) -> torch.fx.GraphModule: + module = module.eval() m = torch.export.export( module, inputs, dynamic_shapes=dynamic_shapes, strict=True ).module() @@ -519,7 +516,7 @@ def get_qdq_module( custom_annotations=custom_quant_annotations, per_channel_conv=is_conv_per_channel, per_channel_linear=is_linear_per_channel, - submodule_quant_config = submodule_quant_config, + submodule_qconfig_list=submodule_qconfig_list, ) if block_size_map is not None: quantizer.set_block_size_map(block_size_map) @@ -547,7 +544,7 @@ def get_prepared_qat_module( is_linear_per_channel: Optional[bool] = False, custom_quant_annotations: Tuple[Callable] = (), quant_dtype: QuantDtype = QuantDtype.use_8a8w, - submodule_quant_config: Optional[Dict[str, ModuleQConfig]] = None, + submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None, ) -> torch.fx.GraphModule: m = torch.export.export_for_training(module, inputs, strict=True).module() @@ -557,12 +554,11 @@ def get_prepared_qat_module( per_channel_conv=is_conv_per_channel, per_channel_linear=is_linear_per_channel, is_qat=True, - submodule_quant_config=submodule_quant_config + submodule_qconfig_list=submodule_qconfig_list, ) - submodule_quant_config = submodule_quant_config or {} - for submodule, module_qconfig in submodule_quant_config.items(): - quantizer.set_submodule_quant_config(submodule, module_qconfig) + submodule_qconfig_list = submodule_qconfig_list or [] + quantizer.set_submodule_qconfig_list(submodule_qconfig_list) prepared = prepare_qat_pt2e(m, quantizer) return torch.ao.quantization.move_exported_model_to_train(prepared) diff --git a/backends/qualcomm/utils/constants.py b/backends/qualcomm/utils/constants.py index 04c7caa0906..ce917bf4115 100644 --- a/backends/qualcomm/utils/constants.py +++ b/backends/qualcomm/utils/constants.py @@ -21,7 +21,6 @@ QCOM_INSERTED_PERMUTE = "qnn_permute" QCOM_LAYOUT_CHANGE = "layout_change" QCOM_NUM_BLOCKS_PER_AXIS = "num_blocks_per_axis" -QCOM_NN_MODULE_STACK = "nn_module_stack" QCOM_OFFSET = "offset" QCOM_ORIG_DTYPE = "orig_dtype" QCOM_QUANTIZED_IO = "q_tensor_io" diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index bb59c814b0a..b17bc8f98bd 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -14,7 +14,7 @@ import tempfile from pathlib import Path -from typing import Callable, Dict, List, Optional +from typing import Callable, List, Optional, Tuple import numpy as np @@ -262,7 +262,7 @@ def make_quantizer( per_channel_linear=False, act_observer=MovingAverageMinMaxObserver, is_qat=False, - submodule_quant_config: Optional[Dict[str, ModuleQConfig]] = None, + callback_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None, ): quantizer = QnnQuantizer() quantizer.add_custom_quant_annotations(custom_annotations) @@ -273,9 +273,8 @@ def make_quantizer( is_linear_per_channel=per_channel_linear, act_observer=act_observer, ) - submodule_quant_config = submodule_quant_config or {} - for submodule, module_qconfig in submodule_quant_config.items(): - quantizer.set_submodule_quant_config(submodule, module_qconfig) + callback_qconfig_list = callback_qconfig_list or [] + quantizer.set_submodule_qconfig_list(callback_qconfig_list) return quantizer From 137228aa28b01d303ffb45d2f1c9e6cd69948679 Mon Sep 17 00:00:00 2001 From: Chun-I Tsai Date: Tue, 8 Apr 2025 09:15:34 +0800 Subject: [PATCH 4/4] Fix lint --- backends/qualcomm/_passes/utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index bcbbdb02258..a8eb6b192ee 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -8,10 +8,7 @@ import torch from executorch.backends.qualcomm.builders.utils import get_parameter -from executorch.backends.qualcomm.utils.constants import ( - QCOM_DTYPE, - QCOM_ENCODING, -) +from executorch.backends.qualcomm.utils.constants import QCOM_DTYPE, QCOM_ENCODING from executorch.exir.dialects._ops import ops as exir_ops from torch._subclasses import FakeTensor