diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index b4bb809b851..25811d077bb 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -29,10 +29,6 @@ DecomposeSoftmaxesPass, ) from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - FoldAndAnnotateQParamsPass, - QuantizeFullArgument, -) from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import ( KeepDimsFalseToSqueezePass, ) @@ -54,7 +50,6 @@ from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass from executorch.exir import ExportedProgram from executorch.exir.backend.compile_spec_schema import CompileSpec -from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_manager import PassManager @@ -85,19 +80,6 @@ def transform_to_backend_pipeline( self.add_pass(Conv1dUnsqueezePass(exported_program)) self.add_pass(DecomposeSoftmaxesPass()) self.add_pass(DecomposeLinearPass()) - self.add_pass(QuantizeFullArgument()) - self.add_pass( - FoldAndAnnotateQParamsPass( - [ - exir_ops.edge.aten.minimum.default, - exir_ops.edge.aten.maximum.default, - exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.avg_pool2d.default, - exir_ops.edge.aten.convolution.default, - exir_ops.edge.aten.full.default, - ] - ) - ) for spec in compile_spec: if spec.key == "permute_memory_format": memory_format = spec.value.decode() diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py deleted file mode 100644 index 6ba72eb1022..00000000000 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ /dev/null @@ -1,181 +0,0 @@ -# Copyright 2024 Arm Limited and/or its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import copy - -from typing import Callable, cast, Iterable - -from executorch.backends.arm.tosa_quant_utils import QuantArgs - -from executorch.exir.dialects._ops import ops as exir_ops - -from executorch.exir.pass_base import ExportPass, PassResult -from torch.fx import GraphModule, Node - -q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default -dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default - - -def get_input_qparams(node: Node) -> dict[int, QuantArgs]: - """ - Get the input quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'. - Raises a ValueError if the node doesn't have any parameters set. - """ - if "input_qparams" not in node.meta.keys(): - raise ValueError(f"No input quantization parameter found in node {node}") - input_qparams = cast(dict[int, QuantArgs], node.meta["input_qparams"]) - if len(input_qparams) == 0: - raise ValueError(f"No input quantization parameter found in node {node}") - return input_qparams - - -def get_output_qparams(node: Node) -> dict[int, QuantArgs]: - """ - Get the output quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'. - Raises a ValueError if the node doesn't have any parameters set. - """ - if "output_qparams" not in node.meta.keys(): - raise ValueError(f"No output quantization parameter found in node {node}") - input_qparams = cast(dict[int, QuantArgs], node.meta["output_qparams"]) - if len(input_qparams) == 0: - raise ValueError(f"No output quantization parameter found in node {node}") - return input_qparams - - -class FoldAndAnnotateQParamsPass(ExportPass): - """ - A pass that walks the graph and removes any DQ and Q nodes before and after the target - node in the supplied list of operators. - The quantization parameters from the DQ/Q nodes are stored as meta values to be - accessible for later lowering and serialization passes. - The assumption is that the quantization annotatation adds DQ nodes for all tensor - inputs to the target one Q node to the output. - - Example ('executorch_exir_dialects_edge__ops_' prefix removed from operators for readability): - - x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8) - - x_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(x_q, 0.05487706884741783, -128, -128, 127, torch.int8) - aten_add_tensor: "f32[5]" = ops_aten_add_Tensor(x_dq, x_dq) - aten_add_tensor_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(aten_add_tensor, 0.05487706884741783, -128, -128, 127, torch.int8) - - output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8) - - Becomes: - x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8) - - aten_add_tensor: "i8[5]" = aten_add_Tensor(x_q, x_q) - - output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8) - - The quantization parameters for x_dq and aten_add_tensor_q are store in meta for the aten_add_tensor node. - - """ - - def __init__(self, targeted_ops: Iterable[Callable]): - super().__init__() - self.targeted_ops = targeted_ops - - def call(self, graph_module: GraphModule) -> PassResult: - - # Loop over the graph nodes and find any node in the 'targeted_ops' list. - for n in graph_module.graph.nodes: - n = cast(Node, n) - if n.op != "call_function" or n.target not in self.targeted_ops: - continue - - # Make sure we haven't already set qparams meta information on the node - assert "input_qparams" not in n.meta.keys() - assert "output_qparams" not in n.meta.keys() - - # for the inputs and outputs search the graph for quantization info and - # store the information in a dict with order of the _tensor_ inputs as key, - # ignoring any other arguments to the target node. - n.meta["input_qparams"] = {} - n.meta["output_qparams"] = {} - for i, arg in enumerate(n.args): - if not isinstance(arg, Node): - continue - - # Make sure arg has requires_grad set to False - # For parameters that are not quantized, sometimes (i.e. convolution) - # the Parameter(FakeTensor(...)) has requires_grad set to True, which - # causes the retracing of the graph to fail with: - # - # E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch. - # E - # E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) - # E Original traceback: - # E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward - # E x = conv(x) - # - if arg.op == "placeholder": - arg.meta["val"].requires_grad = False - - if arg.target != dq_op: - continue - - # arg.target for argument i is a dequant node, extract the information - n.meta["input_qparams"][i] = QuantArgs.from_operator( - arg.target, arg.args - ) - - # arg.args[0] is the tensor input, replace the input usage - n.replace_input_with(arg, arg.args[0]) - graph_module.graph.erase_node(arg) - - # Copy the users, since we are modifying it. - users_copy = copy.copy(n.users) - for i, user in enumerate(users_copy): - if user.target != q_op: - continue - - # quantization node found here, store the quantization parameters in meta value - n.meta["output_qparams"][i] = QuantArgs.from_operator( - user.target, user.args - ) - - user.replace_all_uses_with(n) - graph_module.graph.erase_node(user) - - # retrace the graph to update the fake tensor types - graph_module = super().call(graph_module).graph_module - - graph_module.recompile() - return PassResult(graph_module, True) - - -class QuantizeFullArgument(ExportPass): - """ - Make sure the fill_value for full.default is quantized. This pass needs to be run before - the folding pass above to make sure that the retraced output of the full.default op is - the right dtype. - """ - - def call(self, graph_module: GraphModule) -> PassResult: - modified = False - # Loop over the graph nodes and find any node in the 'targeted_ops' list. - for n in graph_module.graph.nodes: - n = cast(Node, n) - if n.target != exir_ops.edge.aten.full.default: - continue - - # Make sure we have a quantized operator - user = list(n.users)[0] - if user.target != q_op: - continue - - qargs = QuantArgs.from_operator(user.target, user.args) - if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype: - # replace the node arg with a quantized dito and also set dtype - # to get the right output according to the Edge IR specification: - # exir/dialects/edge/edge.yaml:3596 - quantized_full_value = qargs.quantize_value(n.args[1]).item() - n.update_arg(1, quantized_full_value) - n.update_kwarg("dtype", qargs.dtype) - modified = True - - return PassResult(graph_module, modified) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 7f92574cfd1..836c6733703 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -94,8 +94,6 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool: exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mm.default, - exir_ops.edge.aten.minimum.default, - exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.repeat.default, exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.relu.default, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 6db9c968f09..8c4aa85e579 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -19,9 +19,7 @@ op_get_item, op_hardtanh, op_log, - op_max, op_max_pool2d, - op_min, op_mm, op_mul, op_permute, diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index a81e52c5c6e..382779df3d8 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -11,6 +11,7 @@ import executorch.backends.arm.tosa_utils as tutils import serializer.tosa_serializer as ts +import torch from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -40,27 +41,33 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - # Specification (0.80) states that input and output types - # should all be the same - assert inputs[0].dtype == inputs[1].dtype == output.dtype - # Handle int8 (quantized) and int32 - assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32] - - if inputs[0].dtype == ts.DType.INT8: - rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node + input_nodes = tutils.get_two_inputs(node) + + if not is_quant_node and not all( + tensor.meta["val"].dtype in (torch.int8, torch.int32) + for tensor in input_nodes + ): + raise RuntimeError( + f"Unexpected non quantized {AddVisitor_080_BI.target} node." ) - else: - # input[0].dtype == ts.DType.INT32 - # Non quantized input, natively support by TOSA.ADD - rescaled_inputs = inputs - if output.dtype == ts.DType.INT8: + needs_rescale = not ( + all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes) + and node.meta["val"].dtype == torch.int32 + ) + + if needs_rescale: + # Rescale inputs to 32 bit + rescaled_inputs, scale = tqutils.rescale_nodes_to_int32( + input_nodes, tosa_graph + ) + + # Prepare add output tensor broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) else: - # output.dtype == ts.DType.INT32 add_output = output + rescaled_inputs = inputs # Do the INT32 Add tosa_graph.addOperator( @@ -73,10 +80,10 @@ def define_node( None, ) - if output.dtype == ts.DType.INT8: + if needs_rescale: # Scale output back to 8 bit # pyre-ignore - tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node) + tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph) @register_node_visitor @@ -98,19 +105,11 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - # Specification (0.80) states that input and output types - # should all be the same - assert inputs[0].dtype == inputs[1].dtype == output.dtype - - if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: + if is_quant_node: # Call the inherited define_node for handling integers super().define_node(node, tosa_graph, inputs, output, is_quant_node) else: # FP32 Add lowering - assert inputs[0].dtype == ts.DType.FP32 - assert output.dtype == ts.DType.FP32 - - # MI lowering tosa_graph.addOperator( TosaOp.Op().ADD, [inputs[0].name, inputs[1].name], diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 40491fb5f64..4caaad92028 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -8,41 +8,30 @@ import serializer.tosa_serializer as ts import torch -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, - get_output_qparams, -) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.arm.tosa_utils import build_avg_pool_2d_common @register_node_visitor -class AvgPool2dVisitor_0_80_BI(NodeVisitor): +class AvgPool2dVisitor(NodeVisitor): target = "aten.avg_pool2d.default" - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80+BI"), - ] - def __init__(self, *args): super().__init__(*args) - def _build_generic_avgpool2d( + def define_node( self, node: torch.fx.Node, tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - input_zp: int, - output_zp: int, - accumulator_type, + is_quant_node: bool, ) -> None: input_tensor = inputs[0] - kernel_size_list = inputs[1].special stride_size_list = inputs[2].special try: @@ -50,76 +39,13 @@ def _build_generic_avgpool2d( except IndexError: pad_size_list = [0, 0, 0, 0] - attr = ts.TosaSerializerAttribute() - attr.PoolAttribute( - kernel=kernel_size_list, - stride=stride_size_list, - pad=pad_size_list, - input_zp=input_zp, - output_zp=output_zp, - accum_dtype=accumulator_type, - ) - - tosa_graph.addOperator( - ts.TosaOp.Op().AVG_POOL2D, - [input_tensor.name], - [output.name], - attr, - ) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - inputs: List[TosaArg], - output: TosaArg, - is_quant_node: bool, - ) -> None: - input_tensor = inputs[0] - assert input_tensor.dtype == ts.DType.INT8 - - accumulator_type = ts.DType.INT32 - - input_qargs = get_input_qparams(node) - input_zp = input_qargs[0].zp - - output_qargs = get_output_qparams(node) - output_zp = output_qargs[0].zp - - self._build_generic_avgpool2d( - node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type + build_avg_pool_2d_common( + node, + tosa_graph, + input_tensor, + kernel_size_list, + stride_size_list, + pad_size_list, + is_quant_node, + output, ) - - -@register_node_visitor -class AvgPool2dVisitor_0_80_MI(AvgPool2dVisitor_0_80_BI): - # inheriting 'target' from BI class - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80+MI"), - ] - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - inputs: List[TosaArg], - output: TosaArg, - is_quant_node: bool, - ) -> None: - assert ( - inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32 - ), "Only FP32 and INT8 supported" - - if inputs[0].dtype == ts.DType.INT8: - super().define_node(node, tosa_graph, inputs, output, is_quant_node) - - if inputs[0].dtype == ts.DType.FP32: - accumulator_type = ts.DType.FP32 - # Initilize zero point to zero. - input_zp = 0 - output_zp = 0 - - self._build_generic_avgpool2d( - node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type - ) diff --git a/backends/arm/operators/op_batch_norm.py b/backends/arm/operators/op_batch_norm.py index c3b9bb0c43e..d17c3a1b81f 100644 --- a/backends/arm/operators/op_batch_norm.py +++ b/backends/arm/operators/op_batch_norm.py @@ -13,7 +13,6 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import promote_shape, tosa_shape from serializer.tosa_serializer import TosaOp @@ -22,10 +21,6 @@ class BatchNormVisitor(NodeVisitor): target = "aten._native_batch_norm_legit_no_training.default" - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80+MI"), - ] - def __init__(self, *args): super().__init__(*args) diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index dc64e169364..ffbeee7306d 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -8,16 +8,16 @@ import serializer.tosa_serializer as ts import torch -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, - get_output_qparams, -) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import build_rescale_conv_output +from executorch.backends.arm.tosa_quant_utils import ( + build_rescale_conv_output, + get_quant_arg_downstream, + get_quant_arg_upstream, +) from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape from serializer.tosa_serializer import TosaOp @@ -57,6 +57,9 @@ def define_node( ) -> None: input, weight, bias, stride, pad, dilation, _, _, group = inputs + # Currently only int8 is supported in quantized types. + actual_out_type = ts.DType.INT8 if is_quant_node else output.dtype + # Get the attributes of convolution. attr = ts.TosaSerializerAttribute() pad_attr = [val for val in pad.special for _ in (0, 1)] @@ -79,11 +82,9 @@ def define_node( dilation_attr[1], ) - input_zp = 0 - if inputs[0].dtype == ts.DType.INT8: - # int8 input requires quantization information - input_qparams = get_input_qparams(node) - input_zp = input_qparams[0].zp + input_zp = ( + get_quant_arg_upstream(node.all_input_nodes[0]).zp if is_quant_node else 0 + ) attr.ConvAttribute( pad=pad_attr, @@ -99,22 +100,16 @@ def define_node( # Create a zero bias tensor if not presented out_channels = weight.shape[0] bias_name = "bias" + node.name.split("default", 1)[1] - bias_type = output.dtype - if output.dtype == ts.DType.INT8: - # Conv is quantized to int8, but the TOSA operator has - # output type int32, and the bias must be the same type - # as the TOSA output type - bias_type = ts.DType.INT32 bias = tosa_graph.addConst( [out_channels], - bias_type, + ts.DType.INT32 if is_quant_node else output.dtype, [0] * out_channels, name=bias_name, ) # The output type is int32 when input type is int8. conv2d_output_name = output.name - if output.dtype == ts.DType.INT8: + if is_quant_node: conv2d_res = tosa_graph.addIntermediate( tosa_shape(output.shape, output.dim_order), ts.DType.INT32 ) @@ -137,7 +132,7 @@ def define_node( weight_reshaped = tosa_graph.addIntermediate( weight_post_shape, - weight.dtype, + ts.DType.INT8 if is_quant_node else weight.dtype, ) build_reshape( tosa_graph, weight.name, weight_post_shape, weight_reshaped.name @@ -162,19 +157,20 @@ def define_node( # For quantized convolution, rescale the output value back to the same # integer value domain of the next op. Otherwise return float32 output. - if inputs[0].dtype == ts.DType.INT8: + if is_quant_node: # Get scale_factor from input, weight, and output. - input_scale = input_qparams[0].scale - weight_scale = input_qparams[1].scale - output_qargs = get_output_qparams(node) + input_scale = get_quant_arg_upstream(node.all_input_nodes[0]).scale + weight_scale = get_quant_arg_upstream(node.all_input_nodes[1]).scale + output_qargs = get_quant_arg_downstream(list(node.users)[0]) + build_rescale_conv_output( tosa_graph, # pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined. conv2d_res, output.name, - output.dtype, + actual_out_type, input_scale, weight_scale, - output_qargs[0].scale, - output_qargs[0].zp, + output_qargs.scale, + output_qargs.zp, ) diff --git a/backends/arm/operators/op_div.py b/backends/arm/operators/op_div.py index 2332e807c4d..0857e0ed32a 100644 --- a/backends/arm/operators/op_div.py +++ b/backends/arm/operators/op_div.py @@ -13,7 +13,6 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import tosa_shape from serializer.tosa_serializer import TosaOp @@ -22,11 +21,6 @@ class DivVisitor(NodeVisitor): target = "aten.div.Tensor" - # Only supported for MI - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80+MI"), - ] - def __init__(self, *args): super().__init__(*args) diff --git a/backends/arm/operators/op_full.py b/backends/arm/operators/op_full.py index 23a13dd4869..d2bc1377ce7 100644 --- a/backends/arm/operators/op_full.py +++ b/backends/arm/operators/op_full.py @@ -14,6 +14,10 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_quant_utils import ( + get_quant_arg_downstream, + quantize_value, +) from executorch.backends.arm.tosa_utils import tosa_shape from torch.fx import Node @@ -37,14 +41,19 @@ def define_node( shape = tosa_shape(inputs[0].special, output.dim_order) value = inputs[1].number - - if output.dtype == ts.DType.INT8: - fill_dtype = np.int8 + if is_quant_node: + qargs = get_quant_arg_downstream(list(node.users)[0]) + qvalue = quantize_value(value, qargs) + dtype = ts.DType.INT8 + data = np.full(shape, qvalue, dtype=np.int8) else: - fill_dtype = np.float32 - data = np.full(shape, value, dtype=fill_dtype) + assert ( + output.dtype == ts.DType.FP32 + ), "'Full' currently only supports FP32 for unquantized models." + dtype = ts.DType.FP32 + data = np.full(shape, value, dtype=np.float32) - tosa_graph.addConst(shape, output.dtype, data, node.name + "full-const") + tosa_graph.addConst(shape, dtype, data, node.name + "full-const") tosa_graph.addOperator( ts.TosaOp.Op.IDENTITY, [node.name + "full-const"], [output.name] ) diff --git a/backends/arm/operators/op_max.py b/backends/arm/operators/op_max.py deleted file mode 100644 index 61d889e0db7..00000000000 --- a/backends/arm/operators/op_max.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import List - -import executorch.backends.arm.tosa_quant_utils as tqutils -import serializer.tosa_serializer as ts -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, -) -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_utils import tosa_shape - -from serializer.tosa_serializer import TosaOp -from torch.fx import Node - - -@register_node_visitor -class MaxVisitor(NodeVisitor): - target = "aten.maximum.default" - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: ts.TosaSerializer, - inputs: List[TosaArg], - output: TosaArg, - is_quant_node: bool, - ) -> None: - assert inputs[0].dtype == inputs[1].dtype - - max_output = output - if inputs[0].dtype == ts.DType.INT8: - input_qparams = get_input_qparams(node) - assert ( - len(input_qparams) == 2 - ), f"Both inputs needs to have quantization information for {node}" - # insert RESCALEs to int32 - assert ( - input_qparams[0] == input_qparams[1] - ), "Both inputs must have same quantization for MAX" - - operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node - ) - - output.shape = tosa_shape(output.shape, output.dim_order) - max_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) - else: - operand_inputs = inputs - - tosa_graph.addOperator( - TosaOp.Op().MAXIMUM, - [ - operand_inputs[0].name, - operand_inputs[1].name, - ], - [max_output.name], - ) - - if output.dtype == ts.DType.INT8: - # insert RESCALE from int32 back to int8 - tqutils.insert_rescale_op_to_int8(tosa_graph, max_output, scale_back, node) diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 0a4092e3a9a..74e33ddb02c 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -13,7 +13,7 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( +from executorch.backends.arm.tosa_utils import ( get_quant_arg_downstream, get_quant_arg_upstream, ) diff --git a/backends/arm/operators/op_min.py b/backends/arm/operators/op_min.py deleted file mode 100644 index 6750ddd41fc..00000000000 --- a/backends/arm/operators/op_min.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import List - -import executorch.backends.arm.tosa_quant_utils as tqutils - -import serializer.tosa_serializer as ts -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, -) -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_utils import tosa_shape - -from serializer.tosa_serializer import TosaOp -from torch.fx import Node - - -@register_node_visitor -class MinVisitor(NodeVisitor): - target = "aten.minimum.default" - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: ts.TosaSerializer, - inputs: List[TosaArg], - output: TosaArg, - is_quant_node: bool, - ) -> None: - assert inputs[0].dtype == inputs[1].dtype - - min_output = output - if inputs[0].dtype == ts.DType.INT8: - input_qparams = get_input_qparams(node) - assert ( - len(input_qparams) == 2 - ), f"Both inputs needs to have quantization information for {node}" - # insert RESCALEs to int32 - assert ( - input_qparams[0] == input_qparams[1] - ), "Both inputs must have same quantization for MIN" - - operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node - ) - - output.shape = tosa_shape(output.shape, output.dim_order) - min_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) - else: - operand_inputs = inputs - - tosa_graph.addOperator( - TosaOp.Op().MINIMUM, - [ - operand_inputs[0].name, - operand_inputs[1].name, - ], - [min_output.name], - ) - - if output.dtype == ts.DType.INT8: - # insert RESCALE from int32 back to int8 - tqutils.insert_rescale_op_to_int8(tosa_graph, min_output, scale_back, node) diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 3b1ea9d70fe..2d3a0c2786c 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -11,12 +11,10 @@ import serializer.tosa_serializer as ts import torch import torch.fx -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, -) from executorch.backends.arm.operators.node_visitor import NodeVisitor from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg from executorch.backends.arm.tosa_quant_utils import ( + get_quant_arg_upstream, get_quantized_node_output_dtype, is_node_quantized, ) @@ -112,10 +110,8 @@ def process_quantized_bias( _, ) = consumer_node.all_input_nodes - input_qargs = get_input_qparams(consumer_node) - - input_node_scale = input_qargs[0].scale - weight_node_scale = input_qargs[1].scale + input_node_scale = get_quant_arg_upstream(input_node).scale + weight_node_scale = get_quant_arg_upstream(weight_node).scale bias_values_quantized = ( (parameter_values / (input_node_scale * weight_node_scale)) .round() diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 8815d40b0b0..6f2a5689d39 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -77,7 +77,6 @@ def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPattern ], "mul": [[torch.mul]], "sub": [[torch.sub]], - "min_max": [[torch.min], [torch.max]], } return copy.deepcopy(supported_operators) @@ -268,7 +267,6 @@ class ArmQuantizer(Quantizer): "add", "sub", "mul", - "min_max", "mm", "one_to_one", "generic", diff --git a/backends/arm/quantizer/quantization_annotation/__init__.py b/backends/arm/quantizer/quantization_annotation/__init__.py index d9d27cee2ac..1201df51adc 100644 --- a/backends/arm/quantizer/quantization_annotation/__init__.py +++ b/backends/arm/quantizer/quantization_annotation/__init__.py @@ -55,7 +55,6 @@ def decorator(annotator: AnnotatorType): generic_annotator, linear_annotator, max_pool2d_annotator, - min_max_annotator, mm_annotator, mul_annotator, one_to_one_annotator, diff --git a/backends/arm/quantizer/quantization_annotation/min_max_annotator.py b/backends/arm/quantizer/quantization_annotation/min_max_annotator.py deleted file mode 100644 index 43c4d20c134..00000000000 --- a/backends/arm/quantizer/quantization_annotation/min_max_annotator.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import Callable, List, Optional - -import torch -from executorch.backends.arm.quantizer import arm_quantizer_utils -from executorch.backends.arm.quantizer.quantization_annotation import register_annotator -from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig -from torch.ao.quantization.quantizer import QuantizationAnnotation -from torch.fx import GraphModule, Node - - -@register_annotator("min_max") -def _annotate_min_max( - gm: GraphModule, - quantization_config: QuantizationConfig, - filter_fn: Optional[Callable[[Node], bool]] = None, -) -> Optional[List[List[Node]]]: - annotated_partitions = [] - for node in gm.graph.nodes: - if node.target not in ( - torch.ops.aten.minimum.default, - torch.ops.aten.maximum.default, - ): - continue - annotated_partitions.append(node) - min_max_node = node - if arm_quantizer_utils.is_annotated(min_max_node): - continue - - input_qspec_map, output_qspec = arm_quantizer_utils.get_shared_qspec( - min_max_node, gm, quantization_config - ) - if input_qspec_map is not None: - min_max_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=output_qspec, - _annotated=True, - ) - return annotated_partitions diff --git a/backends/arm/test/ops/test_maximum.py b/backends/arm/test/ops/test_maximum.py deleted file mode 100644 index 7e750645229..00000000000 --- a/backends/arm/test/ops/test_maximum.py +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -from typing import Tuple - -import torch -from executorch.backends.arm.test import common -from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.exir import EdgeCompileConfig -from executorch.exir.backend.compile_spec_schema import CompileSpec -from parameterized import parameterized - - -class TestMaximum(unittest.TestCase): - """Tests a single maximum op""" - - class Maximum(torch.nn.Module): - test_parameters = [ - ( - torch.FloatTensor([1, 2, 3, 5, 7]), - (torch.FloatTensor([2, 1, 2, 1, 10])), - ), - (torch.ones(1, 10, 4, 6), 2 * torch.ones(1, 10, 4, 6)), - (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)), - (torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)), - (10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)), - ] - - def __init__(self): - super().__init__() - - def forward(self, x, y): - return torch.maximum(x, y) - - _edge_compile_config: EdgeCompileConfig = EdgeCompileConfig( - _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. - ) - - def _test_maximum_tosa_MI_pipeline( - self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] - ): - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), - ) - .export() - .check_count({"torch.ops.aten.maximum.default": 1}) - .check_not(["torch.ops.quantized_decomposed"]) - .to_edge(config=self._edge_compile_config) - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data) - ) - - def _test_maximum_tosa_BI_pipeline( - self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] - ): - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), - ) - .quantize() - .export() - .check_count({"torch.ops.aten.maximum.default": 1}) - .check(["torch.ops.quantized_decomposed"]) - .to_edge(config=self._edge_compile_config) - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data, qtol=1) - ) - - def _test_maximum_ethos_BI_pipeline( - self, - module: torch.nn.Module, - compile_spec: CompileSpec, - test_data: Tuple[torch.Tensor], - ): - tester = ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=compile_spec, - ) - .quantize() - .export() - .to_edge() - .partition() - .to_executorch() - .serialize() - ) - - return tester - - @parameterized.expand(Maximum.test_parameters) - def test_maximum_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): - test_data = (operand1, operand2) - self._test_maximum_tosa_MI_pipeline(self.Maximum(), test_data) - - @parameterized.expand(Maximum.test_parameters) - def test_maximum_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): - test_data = (operand1, operand2) - self._test_maximum_tosa_BI_pipeline(self.Maximum(), test_data) - - @parameterized.expand(Maximum.test_parameters) - @unittest.expectedFailure # Bug in Vela, disabled until pin changes, bug MLETORCH-513 - def test_maximum_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): - test_data = (operand1, operand2) - tester = self._test_maximum_ethos_BI_pipeline( - self.Maximum(), common.get_u55_compile_spec(), test_data - ) - if common.is_option_enabled("corstone_fvp"): - tester.run_method_and_compare_outputs( - qtol=1, inputs=test_data, target_board="corstone-300" - ) - - @parameterized.expand(Maximum.test_parameters) - def test_maximum_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): - test_data = (operand1, operand2) - tester = self._test_maximum_ethos_BI_pipeline( - self.Maximum(), common.get_u85_compile_spec(), test_data - ) - if common.is_option_enabled("corstone_fvp"): - tester.run_method_and_compare_outputs( - qtol=1, inputs=test_data, target_board="corstone-320" - ) diff --git a/backends/arm/test/ops/test_minimum.py b/backends/arm/test/ops/test_minimum.py deleted file mode 100644 index ddbdb24657a..00000000000 --- a/backends/arm/test/ops/test_minimum.py +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -from typing import Tuple - -import torch -from executorch.backends.arm.test import common -from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.exir import EdgeCompileConfig -from executorch.exir.backend.compile_spec_schema import CompileSpec -from parameterized import parameterized - - -class TestMinimum(unittest.TestCase): - """Tests a single minimum op""" - - class Minimum(torch.nn.Module): - test_parameters = [ - ( - torch.FloatTensor([1, 2, 3, 5, 7]), - (torch.FloatTensor([2, 1, 2, 1, 10])), - ), - (torch.ones(1, 10, 4, 6), 2 * torch.ones(1, 10, 4, 6)), - (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)), - (torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)), - (10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)), - ] - - def __init__(self): - super().__init__() - - def forward(self, x, y): - return torch.minimum(x, y) - - _edge_compile_config: EdgeCompileConfig = EdgeCompileConfig( - _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. - ) - - def _test_minimum_tosa_MI_pipeline( - self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] - ): - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), - ) - .export() - .check_count({"torch.ops.aten.minimum.default": 1}) - .check_not(["torch.ops.quantized_decomposed"]) - .to_edge(config=self._edge_compile_config) - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data) - ) - - def _test_minimum_tosa_BI_pipeline( - self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] - ): - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), - ) - .quantize() - .export() - .check_count({"torch.ops.aten.minimum.default": 1}) - .check(["torch.ops.quantized_decomposed"]) - .to_edge(config=self._edge_compile_config) - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data, qtol=1) - ) - - def _test_minimum_ethos_BI_pipeline( - self, - module: torch.nn.Module, - compile_spec: CompileSpec, - test_data: Tuple[torch.Tensor], - ): - tester = ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=compile_spec, - ) - .quantize() - .export() - .to_edge() - .partition() - .to_executorch() - .serialize() - ) - - return tester - - @parameterized.expand(Minimum.test_parameters) - def test_minimum_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): - test_data = (operand1, operand2) - self._test_minimum_tosa_MI_pipeline(self.Minimum(), test_data) - - @parameterized.expand(Minimum.test_parameters) - def test_minimum_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): - test_data = (operand1, operand2) - self._test_minimum_tosa_BI_pipeline(self.Minimum(), test_data) - - @parameterized.expand(Minimum.test_parameters) - @unittest.expectedFailure # Bug in Vela, disabled until pin changes, bug MLETORCH-513 - def test_minimum_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): - test_data = (operand1, operand2) - tester = self._test_minimum_ethos_BI_pipeline( - self.Minimum(), common.get_u55_compile_spec(), test_data - ) - if common.is_option_enabled("corstone_fvp"): - tester.run_method_and_compare_outputs( - qtol=1, inputs=test_data, target_board="corstone-300" - ) - - @parameterized.expand(Minimum.test_parameters) - def test_minimum_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): - test_data = (operand1, operand2) - tester = self._test_minimum_ethos_BI_pipeline( - self.Minimum(), common.get_u85_compile_spec(), test_data - ) - if common.is_option_enabled("corstone_fvp"): - tester.run_method_and_compare_outputs( - qtol=1, inputs=test_data, target_board="corstone-320" - ) diff --git a/backends/arm/test/passes/test_fold_qdq_pass.py b/backends/arm/test/passes/test_fold_qdq_pass.py deleted file mode 100644 index cd7cf751391..00000000000 --- a/backends/arm/test/passes/test_fold_qdq_pass.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2024 Arm Limited and/or its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import torch -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - FoldAndAnnotateQParamsPass, -) - -from executorch.backends.arm.test import common -from executorch.backends.arm.test.tester.arm_tester import ArmTester - -from executorch.backends.xnnpack.test.tester.tester import RunPasses - -from executorch.exir.dialects._ops import ops as exir_ops - - -class SimpleQuantizeModel(torch.nn.Module): - def forward(self, x, y): - return x + torch.max((x + x), (y + y)) - - def get_inputs(self): - return (torch.rand(1, 1280, 7, 7), torch.rand(1, 1280, 7, 7)) - - -class FoldAndAnnotateQParamsPassTestClass(FoldAndAnnotateQParamsPass): - def __init__(self): - super(FoldAndAnnotateQParamsPassTestClass, self).__init__( - [ - exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.maximum.default, - ] - ) - - -class TestFoldAndAnnotateQParamsPass(unittest.TestCase): - """ - Tests the FoldAndAnnotateQParamsPass which folds dq/q nodes into - the node and stores the quantization parameters in meta. - """ - - def test_fold_qdq_pass(self): - """ - Check that the pass runs for add operation and that one q node and one dq node - is removed from the representation. - """ - module = SimpleQuantizeModel() - test_pass_stage = RunPasses([FoldAndAnnotateQParamsPassTestClass]) - ( - ArmTester( - module, - example_inputs=module.get_inputs(), - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), - ) - .quantize() - .export() - .to_edge() - .check_count( - { - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 7, - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 6, - } - ) - .run_passes(test_pass_stage) - .check_count( - { - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1, - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, - } - ) - ) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 9ae1a27cf7e..4de84ed3458 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -127,7 +127,7 @@ def _get_output_node(program: ExportedProgram) -> Node: def _get_output_quantization_params( program: ExportedProgram, output_node: Node -) -> Optional[QuantizationParams]: +) -> QuantizationParams: """ Get output QuantizationParams from a program. Args: @@ -153,6 +153,8 @@ def _get_output_quantization_params( dtype=node.args[5], ) break # break early, there's only one output node + if quant_params is None: + raise RuntimeError("No Quantization parameters not found in exported model.") return quant_params @@ -483,17 +485,13 @@ def run_tosa_ref_model( if tosa_ref_output.dtype == np.int8: tosa_ref_output = tosa_ref_output.astype(np.int32) quant_param = self.qp_output - if quant_param is not None: - # I.e. bool output is possible for quantized models - tosa_ref_output = ( - tosa_ref_output - quant_param.zp - ) * quant_param.scale + assert ( + quant_param is not None + ), "There are no quantization parameters, check output parameters" + tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale if tosa_ref_output.dtype == np.double: tosa_ref_output = tosa_ref_output.astype("float32") - elif tosa_ref_output.dtype == bool: - # retain the bool output though for boolean related comparisons - tosa_ref_output = tosa_ref_output.astype("bool") # tosa_output is a numpy array, convert to torch tensor for comparison tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output)) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 7b129a98877..4f9eae64be8 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -12,6 +12,7 @@ import executorch.backends.xnnpack.test.tester.tester as tester +import numpy as np import serializer.tosa_serializer as ts import torch.fx @@ -318,15 +319,12 @@ def run_method_and_compare_outputs( target_board, ) - quantization_scale = None if is_quantized: reference_stage = self.stages[self.stage_name(tester.Quantize)] - # bool output is quantized with none quantized output so allow - # self.runner_util.qp_output to be none - if self.runner_util.qp_output is not None: - quantization_scale = self.runner_util.qp_output.scale + quantization_scale = self.runner_util.qp_output.scale else: reference_stage = self.stages[self.stage_name(InitialModel)] + quantization_scale = None logger.info( f"Comparing Stage '{self.stage_name(test_stage)}' with Stage '{self.stage_name(reference_stage)}'" @@ -506,7 +504,7 @@ def transpose_data_format( inputs_transposed = list(data) for i in range(len(data)): if hasattr(data[i], "shape") and len(data[i].shape) == 4: - inputs_transposed[i] = torch.permute(data[i], dim_order) + inputs_transposed[i] = np.transpose(data[i], dim_order) return tuple(inputs_transposed) def _compare_outputs( diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index cdc2eaa3d75..b526a2aa8e8 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -57,10 +57,6 @@ def insert_rescale_ops_to_int32( the graph upstream for DQ nodes. """ - from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, - ) - tensors = inputs.copy() # Reshape tensor according to TOSA dim order @@ -68,8 +64,7 @@ def insert_rescale_ops_to_int32( dim_order = tensor.dim_order tensor.shape = [tensor.shape[i] for i in dim_order] - input_qparams = get_input_qparams(node) - qargs = input_qparams.values() + qargs = list(cast(dict[int, QuantArgs], node.meta["input_qparams"]).values()) # Scale the int8 quantized input to a common scale in the integer # domain @@ -89,7 +84,7 @@ def insert_rescale_ops_to_int32( return rescaled_nodes, min_scale -def insert_rescale_op_to_int8( +def insert_rescale_node_back_to_int8( tosa_graph: ts.TosaSerializer, last_tensor: TosaArg, scale: float, @@ -107,14 +102,9 @@ def insert_rescale_op_to_int8( in the node meta dict as opposed to 'rescale_node_back_to_int8' which search the graph downstream for Q nodes. """ - from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_output_qparams, - ) - - output_qparams = get_output_qparams(node) - assert len(output_qparams) == 1, "More than one output not supported" + assert len(node.meta["output_qparams"]) == 1 - qargs_out = output_qparams[0] + qargs_out = cast(dict[int, QuantArgs], node.meta["output_qparams"])[0] output_rescale_scale = scale / qargs_out.scale # Rescale Back to INT8 @@ -146,17 +136,6 @@ def quantize_value(self, x): def dequantize_value(self, qx: int) -> float: return (qx - self.zp) * self.scale - def __eq__(self, other): - if isinstance(other, QuantArgs): - return ( - self.scale == other.scale - and self.zp == other.zp - and self.qmin == other.qmin - and self.qmax == other.qmax - and self.dtype == other.dtype - ) - return False - @classmethod def from_operator(cls, op, args): if op in dq_q_ops: diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 5bda9bbf188..0b03b31582f 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -7,13 +7,18 @@ import logging import os -from typing import Any +from typing import Any, cast import numpy as np import serializer.tosa_serializer as ts import torch from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_quant_utils import ( + get_quant_arg_downstream, + get_quant_arg_upstream, + q_op, +) from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -135,15 +140,10 @@ def build_reshape(tosa_fb, input_name, new_shape, output_name): def is_bias_node_for_quantized_conv(node): consumer_node = list(node.users)[0] - - if ( + return ( consumer_node.target == exir_ops.edge.aten.convolution.default - and consumer_node.args[2] == node - and consumer_node.meta["val"].dtype == torch.int8 - ): - return True - - return False + and list(consumer_node.users)[0].target == q_op + ) def is_consumer_node_depthwise_conv2d(node): @@ -159,6 +159,48 @@ def is_consumer_node_depthwise_conv2d(node): return False +def build_avg_pool_2d_common( + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + input_tensor: TosaArg, + kernel_size: list, + stride: list, + padding: list, + is_quant_node: bool, + output: TosaArg, +): + accumulator_type = input_tensor.dtype + + if is_quant_node: + # Accumulator type always is int32 when input tensor is an integer type. + accumulator_type = ts.DType.INT32 + + # Initilize zero point to zero. + input_zp = 0 + output_zp = 0 + + if is_quant_node: + input_zp = get_quant_arg_upstream(cast(torch.fx.Node, node.args[0])).zp + output_zp = get_quant_arg_downstream(list(node.users)[0]).zp + + attr = ts.TosaSerializerAttribute() + attr.PoolAttribute( + kernel=kernel_size, + stride=stride, + pad=padding, + input_zp=input_zp, + output_zp=output_zp, + accum_dtype=accumulator_type, + ) + + tosa_graph.addOperator( + TosaOp.Op().AVG_POOL2D, + [input_tensor.name], + [output.name], + attr, + ) + + def get_two_inputs(node: Node, check: bool = False) -> tuple[Node, Node]: """Returns two input nodes to 'node' in order. If 'node' only has one input, it is returned twice.