diff --git a/backends/qualcomm/_passes/decompose_einsum.py b/backends/qualcomm/_passes/decompose_einsum.py index cbf8cbf1249..046c1598311 100644 --- a/backends/qualcomm/_passes/decompose_einsum.py +++ b/backends/qualcomm/_passes/decompose_einsum.py @@ -8,6 +8,8 @@ from executorch.exir.pass_base import ExportPass, PassResult from torch.fx.experimental.proxy_tensor import make_fx +from .utils import copy_nn_module_stack + class DecomposeEinsum(ExportPass): """ @@ -36,6 +38,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: remap[f"arg1_{i+1}"] = arg for decomposed_node in decomposed_module.graph.nodes: + copy_nn_module_stack(node, decomposed_node) # This is the arg[0] equation string, which is not required anymore after decomposition if "arg0" in decomposed_node.name: continue diff --git a/backends/qualcomm/_passes/decompose_linalg_vector_norm.py b/backends/qualcomm/_passes/decompose_linalg_vector_norm.py index 7d70f5c9342..993f088da12 100644 --- a/backends/qualcomm/_passes/decompose_linalg_vector_norm.py +++ b/backends/qualcomm/_passes/decompose_linalg_vector_norm.py @@ -8,6 +8,8 @@ from executorch.exir import to_edge from executorch.exir.pass_base import ExportPass, PassResult +from .utils import copy_nn_module_stack + class LinalgVectorNorm(torch.nn.Module): def __init__(self, exp, dim, keepdim): @@ -62,6 +64,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: remap = {"x": node.args[0]} for decomposed_node in decomposed_module.graph.nodes: + copy_nn_module_stack(node, decomposed_node) # no need to copy existent 'output' if decomposed_node.op == "output": for user in node.users.copy(): diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index d538fe0d34f..a8eb6b192ee 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -121,6 +121,14 @@ def get_passes_dependency_for_capture_program(): } +def copy_nn_module_stack(src, target): + """ + Copy meta["nn_module_stack"] from src node to target node if existing. + """ + if value := src.meta.get("nn_module_stack"): + target.meta["nn_module_stack"] = value + + def is_float_tensor(node: torch.fx.Node) -> bool: if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): return False diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 3620841aff9..8e65607dd84 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -3,9 +3,10 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass from enum import IntEnum, unique from functools import partial -from typing import Callable, Dict, Optional, Sequence, Set, Tuple +from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple import torch from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager @@ -58,7 +59,7 @@ class QuantDtype(IntEnum): use_8a8w = 4 -quant_config_dict = { +QUANT_CONFIG_DICT = { # PTQ (QuantDtype.use_16a16w, False): ( get_16a16w_qnn_ptq_config, @@ -123,6 +124,59 @@ class QuantDtype(IntEnum): } +@dataclass +class ModuleQConfig: + quant_dtype: QuantDtype = QuantDtype.use_8a8w + is_qat: bool = False + is_conv_per_channel: bool = False + is_linear_per_channel: bool = False + act_observer: Optional[ + torch.ao.quantization.observer.UniformQuantizationObserverBase + ] = None + + def __post_init__(self): + if (self.quant_dtype, self.is_qat) not in QUANT_CONFIG_DICT: + raise RuntimeError( + f"the quant config, (quant_dtype: {self.quant_dtype}, is_qat: {self.is_qat}) is not support" + ) + ( + quant_config_func, + per_channel_quant_config_func, + per_block_quant_config_func, + ) = QUANT_CONFIG_DICT[(self.quant_dtype, self.is_qat)] + self.quant_config = ( + quant_config_func(act_observer=self.act_observer) + if self.act_observer + else quant_config_func() + ) + self.per_channel_quant_config = ( + per_channel_quant_config_func(act_observer=self.act_observer) + if self.act_observer + else per_channel_quant_config_func() + ) + self.use_per_channel_weight_quant_ops = set() + if self.is_conv_per_channel: + self.use_per_channel_weight_quant_ops.update( + { + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv_transpose2d.input, + } + ) + if self.is_linear_per_channel: + self.use_per_channel_weight_quant_ops.update( + { + torch.ops.aten.linear.default, + } + ) + if per_block_quant_config_func: + self.per_block_quant_config = ( + per_block_quant_config_func(act_observer=self.act_observer) + if self.act_observer + else per_block_quant_config_func() + ) + + class QnnQuantizer(Quantizer): SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys()) @@ -130,14 +184,11 @@ def __init__(self): super().__init__() self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy() - self.is_qat = False - self.quant_dtype = QuantDtype.use_8a8w - self.quant_config: QuantizationConfig = get_8a8w_qnn_ptq_config() - self.per_channel_quant_config = get_ptq_per_channel_quant_config() - self.per_block_quant_config = get_ptq_per_block_quant_config() + self.default_quant_config = ModuleQConfig() + self.submodule_qconfig_list: List[ + Tuple[Callable[[torch.fx.Node], bool], ModuleQConfig] + ] = [] self.block_size_map = {} - self.use_per_channel_weight_quant_ops: Set[OpOverload] = set() - self.use_per_block_weight_quant_ops: Set[OpOverload] = set() self.custom_quant_annotations: Sequence[Callable] = [] self.discard_nodes: Set[str] = set() @@ -155,41 +206,38 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None: for annotation_func in self.custom_quant_annotations: annotation_func(gm) - def _get_quant_config(self, op: torch.fx.Node) -> Optional[QuantizationConfig]: + def _get_submodule_qconfig(self, node: torch.fx.Node): + for func, qconfig in self.submodule_qconfig_list: + if func(node): + return qconfig + return self.default_quant_config + + def _get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]: """ - Priority: - 1. is one of use_per_block_weight_quant_ops - 2. is one of use_per_channel_weight_quant_ops - 3. quant config + How to pick: + 1. is one of per_block_quant_config + 2. Pick specific submodule config if given. + 3. Pick one if op belongs to use_per_channel_weight_quant_ops + 4. If not 3, pick normal quant config """ - target = op.target - if isinstance(target, str): + op = node.target + if isinstance(op, str): return - if target in self.use_per_block_weight_quant_ops: - if block_size := self.block_size_map.get(op.name): - self.per_block_quant_config.block_size = block_size - return self.per_block_quant_config + if block_size := self.block_size_map.get(node.name): + config = self.default_quant_config.per_block_quant_config + config.block_size = block_size + return config - if target in self.use_per_channel_weight_quant_ops: - return self.per_channel_quant_config + config = self._get_submodule_qconfig(node) - if target in self.quant_ops: - return self.quant_config + if op in config.use_per_channel_weight_quant_ops: + return config.per_channel_quant_config - print(f"No quant config is implemented for op, {op}") - - def _update_per_block_weight_quant_ops(self, ops: Set[OpOverload], enable: bool): - if enable: - self.use_per_block_weight_quant_ops.update(ops) - else: - self.use_per_block_weight_quant_ops.difference_update(ops) + if op in self.quant_ops: + return config.quant_config - def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: bool): - if enable: - self.use_per_channel_weight_quant_ops.update(ops) - else: - self.use_per_channel_weight_quant_ops.difference_update(ops) + print(f"No quant config is implemented for op, {op}") def add_custom_quant_annotations( self, custom_quant_annotations: Sequence[Callable] @@ -212,55 +260,74 @@ def annotate(self, model: GraphModule) -> GraphModule: def get_supported_ops(self) -> Set[OpOverload]: return self.SUPPORTED_OPS - def set_quant_config( - self, quant_dtype: QuantDtype, is_qat=False, act_observer=None + def set_default_quant_config( + self, + quant_dtype: QuantDtype, + is_qat=False, + is_conv_per_channel=False, + is_linear_per_channel=False, + act_observer=None, ) -> None: - self.quant_dtype = quant_dtype - self.is_qat = is_qat - if (quant_dtype, is_qat) not in quant_config_dict: - raise RuntimeError( - f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support" - ) - - quant_config_fuc, per_channel_quant_config_fuc, per_block_quant_config_fuc = ( - quant_config_dict[(quant_dtype, is_qat)] - ) - self.quant_config = ( - quant_config_fuc(act_observer=act_observer) - if act_observer - else quant_config_fuc() + self.default_quant_config = ModuleQConfig( + quant_dtype, + is_qat, + is_conv_per_channel, + is_linear_per_channel, + act_observer, ) - self.per_channel_quant_config = ( - per_channel_quant_config_fuc(act_observer=act_observer) - if act_observer - else per_channel_quant_config_fuc() - ) - if per_block_quant_config_fuc is not None: - self.per_block_quant_config = ( - per_block_quant_config_fuc(act_observer=act_observer) - if act_observer - else per_block_quant_config_fuc() - ) def set_block_size_map(self, block_size_map: Dict[str, Tuple]) -> None: self.block_size_map = block_size_map - def set_per_block_conv_quant(self, enable: bool) -> None: - conv_ops = {torch.ops.aten.conv2d.default} - self._update_per_block_weight_quant_ops(conv_ops, enable) - - def set_per_channel_conv_quant(self, enable: bool) -> None: - conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default} - self._update_per_channel_weight_quant_ops(conv_ops, enable) - - def set_per_channel_linear_quant(self, enable: bool) -> None: - linear_ops = { - torch.ops.aten.linear.default, - } - self._update_per_channel_weight_quant_ops(linear_ops, enable) + def set_submodule_qconfig_list( + self, submodule_qconfig_list: List[Tuple[Callable, ModuleQConfig]] + ) -> None: + """ + Set specific quant config from a callback function. + If a node fits more than one callback, only apply the first one. + """ + self.submodule_qconfig_list = submodule_qconfig_list def transform_for_annotation(self, model: GraphModule) -> GraphModule: return QnnPassManager().transform_for_annotation_pipeline(model) def validate(self, model: GraphModule) -> None: pass + + +def get_submodule_type_predicate(module_type_str): + """ + An example of nn_module_stack + { + 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'), + 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add') + } + """ + + def predicate(node): + if nn_module_stack := node.meta.get("nn_module_stack"): + for _, type_name in nn_module_stack.values(): + if module_type_str in type_name: + return True + return False + + return predicate + + +def get_submodule_name_predicate(module_name_str): + """ + An example of nn_module_stack + { + 'L__self__': ('', 'executorch.backends.qualcomm.tests.models.SubModules'), + 'L__self___add': ('add', 'executorch.backends.qualcomm.tests.models.Add') + } + """ + + def predicate(node): + if nn_module_stack := node.meta.get("nn_module_stack"): + for name in nn_module_stack.keys(): + if module_name_str in name: + return True + return False + + return predicate diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 0857a597d88..face416e304 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1450,6 +1450,18 @@ def forward(self, x): return 10 - x +class SimpleSubModules(torch.nn.Module): + def __init__(self): + super().__init__() + self.add = Add() + self.sub = Sub() + + def forward(self, a, b, c, d): + lhs = self.add(a, b) + rhs = self.sub(c, d) + return torch.mul(lhs, rhs) + + class SumIntList(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 795459a9f77..7e17fa11e4e 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -24,6 +24,7 @@ from executorch.backends.qualcomm.tests.utils import ( generate_context_binary, + ModuleQConfig, QuantDtype, TestQNN, validate_context_binary, @@ -1237,7 +1238,6 @@ def test_qnn_backend_conv2d_block(self): module = self.get_qdq_module( module, sample_input, - is_conv_per_block=True, quant_dtype=QuantDtype.use_16a4w_block, block_size_map={"conv2d": (1, 128, 1, 1)}, ) @@ -1326,8 +1326,8 @@ def test_qnn_backend_element_wise_add(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + gm = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(gm, sample_input) index += 1 def test_qnn_backend_element_wise_and(self): @@ -1367,8 +1367,8 @@ def test_qnn_backend_element_wise_div(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + gm = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(gm, sample_input) index += 1 def test_qnn_backend_element_wise_mul(self): @@ -1395,8 +1395,8 @@ def test_qnn_backend_element_wise_mul(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + gm = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(gm, sample_input) index += 1 def test_qnn_backend_element_wise_or(self): @@ -1455,8 +1455,8 @@ def test_qnn_backend_element_wise_sub(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + gm = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(gm, sample_input) index += 1 def test_qnn_backend_elu(self): @@ -2122,6 +2122,32 @@ def test_qnn_backend_simple_model(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_submodules(self): + module = SimpleSubModules() # noqa: F405 + sample_input = ( + torch.rand(1, 3, 8, 8), + torch.rand(1, 3, 8, 8), + torch.rand(1, 3, 8, 8), + torch.rand(1, 3, 8, 8), + ) + + from executorch.backends.qualcomm.quantizer.quantizer import ( + get_submodule_type_predicate, + ) + + submodule_qconfig_list = [ + ( + get_submodule_type_predicate("Add"), + ModuleQConfig(QuantDtype.use_16a16w), + ) # noqa: F405 + ] + module = self.get_qdq_module( + module, + sample_input, + submodule_qconfig_list=submodule_qconfig_list, + ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_topk_and_index(self): module = TopKandIndex() # noqa: F405 sample_input = (torch.randn(3, 10),) diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 41c56c08a85..42eec15891c 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -16,7 +16,7 @@ from executorch import exir from executorch.backends.qualcomm.qnn_preprocess import QnnBackend -from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.quantizer.quantizer import ModuleQConfig, QuantDtype from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset from executorch.backends.qualcomm.utils.constants import ( QCOM_DTYPE, @@ -497,7 +497,6 @@ def get_qdq_module( self, module: torch.nn.Module, inputs: Tuple[torch.Tensor], - is_conv_per_block: Optional[bool] = False, is_conv_per_channel: Optional[bool] = True, is_linear_per_channel: Optional[bool] = False, custom_quant_annotations: Tuple[Callable] = (), @@ -505,7 +504,9 @@ def get_qdq_module( dynamic_shapes: Dict = None, bypass_check: bool = False, block_size_map: Dict[str, Tuple] = None, + submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None, ) -> torch.fx.GraphModule: + module = module.eval() m = torch.export.export( module, inputs, dynamic_shapes=dynamic_shapes, strict=True ).module() @@ -513,9 +514,9 @@ def get_qdq_module( quantizer = make_quantizer( quant_dtype=quant_dtype, custom_annotations=custom_quant_annotations, - per_block_conv=is_conv_per_block, per_channel_conv=is_conv_per_channel, per_channel_linear=is_linear_per_channel, + submodule_qconfig_list=submodule_qconfig_list, ) if block_size_map is not None: quantizer.set_block_size_map(block_size_map) @@ -543,6 +544,7 @@ def get_prepared_qat_module( is_linear_per_channel: Optional[bool] = False, custom_quant_annotations: Tuple[Callable] = (), quant_dtype: QuantDtype = QuantDtype.use_8a8w, + submodule_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None, ) -> torch.fx.GraphModule: m = torch.export.export_for_training(module, inputs, strict=True).module() @@ -551,12 +553,12 @@ def get_prepared_qat_module( custom_annotations=custom_quant_annotations, per_channel_conv=is_conv_per_channel, per_channel_linear=is_linear_per_channel, + is_qat=True, + submodule_qconfig_list=submodule_qconfig_list, ) - if quant_dtype == QuantDtype.use_8a8w: - quantizer.set_quant_config(quant_dtype, is_qat=True) - else: - raise RuntimeError("Shuld not be here") + submodule_qconfig_list = submodule_qconfig_list or [] + quantizer.set_submodule_qconfig_list(submodule_qconfig_list) prepared = prepare_qat_pt2e(m, quantizer) return torch.ao.quantization.move_exported_model_to_train(prepared) diff --git a/backends/transforms/decompose_sdpa.py b/backends/transforms/decompose_sdpa.py index 329dab96df2..73e9d986c3d 100644 --- a/backends/transforms/decompose_sdpa.py +++ b/backends/transforms/decompose_sdpa.py @@ -62,6 +62,9 @@ def call( # Copy node from decompose graph module for decomposed_node in decomposed_module.graph.nodes: + node.meta["nn_module_stack"] = decomposed_node.meta.get( + "nn_module_stack" + ) if decomposed_node.op == "placeholder": continue diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 2b2f32b037b..b17bc8f98bd 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -14,12 +14,16 @@ import tempfile from pathlib import Path -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Tuple import numpy as np import torch -from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype +from executorch.backends.qualcomm.quantizer.quantizer import ( + ModuleQConfig, + QnnQuantizer, + QuantDtype, +) from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset from executorch.backends.qualcomm.utils.utils import ( generate_htp_compiler_spec, @@ -254,18 +258,23 @@ def qat_train(ori_model, captured_model, quantizer, dataset): def make_quantizer( quant_dtype: Optional[QuantDtype] = QuantDtype.use_8a8w, custom_annotations=(), - per_block_conv=False, per_channel_conv=True, per_channel_linear=False, act_observer=MovingAverageMinMaxObserver, is_qat=False, + callback_qconfig_list: Optional[List[Tuple[Callable, ModuleQConfig]]] = None, ): quantizer = QnnQuantizer() quantizer.add_custom_quant_annotations(custom_annotations) - quantizer.set_per_block_conv_quant(per_block_conv) - quantizer.set_per_channel_conv_quant(per_channel_conv) - quantizer.set_per_channel_linear_quant(per_channel_linear) - quantizer.set_quant_config(quant_dtype, is_qat, act_observer) + quantizer.set_default_quant_config( + quant_dtype, + is_qat=is_qat, + is_conv_per_channel=per_channel_conv, + is_linear_per_channel=per_channel_linear, + act_observer=act_observer, + ) + callback_qconfig_list = callback_qconfig_list or [] + quantizer.set_submodule_qconfig_list(callback_qconfig_list) return quantizer diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index 40d81075d9f..e516ae4c0a9 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -166,30 +166,39 @@ def get_qnn_quantizer( backend == "qnn" ), f"The quantization config is for backend {backend} instead of qnn." qnn_quantizer = QnnQuantizer() # pyre-fixme[16] - qnn_quantizer.set_per_channel_conv_quant(enable=True) - qnn_quantizer.set_per_channel_linear_quant(enable=True) + # more custom quantization are supported including 16a4w etc. default to 8bit quantized custom_annotations = () if quant_config == "8a8w": quant_dtype = QuantDtype.use_8a8w # pyre-fixme[16] - qnn_quantizer.set_quant_config(quant_dtype, is_qat=is_qat) + qnn_quantizer.set_default_quant_config( + quant_dtype, + is_qat=is_qat, + is_conv_per_channel=True, + is_linear_per_channel=True, + ) elif quant_config == "16a16w": - quant_dtype = QuantDtype.use_16a16w # pyre-fixme[16] # Due to the error with 16a16w in Qnn Htp, we need to disable per channel linear quantization when use 16a16w # TODO: enable it after the issue is fixed logging.warning( "Disable per channel quantization for linear and conv due to the error with QNN HTP 16a16w." ) - qnn_quantizer.set_per_channel_conv_quant(enable=False) - qnn_quantizer.set_per_channel_linear_quant(enable=False) - qnn_quantizer.set_quant_config( - quant_dtype, is_qat=is_qat, act_observer=MinMaxObserver + quant_dtype = QuantDtype.use_16a16w # pyre-fixme[16] + qnn_quantizer.set_default_quant_config( + quant_dtype, + is_qat=is_qat, + is_conv_per_channel=False, + is_linear_per_channel=False, + act_observer=MinMaxObserver, ) elif quant_config == "16a4w": - # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. - quant_dtype = QuantDtype.use_16a4w - qnn_quantizer.set_quant_config( - quant_dtype, is_qat=is_qat, act_observer=MinMaxObserver + quant_dtype = QuantDtype.use_16a16w # pyre-fixme[16] + qnn_quantizer.set_default_quant_config( + quant_dtype, + is_qat=is_qat, + is_conv_per_channel=True, + is_linear_per_channel=True, + act_observer=MinMaxObserver, ) # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. custom_annotations = (custom_annotate_llama_matmul_16a8w,)