diff --git a/backends/arm/quantizer/TARGETS b/backends/arm/quantizer/TARGETS index a2445f26c0d..bbd7322daf5 100644 --- a/backends/arm/quantizer/TARGETS +++ b/backends/arm/quantizer/TARGETS @@ -5,12 +5,22 @@ python_library( srcs = ["arm_quantizer.py"], deps = [ ":arm_quantizer_utils", + ":quantization_annotator", "//caffe2:torch", - "//executorch/backends/arm/quantizer/quantization_annotation:quantization_annotation", "//executorch/exir:lib", ], ) +python_library( + name = "quantization_annotator", + srcs = ["quantization_annotator.py"], + deps = [ + ":arm_quantizer_utils", + ":quantization_config", + "//caffe2:torch", + ], +) + python_library( name = "quantization_config", srcs = ["quantization_config.py"], diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 8815d40b0b0..fe104db972b 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -13,24 +13,16 @@ from __future__ import annotations -import copy import functools -from typing import Any, Callable, Dict, List, Optional, Set +from typing import Any, Callable, Dict, List, Optional import torch -import torch.nn.functional as F from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.arm_quantizer_utils import ( - mark_nodes_as_annotated, - propagate_annotation, -) -from executorch.backends.arm.quantizer.quantization_annotation import ( - OP_TO_ANNOTATOR, - OperatorConfig, - OperatorPatternType, -) +from executorch.backends.arm.quantizer.arm_quantizer_utils import mark_node_as_annotated +from executorch.backends.arm.quantizer.quantization_annotator import annotate_graph + from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig from torch.ao.quantization.fake_quantize import ( FakeQuantize, @@ -58,44 +50,6 @@ ] -def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]: - supported_operators: Dict[str, List[OperatorPatternType]] = { - # Both conv and linear should be able to handle relu + hardtanh fusion since - # those are clamp ops - "conv2d": [ - [torch.nn.Conv2d, torch.nn.ReLU], - [torch.nn.Conv2d, F.relu], - [F.conv2d, torch.nn.ReLU], - [F.conv2d, F.relu], - ], - "linear": [[torch.nn.Linear], [F.linear]], - "add": [[torch.add]], - "max_pool2d": [[torch.nn.MaxPool2d], [F.max_pool2d]], - "adaptive_avg_pool2d": [ - [torch.nn.AdaptiveAvgPool2d], - [F.adaptive_avg_pool2d], - ], - "mul": [[torch.mul]], - "sub": [[torch.sub]], - "min_max": [[torch.min], [torch.max]], - } - return copy.deepcopy(supported_operators) - - -def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]: - supported_config_and_operators: List[OperatorConfig] = [] - for quantization_config in [ - get_symmetric_quantization_config(), - get_symmetric_quantization_config(is_per_channel=True), - ]: - ops = _supported_symmetric_quantized_operators() - for pattern_list in ops.values(): - supported_config_and_operators.append( - OperatorConfig(quantization_config, pattern_list) - ) - return copy.deepcopy(supported_config_and_operators) - - @functools.lru_cache def get_symmetric_quantization_config( is_per_channel: bool = False, @@ -180,10 +134,6 @@ def get_symmetric_quantization_config( return quantization_config -def _get_supported_config_and_operators() -> List[OperatorConfig]: - return _get_supported_symmetric_config_and_operators() - - NodeFilterType = Callable[[Node], bool] """Type for a Node Filter used by annotators. A Node filter is a function that takes a Node and returns whether the node should be annotated or not. @@ -255,26 +205,6 @@ def not_module_type_or_name_filter(n: Node) -> bool: class ArmQuantizer(Quantizer): - supported_config_and_operators = _get_supported_config_and_operators() - - # A list of supported static quantization annotators, in order of application. - # For example, fusions come before singular ops. - # The name must match the name used when registering the annotator. - STATIC_ANNOTATION_ORDER = [ - "linear", - "conv", - "adaptive_avg_pool2d", - "max_pool2d", - "add", - "sub", - "mul", - "min_max", - "mm", - "one_to_one", - "generic", - "upsample_nearest2d", - ] - def __init__(self) -> None: super().__init__() self.global_config: Optional[QuantizationConfig] = None @@ -331,7 +261,6 @@ def annotate(self, model: GraphModule) -> GraphModule: The annotated model. """ model = self._annotate_for_static_quantization_config(model) - propagate_annotation(model) return model def _annotate_all_static_patterns( @@ -353,8 +282,7 @@ def _annotate_all_static_patterns( if quantization_config is None: return model - for op in self.STATIC_ANNOTATION_ORDER: - OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) + annotate_graph(model, quantization_config, filter_fn) return model def _annotate_for_static_quantization_config( @@ -363,6 +291,9 @@ def _annotate_for_static_quantization_config( """Matches the correct QuantizationConfig with the correct module using a filter when running _annotate_all_static_patterns. """ + if self.io_config: + self._annotate_io(model, self.io_config) + module_name_list = list(self.module_name_config.keys()) for module_name, config in self.module_name_config.items(): self._annotate_all_static_patterns( @@ -381,9 +312,6 @@ def _annotate_for_static_quantization_config( _get_not_module_type_or_name_filter(tp_list, module_name_list), ) - if self.io_config: - self._annotate_io(model, self.io_config) - return model def _annotate_io( @@ -399,44 +327,13 @@ def _annotate_io( node, quantization_config.get_output_act_qspec(), ) - mark_nodes_as_annotated([node]) + mark_node_as_annotated(node) if node.op == "output": parent = node.all_input_nodes[0] _annotate_input_qspec_map( node, parent, quantization_config.get_input_act_qspec() ) - mark_nodes_as_annotated([node]) + mark_node_as_annotated(node) def validate(self, model: GraphModule) -> None: pass - - @classmethod - def get_supported_operators(cls) -> List[OperatorConfig]: - return cls.supported_config_and_operators - - @classmethod - def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: - op_configs: Set[QuantizationConfig] = set({}) - for spec, _ in cls.supported_config_and_operators: - op_configs.add(spec) - return list(op_configs) - - @classmethod - def get_supported_operator_for_quantization_config( - cls, quantization_config: Optional[QuantizationConfig] - ) -> List[OperatorPatternType]: - if quantization_config is None: - all_ops = [] - for _, ops in cls.supported_config_and_operators: - all_ops.extend(ops) - return all_ops - - for config, ops in cls.supported_config_and_operators: - # note: this assumes each entry in cls.supported_spec_and_operators - # corresponds to one spec, e.g. we don't have - # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)] - # where the first and second entry have the same spec but did not - # merge the op list - if config == quantization_config: - return ops - return [] diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index 4d52b7ddf16..7b460ccae74 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -11,17 +11,12 @@ # Utility functions for ArmQuantizer # -import operator -from typing import Callable, cast, List +from typing import cast import torch -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig from torch._subclasses import FakeTensor -from torch.ao.quantization.quantizer import ( - QuantizationAnnotation, - SharedQuantizationSpec, -) +from torch.ao.quantization.quantizer import QuantizationAnnotation from torch.fx import GraphModule, Node @@ -35,72 +30,30 @@ def is_annotated(node: Node) -> bool: ) -def are_annotated(nodes: List[Node]) -> bool: - """Given a list of nodes (that represents an operator pattern), - return True if any of the nodes - is annotated, otherwise return False. - """ - for node in nodes: - if is_annotated(node): - return True - return False +def is_output_annotated(node: Node) -> bool: + """Given a node, return whether the output of the node is annotated.""" + if "quantization_annotation" in node.meta: + annotation = cast(QuantizationAnnotation, node.meta["quantization_annotation"]) + return annotation._annotated and annotation.output_qspec is not None + else: + return False -def mark_nodes_as_annotated(nodes: List[Node]) -> None: - """Marks all nodes in list 'nodes' as annotated. If needed, an empty - QuantizationAnnotation is added to the quantization_annotation node meta entry. - """ - for node in nodes: - if node is not None: - if "quantization_annotation" not in node.meta: - node.meta["quantization_annotation"] = QuantizationAnnotation() - node.meta["quantization_annotation"]._annotated = True - - -def get_shared_qspec( - node: Node, gm: GraphModule, quantization_config: QuantizationConfig -): - """Returns a Quantization constallation with a SharedQuantizationSpec for the inputs - and output to the parameter 'node'. - Parameters: - node: a node with two inputs that should share Quantization parameters. - gm: The GraphModule containing the node. Used to inspect global graph features. - quantization_config : a QuantizationConfig with the input QuantizationSpec to share - Returns: - input_qspec_map: a dict[node, QuantizationSpec] that maps the inputs to 'node' to - the correct QuantizationSpec. - shared_with_input0_spec: The SharedQuantizationSpec to be used as output QuantizationSpec. - - Both outputs are None if one of the inputs is a node that can't be quantized. +def mark_node_as_annotated(node: Node) -> None: + """Marks node as annotated. If needed, an empty QuantizationAnnotation is added + to the quantization_annotation node meta entry. """ - input_act0 = cast(Node, node.args[0]) - input_act1 = node.args[1] - - input_act_qspec = quantization_config.get_input_act_qspec() - shared_with_input0_qspec = SharedQuantizationSpec((input_act0, node)) - - input_qspec_map = {} - if isinstance(input_act0, Node): - if not is_input_ok_for_quantization(input_act0, gm): - return None, None - input_qspec_map[input_act0] = input_act_qspec - - if isinstance(input_act1, Node): - if not is_input_ok_for_quantization(input_act1, gm): - return None, None - if input_act0 is not input_act1: - input_qspec_map[input_act1] = shared_with_input0_qspec - return input_qspec_map, shared_with_input0_qspec + if "quantization_annotation" not in node.meta: + node.meta["quantization_annotation"] = QuantizationAnnotation() + node.meta["quantization_annotation"]._annotated = True -def is_input_ok_for_quantization(input_act: Node, gm: GraphModule): - """Check if an input can be quantized. The input can not be quantized if: +def is_ok_for_quantization(node: Node, gm: GraphModule): + """Check if an node can be quantized. The node can not be quantized if: - The node does not output a float tensor or, - The node outputs a large scalar. """ - return not ( - is_input_non_float_tensor(input_act) or is_input_large_scalar(input_act, gm) - ) + return not (is_non_float_tensor(node) or is_large_scalar(node, gm)) def get_node_target(module: torch.nn.Module | GraphModule, target_str: str): @@ -110,7 +63,7 @@ def get_node_target(module: torch.nn.Module | GraphModule, target_str: str): return getattr(module, targets[-1]) -def is_input_large_scalar(node: Node, gm: GraphModule): +def is_large_scalar(node: Node, gm: GraphModule): """Check if input is a large scalar value. So that we can skip quantization for the node since histc op (in HistogramObserver) only works for values up to certain upper bound """ @@ -122,72 +75,10 @@ def is_input_large_scalar(node: Node, gm: GraphModule): return False -def is_input_non_float_tensor(node: Node) -> bool: +def is_non_float_tensor(node: Node) -> bool: """Check if the input is not a float tensor, so that we can skip quantization for the node since observers only works with float Tensors """ if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): return True return node.meta["val"].dtype != torch.float32 - - -def is_share_obs_or_fq_op(op: Callable) -> bool: - """Returns whether the the operation 'op' can be quantized using a shared observer or - fake quantizer. This means that the operation can inherit it's quantization spec - from parent nodes. - """ - return op in [ - torch.ops.aten.hardtanh.default, - torch.ops.aten.hardtanh_.default, - torch.ops.aten.relu.default, - torch.ops.aten.mean.default, - torch.ops.aten.mean.dim, - torch.ops.aten.permute.default, - torch.ops.aten.permute_copy.default, - # TODO: remove? - torch.ops.aten.adaptive_avg_pool2d.default, - torch.ops.aten.avg_pool2d.default, - torch.ops.aten.max_pool2d.default, - torch.ops.aten.full.default, - torch.ops.aten.flatten.using_ints, - torch.ops.aten.dropout.default, - operator.getitem, - ] - - -def propagate_annotation(model: GraphModule) -> None: - """For unannotated ops that can share observer or have fake quantizers, - annotate with a SharedQuantizationSpec, where the shared spec is the - output spec of the parent node. - This propagates output qspecs downward in the graph until - an op that is already annotated or can't share qspec is encountered. - """ - for n in model.graph.nodes: - n = cast(Node, n) - if is_annotated(n): - continue - if n.op != "call_function" or not is_share_obs_or_fq_op( - cast(Callable, n.target) - ): - continue - - prev_node = n.args[0] - if not isinstance(prev_node, Node): - continue - - quantization_annotation = cast( - QuantizationAnnotation | None, - prev_node.meta.get("quantization_annotation", None), - ) - if not quantization_annotation or not quantization_annotation.output_qspec: - continue - - # propagate the previous output_qspec to the current node - shared_qspec = SharedQuantizationSpec(prev_node) - n.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map={ - prev_node: shared_qspec, - }, - output_qspec=shared_qspec, - _annotated=True, - ) diff --git a/backends/arm/quantizer/quantization_annotation/TARGETS b/backends/arm/quantizer/quantization_annotation/TARGETS deleted file mode 100644 index 4ce8b5cad2c..00000000000 --- a/backends/arm/quantizer/quantization_annotation/TARGETS +++ /dev/null @@ -1,12 +0,0 @@ -load("@fbcode_macros//build_defs:python_library.bzl", "python_library") - -python_library( - name = "quantization_annotation", - srcs = glob(["*.py"]), - typing = True, - deps = [ - "//caffe2:torch", - "//executorch/backends/arm/quantizer:arm_quantizer_utils", - "//executorch/backends/arm/quantizer:quantization_config", - ], -) diff --git a/backends/arm/quantizer/quantization_annotation/__init__.py b/backends/arm/quantizer/quantization_annotation/__init__.py deleted file mode 100644 index d9d27cee2ac..00000000000 --- a/backends/arm/quantizer/quantization_annotation/__init__.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - - -from typing import Callable, Dict, List, NamedTuple, Optional - -import torch -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.fx import Node - -OperatorPatternType = List[Callable] -OperatorPatternType.__module__ = "executorch.backends.arm.quantizer.arm_quantizer_utils" - - -class OperatorConfig(NamedTuple): - # fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]] - # Basically we are mapping a quantization config to some list of patterns. - # a pattern is defined as a list of nn module, function or builtin function names - # e.g. [nn.Conv2d, torch.relu, torch.add] - # We have not resolved whether fusion can be considered internal details of the - # quantizer hence it does not need communication to user. - # Note this pattern is not really informative since it does not really - # tell us the graph structure resulting from the list of ops. - config: QuantizationConfig - operators: List[OperatorPatternType] - - -AnnotatorType = Callable[ - [ - torch.fx.GraphModule, - QuantizationConfig, - Optional[Callable[[Node], bool]], - ], - Optional[List[List[Node]]], -] -OP_TO_ANNOTATOR: Dict[str, AnnotatorType] = {} - - -def register_annotator(op: str): - def decorator(annotator: AnnotatorType): - OP_TO_ANNOTATOR[op] = annotator - - return decorator - - -from . import ( # noqa - adaptive_ang_pool2d_annotator, - add_annotator, - conv_annotator, - generic_annotator, - linear_annotator, - max_pool2d_annotator, - min_max_annotator, - mm_annotator, - mul_annotator, - one_to_one_annotator, - sub_annotator, - upsample_nearest2d_annotator, -) diff --git a/backends/arm/quantizer/quantization_annotation/adaptive_ang_pool2d_annotator.py b/backends/arm/quantizer/quantization_annotation/adaptive_ang_pool2d_annotator.py deleted file mode 100644 index 723a48f6644..00000000000 --- a/backends/arm/quantizer/quantization_annotation/adaptive_ang_pool2d_annotator.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -import itertools -from typing import Callable, List, Optional - -import torch -import torch.nn.functional as F -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import ( - QuantizationAnnotation, - SharedQuantizationSpec, -) -from torch.fx import Node -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions - - -@register_annotator("adaptive_avg_pool2d") -def _annotate_adaptive_avg_pool2d( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - """Always annotate adaptive_avg_pool2d op""" - module_partitions = get_source_partitions( - gm.graph, [torch.nn.AdaptiveAvgPool2d, F.adaptive_avg_pool2d], filter_fn - ) - partitions = list(itertools.chain.from_iterable(module_partitions.values())) - annotated_partitions = [] - for partition in partitions: - pool_node = partition.output_nodes[0] - if ( - pool_node.op != "call_function" - or pool_node.target != torch.ops.aten.adaptive_avg_pool2d.default - ): - raise ValueError(f"{pool_node} is not an aten adaptive_avg_pool2d operator") - - if arm_quantizer_utils.is_annotated(pool_node): - continue - - annotated_partitions.append(partition.nodes) - input_act = pool_node.args[0] - assert isinstance(input_act, Node) - - # only annotate input output sharing operator - # when the output of the input node is annotated - if ( - "quantization_annotation" not in input_act.meta - or not input_act.meta["quantization_annotation"]._annotated - or input_act.meta["quantization_annotation"].output_qspec is None - ): - input_act_qspec = quantization_config.get_input_act_qspec() - else: - input_act_qspec = SharedQuantizationSpec(input_act) - - # output sharing with input - output_act_qspec = SharedQuantizationSpec((input_act, pool_node)) - pool_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map={ - input_act: input_act_qspec, - }, - output_qspec=output_act_qspec, - _annotated=True, - ) - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/add_annotator.py b/backends/arm/quantizer/quantization_annotation/add_annotator.py deleted file mode 100644 index 600c5d31f69..00000000000 --- a/backends/arm/quantizer/quantization_annotation/add_annotator.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import Callable, List, Optional - -import torch -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import QuantizationAnnotation -from torch.fx import Node - - -@register_annotator("add") -def _annotate_add( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - annotated_partitions = [] - for node in gm.graph.nodes: - if node.target not in ( - torch.ops.aten.add.Tensor, - torch.ops.aten.add_.Tensor, - ): - continue - annotated_partitions.append(node) - add_node = node - if arm_quantizer_utils.is_annotated(add_node): - continue - - input_qspec_map, output_qspec = arm_quantizer_utils.get_shared_qspec( - add_node, gm, quantization_config - ) - if input_qspec_map is not None: - add_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=output_qspec, - _annotated=True, - ) - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/conv_annotator.py b/backends/arm/quantizer/quantization_annotation/conv_annotator.py deleted file mode 100644 index 4ff7dd9e800..00000000000 --- a/backends/arm/quantizer/quantization_annotation/conv_annotator.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree.f - -# pyre-unsafe - -from typing import Callable, List, Optional - -import torch -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import QuantizationAnnotation - -from torch.fx import Node - - -@register_annotator("conv") -def _annotate_conv( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - annotated_partitions = [] - for n in gm.graph.nodes: - if n.op != "call_function" or n.target not in [ - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, - ]: - continue - conv_node = n - - input_qspec_map = {} - input_act = conv_node.args[0] - assert isinstance(input_act, Node) - input_qspec_map[input_act] = quantization_config.get_input_act_qspec() - - weight = conv_node.args[1] - assert isinstance(weight, Node) - input_qspec_map[weight] = quantization_config.get_weight_qspec() - - # adding weight node to the partition as well - partition_nodes = [conv_node, conv_node.args[1]] - - bias = conv_node.args[2] if len(conv_node.args) > 2 else None - if isinstance(bias, Node): - input_qspec_map[bias] = quantization_config.get_bias_qspec() - partition_nodes.append(bias) - - if arm_quantizer_utils.are_annotated(partition_nodes): - continue - - if filter_fn and any(not filter_fn(n) for n in partition_nodes): - continue - - conv_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=quantization_config.get_output_act_qspec(), - _annotated=True, - ) - arm_quantizer_utils.mark_nodes_as_annotated(partition_nodes) - annotated_partitions.append(partition_nodes) - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/generic_annotator.py b/backends/arm/quantizer/quantization_annotation/generic_annotator.py deleted file mode 100644 index 5db29c95e8e..00000000000 --- a/backends/arm/quantizer/quantization_annotation/generic_annotator.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -from typing import Callable, List, Optional - -import torch -import torch.fx -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import SharedQuantizationSpec -from torch.ao.quantization.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, -) -from torch.fx import Node - - -_SUPPORTED_OPS = [ - # DATA LAYOUT OPS - torch.ops.aten.squeeze.default, - torch.ops.aten.squeeze_copy.default, - torch.ops.aten.squeeze_copy.dim, - torch.ops.aten.squeeze.dim, - torch.ops.aten.squeeze.dims, - torch.ops.aten.unsqueeze.default, - torch.ops.aten.unsqueeze_copy.default, - torch.ops.aten.reshape.default, - torch.ops.aten.repeat.default, - torch.ops.aten.expand_copy.default, - torch.ops.aten.expand.default, - # Disabling these as there seems to be an issue with support for complex - # datatypes in torch: - # torch.ops.aten.view_as_complex.default, - # torch.ops.aten.view_as_complex_copy.default, - # torch.ops.aten.view_as_real.default, - # torch.ops.aten.view_as_real_copy.default, - torch.ops.aten.view.default, - torch.ops.aten.view_as.default, - torch.ops.aten.view_copy.default, - torch.ops.aten.select.int, - torch.ops.aten.select_copy.int, - torch.ops.aten.slice.Tensor, - torch.ops.aten.slice_copy.Tensor, - torch.ops.aten.split.Tensor, - torch.ops.aten.split_with_sizes.default, - torch.ops.aten.transpose.Dimname, - torch.ops.aten.transpose.int, - torch.ops.aten.transpose_copy.int, - torch.ops.aten.tile.default, - torch.ops.aten.flip.default, - torch.ops.aten.cat.default, - torch.ops.aten.concatenate.default, - torch.ops.aten.stack.default, - torch.ops.aten.chunk.default, - torch.ops.aten.contiguous.default, -] - - -@register_annotator("generic") -def _annotate_generic( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - """Propagate qspecs to generic ops like unsqueeze, reshape etc.""" - annotated_partitions = [] - - for node in gm.graph.nodes: - if node.op != "call_function" or node.target not in _SUPPORTED_OPS: - continue - if filter_fn and not filter_fn(node): - continue - if arm_quantizer_utils.is_annotated(node): - continue - - input_acts = node.args[0] - - # Check to see if there are multiple inputs. - # this allows for stack/cat ops to be annotated - # in a similar way. - has_multi_inputs = isinstance(input_acts, list) - - input_act0 = input_acts[0] if has_multi_inputs else input_acts - - # Using a non-shared quantization spec here as a SharedQuantizationSpec - # can lead to a recursion. - _annotate_input_qspec_map( - node, input_act0, quantization_config.get_input_act_qspec() - ) - shared_with_input0_qspec = SharedQuantizationSpec((input_act0, node)) - - if has_multi_inputs: - # For the rest of the inputs, share qspec with first. - for input_act in input_acts[1:]: - if input_act is not input_act0: - node.meta["quantization_annotation"].input_qspec_map[ - input_act - ] = shared_with_input0_qspec - - _annotate_output_qspec(node, shared_with_input0_qspec) - arm_quantizer_utils.mark_nodes_as_annotated([node]) - annotated_partitions.append([node]) - - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/linear_annotator.py b/backends/arm/quantizer/quantization_annotation/linear_annotator.py deleted file mode 100644 index 7c3f91ec707..00000000000 --- a/backends/arm/quantizer/quantization_annotation/linear_annotator.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import Callable, List, Optional - -import torch -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, -) -from torch.fx import Node - - -@register_annotator("linear") -def _annotate_linear( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - annotated_partitions = [] - input_act_qspec = quantization_config.get_input_act_qspec() - output_act_qspec = quantization_config.get_output_act_qspec() - weight_qspec = quantization_config.get_weight_qspec() - bias_qspec = quantization_config.get_bias_qspec() - - for node in gm.graph.nodes: - if node.op != "call_function" or node.target != torch.ops.aten.linear.default: - continue - if filter_fn and not filter_fn(node): - continue - act_node = node.args[0] - weight_node = node.args[1] - bias_node = None - if len(node.args) > 2: - bias_node = node.args[2] - - if arm_quantizer_utils.is_annotated(node) is False: # type: ignore[list-item] - _annotate_input_qspec_map( - node, - act_node, - input_act_qspec, - ) - _annotate_input_qspec_map( - node, - weight_node, - weight_qspec, - ) - nodes_to_mark_annotated = [node, weight_node] - if bias_node: - _annotate_input_qspec_map( - node, - bias_node, - bias_qspec, - ) - nodes_to_mark_annotated.append(bias_node) - _annotate_output_qspec(node, output_act_qspec) - arm_quantizer_utils.mark_nodes_as_annotated(nodes_to_mark_annotated) - annotated_partitions.append(nodes_to_mark_annotated) - - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/max_pool2d_annotator.py b/backends/arm/quantizer/quantization_annotation/max_pool2d_annotator.py deleted file mode 100644 index 0ef2ee39fe5..00000000000 --- a/backends/arm/quantizer/quantization_annotation/max_pool2d_annotator.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -import itertools -from typing import Callable, List, Optional - -import torch -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import ( - QuantizationAnnotation, - SharedQuantizationSpec, -) -from torch.fx import Node -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions - - -@register_annotator("max_pool2d") -def _annotate_max_pool2d( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - module_partitions = get_source_partitions( - gm.graph, [torch.nn.MaxPool2d, torch.nn.functional.max_pool2d], filter_fn - ) - maxpool_partitions = list(itertools.chain.from_iterable(module_partitions.values())) - annotated_partitions = [] - for maxpool_partition in maxpool_partitions: - annotated_partitions.append(maxpool_partition.nodes) - output_node = maxpool_partition.output_nodes[0] - maxpool_node = None - for n in maxpool_partition.nodes: - if n.target == torch.ops.aten.max_pool2d.default: - maxpool_node = n - assert ( - maxpool_node is not None - ), "ArmQuantizer only works with torch.ops.aten.max_pool2d.default, " - "please make sure you are exporting the model correctly" - if arm_quantizer_utils.are_annotated([output_node, maxpool_node]): # type: ignore[list-item] - continue - - input_act = maxpool_node.args[0] # type: ignore[union-attr] - assert isinstance(input_act, Node) - - # only annotate maxpool when the output of the input node is annotated - if ( - "quantization_annotation" not in input_act.meta - or not input_act.meta["quantization_annotation"]._annotated - or input_act.meta["quantization_annotation"].output_qspec is None - ): - continue - # input and output of maxpool will share quantization parameter with input of maxpool - act_qspec = SharedQuantizationSpec(input_act) - # act_qspec = get_act_qspec(quantization_config) - maxpool_node.meta["quantization_annotation"] = QuantizationAnnotation( # type: ignore[union-attr] - input_qspec_map={ - input_act: act_qspec, - }, - _annotated=True, - ) - output_node.meta["quantization_annotation"] = QuantizationAnnotation( - output_qspec=act_qspec, - _annotated=True, - ) - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/min_max_annotator.py b/backends/arm/quantizer/quantization_annotation/min_max_annotator.py deleted file mode 100644 index 43c4d20c134..00000000000 --- a/backends/arm/quantizer/quantization_annotation/min_max_annotator.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import Callable, List, Optional - -import torch -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import QuantizationAnnotation -from torch.fx import GraphModule, Node - - -@register_annotator("min_max") -def _annotate_min_max( - gm: GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - annotated_partitions = [] - for node in gm.graph.nodes: - if node.target not in ( - torch.ops.aten.minimum.default, - torch.ops.aten.maximum.default, - ): - continue - annotated_partitions.append(node) - min_max_node = node - if arm_quantizer_utils.is_annotated(min_max_node): - continue - - input_qspec_map, output_qspec = arm_quantizer_utils.get_shared_qspec( - min_max_node, gm, quantization_config - ) - if input_qspec_map is not None: - min_max_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=output_qspec, - _annotated=True, - ) - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/mm_annotator.py b/backends/arm/quantizer/quantization_annotation/mm_annotator.py deleted file mode 100644 index 60d9adb1c3c..00000000000 --- a/backends/arm/quantizer/quantization_annotation/mm_annotator.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -import itertools -from typing import Callable, List, Optional - -import torch -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import QuantizationAnnotation -from torch.fx import Node -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions - - -@register_annotator("mm") -def _annotate_mm( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - mm_partitions = get_source_partitions( - gm.graph, [torch.mm, torch.bmm, torch.matmul], filter_fn - ) - mm_partitions = list(itertools.chain.from_iterable(mm_partitions.values())) - annotated_partitions = [] - for mm_partition in mm_partitions: - annotated_partitions.append(mm_partition.nodes) - mm_node = mm_partition.output_nodes[0] - - if arm_quantizer_utils.is_annotated(mm_node): - continue - - input_act_qspec = quantization_config.get_input_act_qspec() - output_act_qspec = quantization_config.get_output_act_qspec() - - input_qspec_map = {} - input_act0 = mm_node.args[0] - if isinstance(input_act0, Node): - if not arm_quantizer_utils.is_input_ok_for_quantization(input_act0, gm): - continue - input_qspec_map[input_act0] = input_act_qspec - - input_act1 = mm_node.args[1] - if isinstance(input_act1, Node): - if not arm_quantizer_utils.is_input_ok_for_quantization(input_act1, gm): - continue - input_qspec_map[input_act1] = input_act_qspec - - mm_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=output_act_qspec, - _annotated=True, - ) - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/mul_annotator.py b/backends/arm/quantizer/quantization_annotation/mul_annotator.py deleted file mode 100644 index 3a206c3aba8..00000000000 --- a/backends/arm/quantizer/quantization_annotation/mul_annotator.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import Callable, List, Optional - -import torch -import torch.fx -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import QuantizationAnnotation -from torch.fx import Node - - -@register_annotator("mul") -def _annotate_mul( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - - annotated_partitions = [] - for node in gm.graph.nodes: - if node.target not in (torch.ops.aten.mul.Tensor, torch.ops.aten.mul_.Tensor): - continue - mul_node = node - annotated_partitions.append([mul_node]) - if arm_quantizer_utils.is_annotated(mul_node): - continue - - input_act_qspec = quantization_config.get_input_act_qspec() - output_act_qspec = quantization_config.get_output_act_qspec() - - input_qspec_map = {} - input_act0 = mul_node.args[0] - if isinstance(input_act0, Node): - if arm_quantizer_utils.is_input_large_scalar(input_act0, gm): - continue - if arm_quantizer_utils.is_input_non_float_tensor(input_act0): - continue - input_qspec_map[input_act0] = input_act_qspec - - input_act1 = mul_node.args[1] - if isinstance(input_act1, Node): - if arm_quantizer_utils.is_input_large_scalar(input_act1, gm): - continue - if arm_quantizer_utils.is_input_non_float_tensor(input_act1): - continue - input_qspec_map[input_act1] = input_act_qspec - - mul_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=output_act_qspec, - _annotated=True, - ) - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py b/backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py deleted file mode 100644 index ec008826f00..00000000000 --- a/backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import Callable, List, Optional - -import torch -import torch.fx -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, -) -from torch.fx import Node - - -@register_annotator("one_to_one") -def _annotate_one_to_one( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - """ - This annotator adds the input and output qspec from the quantization config to - ops in 'one_to_one_ops' that have the following properties: - - Have a single input and single output. - - Can handle different qspecs on the input and output. - - Typical ops are ops implemented with a lookup table. - """ - annotated_partitions = [] - one_to_one_ops = ( - torch.ops.aten.exp.default, - torch.ops.aten.log.default, - torch.ops.aten.reciprocal.default, - torch.ops.aten.rsqrt.default, - torch.ops.aten.sigmoid.default, - torch.ops.aten.tanh.default, - torch.ops.aten.sum.dim_IntList, - ) - for node in gm.graph.nodes: - if node.op != "call_function" or node.target not in one_to_one_ops: - continue - if filter_fn and not filter_fn(node): - continue - input_node = node.args[0] - - if not arm_quantizer_utils.is_annotated(node): - _annotate_input_qspec_map( - node, - input_node, - quantization_config.get_input_act_qspec(), - ) - _annotate_output_qspec(node, quantization_config.get_output_act_qspec()) - - arm_quantizer_utils.mark_nodes_as_annotated([node]) - annotated_partitions.append([node]) - - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/sub_annotator.py b/backends/arm/quantizer/quantization_annotation/sub_annotator.py deleted file mode 100644 index 437f3e22e75..00000000000 --- a/backends/arm/quantizer/quantization_annotation/sub_annotator.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import Callable, List, Optional - -import torch -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import QuantizationAnnotation -from torch.fx import GraphModule, Node - - -@register_annotator("sub") -def _annotate_sub( - gm: GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - annotated_partitions = [] - for node in gm.graph.nodes: - if node.target not in (torch.ops.aten.sub.Tensor, torch.ops.aten.sub_.Tensor): - continue - annotated_partitions.append(node) - sub_node = node - if arm_quantizer_utils.is_annotated(sub_node): - continue - - input_qspec_map, output_qspec = arm_quantizer_utils.get_shared_qspec( - sub_node, gm, quantization_config - ) - if input_qspec_map is not None: - sub_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=output_qspec, - _annotated=True, - ) - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotation/upsample_nearest2d_annotator.py b/backends/arm/quantizer/quantization_annotation/upsample_nearest2d_annotator.py deleted file mode 100644 index eb4daf3a5c3..00000000000 --- a/backends/arm/quantizer/quantization_annotation/upsample_nearest2d_annotator.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -import itertools -from typing import Callable, List, Optional - -import torch -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import ( - QuantizationAnnotation, - SharedQuantizationSpec, -) -from torch.fx import Node -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions - - -def _filter_upsample_nearest2d(filter_fn: Optional[Callable[[Node], bool]] = None): - def filter(node: Node): - is_upsample = node.target == torch.ops.aten.upsample_nearest2d.vec - if filter_fn is None: - return is_upsample - else: - return is_upsample and filter_fn(node) - - return filter - - -@register_annotator("upsample_nearest2d") -def _annotate_upsample_nearest2d( - gm: torch.fx.GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - module_partitions = get_source_partitions( - gm.graph, - [ - torch.nn.UpsamplingNearest2d, - torch.nn.Upsample, - torch.nn.functional.interpolate, - ], - _filter_upsample_nearest2d(filter_fn), - ) - upsample_partitions = list( - itertools.chain.from_iterable(module_partitions.values()) - ) - annotated_partitions = [] - - for upsample_partition in upsample_partitions: - annotated_partitions.append(upsample_partition.nodes) - - assert len(upsample_partition.nodes) == 1 - upsample_node = upsample_partition.nodes[0] - - input_act = upsample_node.args[0] - assert isinstance(input_act, Node) - - input_act_qspec = quantization_config.get_input_act_qspec() - output_act_qspec = SharedQuantizationSpec((input_act, upsample_node)) - - upsample_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map={ - input_act: input_act_qspec, - }, - output_qspec=output_act_qspec, - _annotated=True, - ) - - return annotated_partitions diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py new file mode 100644 index 00000000000..0b29570f36f --- /dev/null +++ b/backends/arm/quantizer/quantization_annotator.py @@ -0,0 +1,298 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator +from dataclasses import dataclass +from typing import Callable, List, Optional + +import torch +import torch.fx +from executorch.backends.arm.quantizer import arm_quantizer_utils +from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig +from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec +from torch.ao.quantization.quantizer.utils import ( + _annotate_input_qspec_map, + _annotate_output_qspec, +) +from torch.fx import Node + + +@dataclass(frozen=True) +class _QuantProperty: + """Specify how the input/output at 'index' must be quantized.""" + + index: int + qspec: type[QuantizationSpecBase] | List[type[QuantizationSpecBase]] + optional: bool = False + mark_annotated: bool = False + + +class _OpQuantProperties: + def __init__(self): + self.quant_inputs: List[_QuantProperty] = [] + self.quant_output: Optional[_QuantProperty] = None + + +def _as_list(x): + if isinstance(x, list): + return x + else: + return [ + x, + ] + + +def _is_ok_for_quantization( + node: Node, quant_property: _QuantProperty, gm: torch.fx.GraphModule +) -> bool: + if quant_property.optional and ( + quant_property.index >= len(node.args) + or node.args[quant_property.index] is None + ): + return True + + for n_arg in _as_list(node.args[quant_property.index]): + assert isinstance(n_arg, Node) + if not arm_quantizer_utils.is_ok_for_quantization(n_arg, gm): + return False + + return True + + +def _annotate_input(node: Node, quant_property: _QuantProperty): + assert not arm_quantizer_utils.is_annotated(node) + if quant_property.optional and ( + quant_property.index >= len(node.args) + or node.args[quant_property.index] is None + ): + return + + for n_arg, qspec in zip( + _as_list(node.args[quant_property.index]), + _as_list(quant_property.qspec), + strict=True, + ): + assert isinstance(n_arg, Node) + _annotate_input_qspec_map(node, n_arg, qspec) + if quant_property.mark_annotated: + arm_quantizer_utils.mark_node_as_annotated(n_arg) + + +def _annotate_output(node: Node, quant_property: _QuantProperty): + assert not arm_quantizer_utils.is_annotated(node) + assert not quant_property.mark_annotated + assert not quant_property.optional + assert quant_property.index == 0, "Only one output annotation supported currently" + + _annotate_output_qspec(node, quant_property.qspec) + + +_one_to_one = [ + torch.ops.aten.exp.default, + torch.ops.aten.log.default, + torch.ops.aten.reciprocal.default, + torch.ops.aten.rsqrt.default, + torch.ops.aten.sigmoid.default, + torch.ops.aten.tanh.default, + torch.ops.aten.sum.dim_IntList, +] + +_one_to_one_shared_input_qspec = [ + torch.ops.aten.squeeze.default, + torch.ops.aten.squeeze_copy.default, + torch.ops.aten.squeeze_copy.dim, + torch.ops.aten.squeeze.dim, + torch.ops.aten.squeeze.dims, + torch.ops.aten.unsqueeze.default, + torch.ops.aten.unsqueeze_copy.default, + torch.ops.aten.reshape.default, + torch.ops.aten.repeat.default, + torch.ops.aten.expand_copy.default, + torch.ops.aten.expand.default, + # Disabling these as there seems to be an issue with support for complex + # datatypes in torch: + # torch.ops.aten.view_as_complex.default, + # torch.ops.aten.view_as_complex_copy.default, + # torch.ops.aten.view_as_real.default, + # torch.ops.aten.view_as_real_copy.default, + torch.ops.aten.view.default, + torch.ops.aten.view_as.default, + torch.ops.aten.view_copy.default, + torch.ops.aten.select.int, + torch.ops.aten.select_copy.int, + torch.ops.aten.slice.Tensor, + torch.ops.aten.slice_copy.Tensor, + torch.ops.aten.split.Tensor, + torch.ops.aten.split_with_sizes.default, + torch.ops.aten.transpose.Dimname, + torch.ops.aten.transpose.int, + torch.ops.aten.transpose_copy.int, + torch.ops.aten.tile.default, + torch.ops.aten.flip.default, + torch.ops.aten.chunk.default, + torch.ops.aten.contiguous.default, + torch.ops.aten.upsample_nearest2d.vec, +] + +# Operators that can inherit the quantization specs from its parent node +# as SharedQuantizationSpec. +_parent_shared_qspec = [ + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + torch.ops.aten.relu.default, + torch.ops.aten.mean.default, + torch.ops.aten.mean.dim, + torch.ops.aten.permute.default, + torch.ops.aten.permute_copy.default, + torch.ops.aten.avg_pool2d.default, + torch.ops.aten.max_pool2d.default, + torch.ops.aten.full.default, + torch.ops.aten.flatten.using_ints, + torch.ops.aten.dropout.default, + operator.getitem, +] + + +def get_quant_properties( # noqa: C901 + node: Node, gm: torch.fx.GraphModule, quantization_config +) -> _OpQuantProperties: + input_act_qspec = quantization_config.get_input_act_qspec() + weight_qspec = quantization_config.get_weight_qspec() + output_act_qspec = quantization_config.get_output_act_qspec() + bias_qspec = quantization_config.get_bias_qspec() + + quant_properties = _OpQuantProperties() + if node.target in ( + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.linear.default, + ): + quant_properties.quant_inputs = [ + _QuantProperty(0, input_act_qspec), + _QuantProperty(1, weight_qspec, mark_annotated=True), + _QuantProperty(2, bias_qspec, optional=True, mark_annotated=True), + ] + quant_properties.quant_output = _QuantProperty(0, output_act_qspec) + elif node.target in ( + torch.ops.aten.matmul.default, + torch.ops.aten.mm.default, + torch.ops.aten.bmm.default, + torch.ops.aten.mul.Tensor, + torch.ops.aten.mul_.Tensor, + ): + quant_properties.quant_inputs = [ + _QuantProperty(0, input_act_qspec), + _QuantProperty(1, input_act_qspec), + ] + quant_properties.quant_output = _QuantProperty(0, output_act_qspec) + elif node.target in ( + torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, + torch.ops.aten.sub.Tensor, + torch.ops.aten.sub_.Tensor, + torch.ops.aten.minimum.default, + torch.ops.aten.maximum.default, + ): + shared_qspec = SharedQuantizationSpec((node.args[0], node)) + quant_properties.quant_inputs = [ + _QuantProperty(0, input_act_qspec), + _QuantProperty( + 1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec + ), + ] + quant_properties.quant_output = _QuantProperty(0, shared_qspec) + elif node.target == torch.ops.aten.adaptive_avg_pool2d.default: + input_qspec = ( + SharedQuantizationSpec(node.args[0]) + if arm_quantizer_utils.is_output_annotated(node.args[0]) + else input_act_qspec + ) + quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)] + quant_properties.quant_output = _QuantProperty( + 0, SharedQuantizationSpec((node.args[0], node)) + ) + elif node.target in ( + torch.ops.aten.cat.default, + torch.ops.aten.concatenate.default, + torch.ops.aten.stack.default, + ): + assert isinstance(node.args[0], list) + assert len(node.args[0]) != 0 + + shared_qspec = SharedQuantizationSpec((node.args[0][0], node)) + quant_properties.quant_inputs = [ + _QuantProperty( + 0, + [ + input_act_qspec if n == node.args[0][0] else shared_qspec + for n in node.args[0] + ], + ) + ] + quant_properties.quant_output = _QuantProperty(0, shared_qspec) + elif node.target in _one_to_one: + quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)] + quant_properties.quant_output = _QuantProperty(0, output_act_qspec) + elif node.target in _one_to_one_shared_input_qspec: + quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)] + quant_properties.quant_output = _QuantProperty( + 0, SharedQuantizationSpec((node.args[0], node)) + ) + elif node.target in _parent_shared_qspec: + if not isinstance(node.args[0], Node): + return None + + if not arm_quantizer_utils.is_output_annotated(node.args[0]): + return None + + shared_qspec = SharedQuantizationSpec(node.args[0]) + quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] + quant_properties.quant_output = _QuantProperty(0, shared_qspec) + else: + return None + + # Don't check if operator.getitem is ok for quantization, it's always ok + if node.target == operator.getitem: + return quant_properties + + # Check that each inputs/outputs can be quantized properly with the + # provided QuantProperties + for quant_property in quant_properties.quant_inputs: + if not _is_ok_for_quantization(node, quant_property, gm): + return None + + if quant_properties.quant_output is not None: + if not _is_ok_for_quantization(node, quant_properties.quant_output, gm): + return None + + return quant_properties + + +def annotate_graph( + gm: torch.fx.GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + for node in gm.graph.nodes: + if node.op != "call_function": + continue + + if arm_quantizer_utils.is_annotated(node): + continue + + if filter_fn is not None and not filter_fn(node): + continue + + quant_properties = get_quant_properties(node, gm, quantization_config) + if quant_properties is None: + continue + + for quant_property in quant_properties.quant_inputs: + _annotate_input(node, quant_property) + + if quant_properties.quant_output is not None: + _annotate_output(node, quant_properties.quant_output) + + arm_quantizer_utils.mark_node_as_annotated(node)