diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index d56e70e78b3..f6b8fda8a3d 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math from copy import copy from typing import cast, Dict, Optional, Set, Tuple, Type @@ -93,6 +94,7 @@ class InsertRescaleInt32Pass(ArmPass): exir_ops.edge.aten.lt.Tensor, exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.mul.Tensor, ] def _int32_qargs(self, s): @@ -133,6 +135,15 @@ def _get_inputs_rescaled_qparams( qparams = { i: self._int32_qargs(min_scale) for i in range(len(input_qparams)) } + elif target in [ + exir_ops.edge.aten.mul.Tensor, + ]: + # The input scales do not need to be adjusted for these ops; they + # can remain the same. + qparams = { + i: self._int32_qargs(input_qparams[i].get_scale_per_tensor()) + for i in range(len(input_qparams)) + } else: raise ValueError(f"Not a valid target: {target}") @@ -161,6 +172,20 @@ def _get_output_qparams( ]: # Output is bool for these ops and thus no qparams are present return None + elif target in [exir_ops.edge.aten.mul.Tensor]: + # Mul will cause the scales to also multiply; refer to the formula + # where we compute the output scale S_2: + # + # (Q_2 - ZP_2) * S_2 == ((Q_0 - ZP_0) * S_0) * ((Q_1 - ZP_1) * S_1) + # + # yields: + # + # (Q_2 - ZP_2) == (Q_0 - ZP_0) * (Q_1 - ZP_1) + # S_2 = S_0 * S_1 + output_scale = math.prod( + (qp.get_scale_per_tensor() for qp in inputs_qparams.values()) + ) + return self._int32_qargs(output_scale) else: raise ValueError(f"Not a valid target: {target}") @@ -187,7 +212,7 @@ def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> b modified = False for i in qargs: qp = qargs[i] - if qp.dtype != torch.int8: + if qp.dtype not in (torch.int8, torch.int16): continue arg_node = args_copy[i] @@ -226,7 +251,7 @@ def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> b assert rescale_qargs is not None qarg = qargs[0] - if qarg.dtype != torch.int8: + if qarg.dtype not in (torch.int8, torch.int16): return False users_copy = list(node.users) @@ -237,7 +262,7 @@ def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> b exir_ops.backend.tosa.RESCALE.default, ( node, - torch.int8, + qarg.dtype, rescale_qargs.get_scale_per_tensor() / qarg.get_scale_per_tensor(), # Old scale / new scale rescale_qargs.get_zp_per_tensor(), # Old zero point diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index 9d139c68242..9178aa4d014 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -7,14 +7,8 @@ from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils -import executorch.backends.arm.tosa.utils as tutils import torch -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, -) - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -24,17 +18,17 @@ validate_same_dtype, validate_valid_dtype, ) -from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import TosaSpecification @register_node_visitor -class MulVisitor_INT(NodeVisitor): +class MulVisitor(NodeVisitor): target = "aten.mul.Tensor" tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+FP"), TosaSpecification.create_from_string("TOSA-1.0+INT"), - TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), ] def define_node( @@ -52,105 +46,13 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) - if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16: - input_A = inputs[0] - input_B = inputs[1] - input_qparams = get_input_qparams(node) - input_A_qargs = input_qparams[0] - input_B_qargs = input_qparams[1] - input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order) - input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order) - - # Rescale inputs to INT32 with zp=0 - input_A_rescaled = tqutils.build_rescale_to_int32( - tosa_graph, - input_A, - input_A_qargs.get_zp_per_tensor(), - 1.0, - tosa_spec=self.tosa_spec, - ) - input_B_rescaled = tqutils.build_rescale_to_int32( - tosa_graph, - input_B, - input_B_qargs.get_zp_per_tensor(), - 1.0, - tosa_spec=self.tosa_spec, - ) - else: - # input[0].dtype == ts.DType.INT16 or ts.DType.INT32 - # Non quantized input, natively support by TOSA.MUL - input_A_rescaled, input_B_rescaled = inputs[0], inputs[1] - - if output.dtype == ts.DType.INT8 or output.dtype == ts.DType.INT16: - output_shape = tutils.tosa_shape(output.shape, output.dim_order) - mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32) - else: - # output.dtype == ts.DType.INT32 (non-quantized) - mul_output = output - - # Do the INT32 Mul - tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift") - self._serialize_operator( - node, - tosa_graph, - ts.TosaOp.Op().MUL, - [input_A_rescaled.name, input_B_rescaled.name, f"{node.name}_shift"], - [mul_output.name], - ) - - if output.dtype == ts.DType.INT8: - # Scale output back to 8 bit - output_scale = ( - input_A_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined] - * input_B_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined] - ) - tqutils.insert_rescale_op_to_int8( - tosa_graph, mul_output, output_scale, node, self.tosa_spec - ) - elif output.dtype == ts.DType.INT16: - # Scale output back to 16 bit - output_scale = ( - input_A_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined] - * input_B_qargs.get_scale_per_tensor() # type: ignore[possibly-undefined] - ) - tqutils.insert_rescale_op_to_int16( - tosa_graph, mul_output, output_scale, node, self.tosa_spec - ) - - -@register_node_visitor -class MulVisitor_FP(MulVisitor_INT): - # inheriting 'target' from INT class - - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - - validate_num_inputs(self.target, inputs, 2) - validate_same_dtype(self.target, [*inputs, output], ts) - - if inputs[0].dtype == ts.DType.INT8: - return super().define_node(node, tosa_graph, inputs, output) - - input1, input2 = inputs - tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift") - self._serialize_operator( - node, - tosa_graph, + tosa_graph.addOperator( ts.TosaOp.Op().MUL, - [input1.name, input2.name, f"{node.name}_shift"], + [inputs[0].name, inputs[1].name, f"{node.name}_shift"], [output.name], ) diff --git a/backends/arm/test/passes/test_insert_rescale_i32_pass.py b/backends/arm/test/passes/test_insert_rescale_i32_pass.py index 096c90d330d..66f09ba89a9 100644 --- a/backends/arm/test/passes/test_insert_rescale_i32_pass.py +++ b/backends/arm/test/passes/test_insert_rescale_i32_pass.py @@ -22,10 +22,11 @@ def __init__(self): super().__init__() def forward(self, x, y): - a = torch.maximum(x, y) - b = torch.abs(a) - c = a > b - return c + a = x * y + b = torch.maximum(a, y) + c = torch.abs(b) + d = c > b + return d def get_inputs(self, dtype) -> input_t: if dtype == torch.float32: @@ -45,8 +46,8 @@ def test_insert_rescales(): ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"} ops_after = { # "number of op nodes with i8 output" + "number of i8 node inputs" - "executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 2 - + 5, + "executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 3 + + 7, } pipeline = PassPipeline[input_t]( module,