diff --git a/backends/xnnpack/operators/__init__.py b/backends/xnnpack/operators/__init__.py index e199c95b0f0..f056ad8b086 100644 --- a/backends/xnnpack/operators/__init__.py +++ b/backends/xnnpack/operators/__init__.py @@ -15,7 +15,6 @@ op_ceiling, op_clamp, op_conv2d, - op_dequantize_per_tensor, op_div, op_dynamic_dequantize_ops, op_dynamic_quantize_ops, @@ -35,7 +34,7 @@ op_negate, op_permute, op_prelu, - op_quantize_per_tensor, + op_quant_dequant, op_relu, op_rsqrt, op_sdpa, diff --git a/backends/xnnpack/operators/op_dequantize_per_tensor.py b/backends/xnnpack/operators/op_dequantize_per_tensor.py deleted file mode 100644 index cea76b31057..00000000000 --- a/backends/xnnpack/operators/op_dequantize_per_tensor.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Dict - -import torch -from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import ( - TagImplicitQDqPass, -) -from executorch.backends.xnnpack.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.xnnpack.operators.quant_params import QuantParams -from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( - XNNConvert, - XNNGraph, - XNode, -) -from executorch.backends.xnnpack.utils.utils import get_input_node - - -@register_node_visitor -class OpDeQuantizePerTensor(NodeVisitor): - """ - Dequantize Per Tensor Node visitor - """ - - target = "quantized_decomposed.dequantize_per_tensor.default" - - def __init__(self, *args) -> None: - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - xnn_graph: XNNGraph, - vals_to_ids: Dict[torch.fx.Node, int], - debug_handle: int, - ) -> None: - """ - We only define a node if it is not an implict dq node - """ - if not TagImplicitQDqPass.is_tagged_as_implicit_q_dq(node): - dq_input = get_input_node(node, 0) - input_quant_params = QuantParams.from_q_dq_node(node) - # fp32 output - self.define_tensor(node, xnn_graph, vals_to_ids) - output_id = vals_to_ids[node] - - # qint8 input - input_quant_params.is_output = False - self.define_tensor( - dq_input, xnn_graph, vals_to_ids, quant_params=input_quant_params - ) - input_id = vals_to_ids[dq_input] - - ser_node = XNode( - xnode_union=XNNConvert(input_id=input_id, output_id=output_id, flags=0), - debug_handle=debug_handle, - ) - xnn_graph.xnodes.append(ser_node) - else: - # If this node was ignored, then its id is the same as its parent - dq_input = get_input_node(node, 0) - if dq_input in vals_to_ids: - vals_to_ids[node] = vals_to_ids[dq_input] diff --git a/backends/xnnpack/operators/op_quant_dequant.py b/backends/xnnpack/operators/op_quant_dequant.py new file mode 100644 index 00000000000..521a8b6475a --- /dev/null +++ b/backends/xnnpack/operators/op_quant_dequant.py @@ -0,0 +1,198 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import ( + TagImplicitQDqPass, +) +from executorch.backends.xnnpack.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.xnnpack.operators.quant_params import QuantParams +from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( + XNNConvert, + XNNGraph, + XNode, +) +from executorch.backends.xnnpack.utils.quant_utils import ( + is_per_channel_group, + validate_quant_scales, + validate_quant_zeropoints, +) +from executorch.backends.xnnpack.utils.utils import get_input_node, get_param_tensor + + +class OpStaticQDQNode(NodeVisitor): + def check_scales_zeropoints(self, node) -> None: + scales = node.args[1] + zero_points = node.args[2] + is_groupwise = is_per_channel_group(node) + dtype = node.args[-1] + if is_groupwise: + dtype = node.args[-3] + + if isinstance(scales, torch.fx.Node): + scales = get_param_tensor(self.exported_program, scales) + + if isinstance(zero_points, torch.fx.Node): + zero_points = get_param_tensor(self.exported_program, zero_points) + + try: + validate_quant_scales(scales) + validate_quant_zeropoints(zero_points, dtype, is_groupwise) + except ValueError as e: + raise ValueError( + f"Invalid quantization scale or zero point for {node}: {e}" + ) + + def define_node( + self, + node: torch.fx.Node, + xnn_graph: XNNGraph, + vals_to_ids: Dict[torch.fx.Node, int], + debug_handle: int, + ) -> None: + # check scales and zp are valid + self.check_scales_zeropoints(node) + + +@register_node_visitor +class OpDeQuantizePerTensor(OpStaticQDQNode): + """ + Dequantize Per Tensor Node visitor + """ + + target = "quantized_decomposed.dequantize_per_tensor.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + xnn_graph: XNNGraph, + vals_to_ids: Dict[torch.fx.Node, int], + debug_handle: int, + ) -> None: + """ + We only define a node if it is not an implict dq node + """ + # check scales and zp are valid + super().define_node(node, xnn_graph, vals_to_ids, debug_handle) + + if not TagImplicitQDqPass.is_tagged_as_implicit_q_dq(node): + dq_input = get_input_node(node, 0) + input_quant_params = QuantParams.from_q_dq_node(node) + # fp32 output + self.define_tensor(node, xnn_graph, vals_to_ids) + output_id = vals_to_ids[node] + + # qint8 input + input_quant_params.is_output = False + self.define_tensor( + dq_input, xnn_graph, vals_to_ids, quant_params=input_quant_params + ) + input_id = vals_to_ids[dq_input] + + ser_node = XNode( + xnode_union=XNNConvert(input_id=input_id, output_id=output_id, flags=0), + debug_handle=debug_handle, + ) + xnn_graph.xnodes.append(ser_node) + else: + # If this node was ignored, then its id is the same as its parent + dq_input = get_input_node(node, 0) + if dq_input in vals_to_ids: + vals_to_ids[node] = vals_to_ids[dq_input] + + +@register_node_visitor +class OpQuantizePerTensor(OpStaticQDQNode): + """ + Quantize Per Tensor Node visitor + """ + + target = "quantized_decomposed.quantize_per_tensor.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + xnn_graph: XNNGraph, + vals_to_ids: Dict[torch.fx.Node, int], + debug_handle: int, + ) -> None: + """ + We only define a node if it is not an implict q node + """ + # check scales and zp are valid + super().define_node(node, xnn_graph, vals_to_ids, debug_handle) + + q_input = get_input_node(node, 0) + if not TagImplicitQDqPass.is_tagged_as_implicit_q_dq(node): + input_quant_params = QuantParams.from_q_dq_node(node) + # fp32 input + self.define_tensor(q_input, xnn_graph, vals_to_ids) + input_id = vals_to_ids[q_input] + + # qint8 output + input_quant_params.q_input = node + input_quant_params.is_input = False + self.define_tensor( + node, xnn_graph, vals_to_ids, quant_params=input_quant_params + ) + output_id = vals_to_ids[node] + + ser_node = XNode( + xnode_union=XNNConvert(input_id=input_id, output_id=output_id, flags=0), + debug_handle=debug_handle, + ) + xnn_graph.xnodes.append(ser_node) + else: + # If this node was ignored, then its id is the same as its parents + if q_input in vals_to_ids: + vals_to_ids[node] = vals_to_ids[q_input] + + +@register_node_visitor +class OpDequantizePerChannelDefault(OpStaticQDQNode): + """ + do nothing if node is dequantize_per_channel.default + """ + + target = "quantized_decomposed.dequantize_per_channel.default" + + +@register_node_visitor +class OpQuantizePerChannelDefault(OpStaticQDQNode): + """ + do nothing if node is quantize_per_channel.default + """ + + target = "quantized_decomposed.quantize_per_channel.default" + + +@register_node_visitor +class OpQuantizePerChannelGroupDefault(OpStaticQDQNode): + """ + do nothing if node is quantize_per_channel_group.default + """ + + target = "quantized_decomposed.quantize_per_channel_group.default" + + +@register_node_visitor +class OpDequantizePerChannelGroupDefault(OpStaticQDQNode): + """ + do nothing if node is dequantize_per_channel_group.default + """ + + target = "quantized_decomposed.dequantize_per_channel_group.default" diff --git a/backends/xnnpack/operators/op_quantize_per_tensor.py b/backends/xnnpack/operators/op_quantize_per_tensor.py deleted file mode 100644 index da15559410e..00000000000 --- a/backends/xnnpack/operators/op_quantize_per_tensor.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Dict - -import torch -from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import ( - TagImplicitQDqPass, -) -from executorch.backends.xnnpack.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.xnnpack.operators.quant_params import QuantParams -from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( - XNNConvert, - XNNGraph, - XNode, -) -from executorch.backends.xnnpack.utils.utils import get_input_node - - -@register_node_visitor -class OpQuantizePerTensor(NodeVisitor): - """ - Quantize Per Tensor Node visitor - """ - - target = "quantized_decomposed.quantize_per_tensor.default" - - def __init__(self, *args) -> None: - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - xnn_graph: XNNGraph, - vals_to_ids: Dict[torch.fx.Node, int], - debug_handle: int, - ) -> None: - """ - We only define a node if it is not an implict q node - """ - q_input = get_input_node(node, 0) - if not TagImplicitQDqPass.is_tagged_as_implicit_q_dq(node): - input_quant_params = QuantParams.from_q_dq_node(node) - # fp32 input - self.define_tensor(q_input, xnn_graph, vals_to_ids) - input_id = vals_to_ids[q_input] - - # qint8 output - input_quant_params.q_input = node - input_quant_params.is_input = False - self.define_tensor( - node, xnn_graph, vals_to_ids, quant_params=input_quant_params - ) - output_id = vals_to_ids[node] - - ser_node = XNode( - xnode_union=XNNConvert(input_id=input_id, output_id=output_id, flags=0), - debug_handle=debug_handle, - ) - xnn_graph.xnodes.append(ser_node) - else: - # If this node was ignored, then its id is the same as its parents - if q_input in vals_to_ids: - vals_to_ids[node] = vals_to_ids[q_input] diff --git a/backends/xnnpack/operators/op_skip_ops.py b/backends/xnnpack/operators/op_skip_ops.py index 6597c0568e3..face7342d8f 100644 --- a/backends/xnnpack/operators/op_skip_ops.py +++ b/backends/xnnpack/operators/op_skip_ops.py @@ -41,15 +41,6 @@ class OpChooseQparamsTensor(OpSkipOps): target = "quantized_decomposed.choose_qparams.tensor" -@register_node_visitor -class OpDequantizePerChannelDefault(OpSkipOps): - """ - do nothing if node is dequantize_per_channel.default - """ - - target = "quantized_decomposed.dequantize_per_channel.default" - - @register_node_visitor class OpGetItem(OpSkipOps): """ @@ -59,15 +50,6 @@ class OpGetItem(OpSkipOps): target = "getitem" -@register_node_visitor -class OpQuantizePerChannelDefault(OpSkipOps): - """ - do nothing if node is quantize_per_channel.default - """ - - target = "quantized_decomposed.quantize_per_channel.default" - - @register_node_visitor class OpTCopyDefault(OpSkipOps): """ @@ -113,21 +95,3 @@ class OpChooseQparamsToken(OpSkipOps): """ target = "quantized_decomposed.choose_qparams_per_token_asymmetric.default" - - -@register_node_visitor -class OpQuantizePerChannelGroupDefault(OpSkipOps): - """ - do nothing if node is quantize_per_channel_group.default - """ - - target = "quantized_decomposed.quantize_per_channel_group.default" - - -@register_node_visitor -class OpDequantizePerChannelGroupDefault(OpSkipOps): - """ - do nothing if node is dequantize_per_channel_group.default - """ - - target = "quantized_decomposed.dequantize_per_channel_group.default" diff --git a/backends/xnnpack/operators/quant_params.py b/backends/xnnpack/operators/quant_params.py index a2d26772ecc..e695b151560 100644 --- a/backends/xnnpack/operators/quant_params.py +++ b/backends/xnnpack/operators/quant_params.py @@ -102,6 +102,16 @@ def __init__( assert group_size > 0, "Group size must be greater than 0" self.is_per_channel_group = self.per_channel and self.group_size > 0 + if per_channel and not self.is_per_channel_group: + tensor = q_input.meta["val"] + assert ( + tensor.shape[self.axis] == cast(torch.Tensor, self.scale).shape[0] + ), f"Invalid size of per channel quantization scales, axis: {self.axis}, scale size: {self.scale.shape}, tensor shape: {tensor.shape}" + + assert ( + tensor.shape[self.axis] == cast(torch.Tensor, self.zp).shape[0] + ), f"Invalid size of per channel quantization zero-points, axis: {self.axis}, zp size: {self.zp.shape}, tensor shape: {tensor.shape}" + def quantize_tensor(self, tensor: torch.Tensor) -> torch.Tensor: # Do nothing if already quantized by the Quantizer if tensor.dtype == self.dtype: diff --git a/backends/xnnpack/test/ops/test_check_quant_params.py b/backends/xnnpack/test/ops/test_check_quant_params.py new file mode 100644 index 00000000000..cd18568afba --- /dev/null +++ b/backends/xnnpack/test/ops/test_check_quant_params.py @@ -0,0 +1,104 @@ +import unittest + +import torch +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner + +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) +from executorch.backends.xnnpack.utils.utils import get_param_tensor +from executorch.exir import to_edge_transform_and_lower +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.export import export_for_training + + +class TestCheckQuantParams(unittest.TestCase): + def create_invalid_value_injector( + self, invalid_value, is_per_channel=False, is_zp=False + ): + def inject_invalid_scale_in_per_tensor(aten): + for node in aten.graph_module.graph.nodes: + target_to_find = ( + torch.ops.quantized_decomposed.quantize_per_tensor.default + if not is_per_channel + else torch.ops.quantized_decomposed.dequantize_per_channel.default + ) + if node.op == "call_function" and node.target == target_to_find: + if is_zp: + node_args = list(node.args) + node_args[2] = invalid_value + node.args = tuple(node_args) + break + else: + scale = node.args[1] + if is_per_channel: + self.assertTrue(isinstance(scale, torch.fx.Node)) + scale_tensor = get_param_tensor(aten, scale) + scale_tensor[2] = invalid_value + else: + self.assertTrue(isinstance(scale, float)) + node_args = list(node.args) + node_args[1] = invalid_value + node.args = tuple(node_args) + break + + return inject_invalid_scale_in_per_tensor + + def _test_check_quant_message(self, ep_modifier, expected_message): + mod = torch.nn.Linear(10, 10) + quantizer = XNNPACKQuantizer() + captured = export_for_training(mod, (torch.randn(1, 10),)).module() + quantizer.set_global(get_symmetric_quantization_config(is_per_channel=True)) + prepared = prepare_pt2e(captured, quantizer) + + prepared(*(torch.randn(1, 10),)) + converted = convert_pt2e(prepared) + aten = torch.export.export(converted, (torch.randn(1, 10),)) + + ep_modifier(aten) + + with self.assertRaises(ValueError) as context: + to_edge_transform_and_lower(aten, partitioner=[XnnpackPartitioner()]) + + self.assertEquals(str(context.exception), expected_message) + + def test_in_per_tensor_quant(self): + + for invalid_scale in [ + float("nan"), + float("inf"), + -float("inf"), + 1.0000002153053333e-39, + ]: + self._test_check_quant_message( + self.create_invalid_value_injector(invalid_scale), + "Invalid quantization scale or zero point for quantized_decomposed_quantize_per_tensor_default: " + "Scales must be finite and normal, however found scale value: " + f"{invalid_scale} in scale tensor at index: (0,)", + ) + + def test_in_per_channel_quant(self): + for invalid_scale in [ + float("nan"), + float("inf"), + -float("inf"), + 1.0000002153053333e-39, + ]: + self._test_check_quant_message( + self.create_invalid_value_injector(invalid_scale, is_per_channel=True), + "Invalid quantization scale or zero point for quantized_decomposed_dequantize_per_channel_default: " + "Scales must be finite and normal, however found scale value: " + f"{invalid_scale} in scale tensor at index: (2,)", + ) + + def test_inject_invalid_zp(self): + for invalid_zp in [-129, 128]: + self._test_check_quant_message( + self.create_invalid_value_injector( + invalid_zp, is_zp=True, is_per_channel=False + ), + "Invalid quantization scale or zero point for quantized_decomposed_quantize_per_tensor_default: " + f"Found invalid zeropoint {invalid_zp} " + "in zero point tensor at index: (0,)", + ) diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index 79544256022..bbd4843dbb0 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -306,9 +306,8 @@ def __init__( self.edge_dialect_program = None def run(self, artifact: ExportedProgram, inputs=None) -> None: - artifact_to_run = copy.deepcopy(artifact) self.edge_dialect_program = to_edge_transform_and_lower( - artifact_to_run, + artifact, compile_config=self.edge_compile_conf, partitioner=self.partitioners, ) diff --git a/backends/xnnpack/utils/quant_utils.py b/backends/xnnpack/utils/quant_utils.py index 7e3df9097cc..db1914e3910 100644 --- a/backends/xnnpack/utils/quant_utils.py +++ b/backends/xnnpack/utils/quant_utils.py @@ -6,7 +6,7 @@ import operator from itertools import accumulate -from typing import cast +from typing import cast, Union import torch from executorch.exir.backend.canonical_partitioners.config_partitioner import ( @@ -213,3 +213,62 @@ def extract_qdq_affine_op_args_for_decomposed_ops(node: torch.fx.Node): args.append(node.args[-1]) return args + + +def is_tensor_subnormal(tensor: torch.Tensor): + finfo = torch.finfo(tensor.dtype) + return (tensor >= 0) & (torch.abs(tensor) < finfo.smallest_normal) + + +def validate_quant_scales(scales: Union[float, torch.Tensor]): + if isinstance(scales, float): + scales = torch.tensor([scales]) + + is_infinite = torch.isinf(scales) | torch.isnan(scales) + + is_subnormal = is_tensor_subnormal(scales) + + if is_infinite.nonzero().numel() != 0: + idx = torch.where(is_infinite) + idx = tuple(int(index[0]) for index in idx) + value = scales[idx] + raise ValueError( + f"Scales must be finite and normal, however found scale value: {value}" + f" in scale tensor at index: {idx}" + ) + + if is_subnormal.nonzero().numel() != 0: + idx = torch.where(is_subnormal) + idx = tuple(int(index[0]) for index in idx) + value = scales[idx] + raise ValueError( + f"Scales must be finite and normal, however found scale value: {value}" + f" in scale tensor at index: {tuple(idx)}" + ) + + +def validate_quant_zeropoints( + zp: Union[float, int, torch.Tensor], dtype: torch.dtype, is_4bit: bool +): + if not isinstance(zp, torch.Tensor): + zp = torch.tensor([zp]) + + if dtype == torch.int8 or dtype == torch.qint8: + if is_4bit: + invalid_zp = (zp < 0) | (zp > 15) + else: + invalid_zp = (zp < -128) | (zp > 127) + elif dtype == torch.uint8 or dtype == torch.quint8: + invalid_zp = (zp < 0) | (zp > 255) + elif dtype == torch.int32: + invalid_zp = zp != 0 + else: + raise ValueError("Unsupported dtype for quantization") + + if invalid_zp.nonzero().numel() != 0: + idx = torch.where(invalid_zp) + idx = tuple(int(index[0]) for index in idx) + value = zp[tuple(idx)] + raise ValueError( + f"Found invalid zeropoint {value}" f" in zero point tensor at index: {idx}" + )