diff --git a/backends/arm/TARGETS b/backends/arm/TARGETS index a737c4bc9de..d1b560df4b8 100644 --- a/backends/arm/TARGETS +++ b/backends/arm/TARGETS @@ -83,7 +83,6 @@ runtime.python_library( "fbsource//third-party/tosa_tools:tosa", "//executorch/backends/arm/operators:node_visitor", "//executorch/backends/arm/tosa:mapping", - "//executorch/backends/arm/tosa:quant_utils", "//executorch/backends/arm/tosa:utils", "//executorch/exir:lib", ], diff --git a/backends/arm/_passes/TARGETS b/backends/arm/_passes/TARGETS index bb4e992ada1..a75c63fb86e 100644 --- a/backends/arm/_passes/TARGETS +++ b/backends/arm/_passes/TARGETS @@ -6,7 +6,6 @@ runtime.python_library( deps = [ "//executorch/backends/arm:common", "//executorch/backends/arm:constants", - "//executorch/backends/arm/tosa:quant_utils", "//executorch/backends/arm/tosa:utils", "//executorch/backends/arm/tosa/dialect:lib", "//executorch/backends/transforms:fuse_view_copy", diff --git a/backends/arm/operators/TARGETS b/backends/arm/operators/TARGETS index afe1c4dd22c..38eb9e7cad9 100644 --- a/backends/arm/operators/TARGETS +++ b/backends/arm/operators/TARGETS @@ -24,7 +24,6 @@ runtime.python_library( ":node_visitor", ":operator_validation_utils", "//executorch/backends/arm/tosa:mapping", - "//executorch/backends/arm/tosa:quant_utils", "//executorch/backends/arm/tosa:utils", "//executorch/backends/arm/_passes:passes", "//executorch/exir:lib", diff --git a/backends/arm/operators/op_index_select.py b/backends/arm/operators/op_index_select.py index db2488fa163..5b73b5e91ae 100644 --- a/backends/arm/operators/op_index_select.py +++ b/backends/arm/operators/op_index_select.py @@ -6,7 +6,6 @@ from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils # noqa: F401 import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( diff --git a/backends/arm/operators/op_tosa_rescale.py b/backends/arm/operators/op_tosa_rescale.py index 75268938579..b8119799e68 100644 --- a/backends/arm/operators/op_tosa_rescale.py +++ b/backends/arm/operators/op_tosa_rescale.py @@ -4,7 +4,8 @@ # LICENSE file in the root directory of this source tree. -from typing import Any, cast, List +import math +from typing import Any, cast, List, Tuple import torch @@ -19,10 +20,142 @@ from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import map_dtype, TosaArg -from executorch.backends.arm.tosa.quant_utils import build_rescale from torch.fx import Node +# TOSA uses the RESCALE operation to scale between values with differing precision. +# The RESCALE operator is defined using an integer multiply, add, and shift. +# This utility function is for calculating the multiplier and shift given a scale. +# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling +def _compute_multiplier_and_shift( + scales: list[float], scaleWidth: int = 32 +) -> Tuple[list[int], list[int]]: + if scaleWidth == 16: + offset = 15 + elif scaleWidth == 32: + offset = 31 + else: + raise ValueError( + f"Unsupported scale width: {scaleWidth}, only 16 and 32 are valid values." + ) + + multipliers = [] + shifts = [] + for scale in scales: + mantissa, exponent = math.frexp(scale) + shift = exponent + + const_2_power_15_or_31 = 1 << offset + shifted_mantissa = round(mantissa * const_2_power_15_or_31) + + assert ( + shifted_mantissa <= const_2_power_15_or_31 + ), f"Mantissa {shifted_mantissa} exceeds limit {const_2_power_15_or_31}" + + if shifted_mantissa == const_2_power_15_or_31: + shifted_mantissa = shifted_mantissa // 2 + shift += 1 + + # TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits. + shift = offset - shift + + # INT32_MAX, 2^31 - 1 + assert shifted_mantissa <= (const_2_power_15_or_31 - 1), ( + f"Mantissa {shifted_mantissa} exceeds signed max " + f"{const_2_power_15_or_31 - 1}" + ) + + multiplier = shifted_mantissa + + if shift > 62: + multiplier = multiplier >> min(31, shift - 62) + shift = 62 + + assert multiplier >= 0, "Multiplier should be non-negative" + assert shift >= 2 and shift <= 62, "Shift should be in range [2, 62]" + multipliers.append(multiplier) + shifts.append(shift) + return multipliers, shifts + + +# For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp and output_zp to be +# const inputs. Create constant operators from the data already initialized. +def _create_const_ops_for_rescale( + tosa_fb, + scale_32, + input_dtype, + node_name, + multipliers, + shifts, + input_zp, + output_zp, + output_dtype, + ts, +): + + multipliers = tosa_fb.addConst( + (len(multipliers),), + ts.DType.INT32 if scale_32 else ts.DType.INT16, + multipliers, + name=node_name + "_multipliers", + ) + shifts = tosa_fb.addConst( + (len(shifts),), ts.DType.INT8, shifts, name=node_name + "_shifts" + ) + input_zp = tosa_fb.addConst( + [1], input_dtype, input_zp, name=node_name + "_input_zp" + ) + output_zp = tosa_fb.addConst( + [1], output_dtype, output_zp, name=node_name + "_output_zp" + ) + + return [multipliers.name, shifts.name, input_zp.name, output_zp.name] + + +def _build_rescale( + tosa_fb: Any, + scale: list[float], + input_node: Any, + output_name: str, + output_type: Any, + input_zp: list[int], + output_zp: list[int], + rounding_mode: ts.RoundingMode, + per_channel: bool = False, + is_scale32: bool = True, +): + scaleWidth = 16 if input_node.dtype == ts.DType.INT48 else 32 + is_scale32 = False if input_node.dtype == ts.DType.INT48 else True + multipliers, shifts = _compute_multiplier_and_shift(scale, scaleWidth) + rescale_inputs = _create_const_ops_for_rescale( + tosa_fb, + is_scale32, + input_node.dtype, + output_name, + multipliers, + shifts, + input_zp, + output_zp, + output_type, + ts, + ) + attr_rescale = ts.TosaSerializerAttribute() + attr_rescale.RescaleAttribute( + scale32=is_scale32, + rounding_mode=rounding_mode, + per_channel=per_channel, + input_unsigned=False, + output_unsigned=False, + ) + + tosa_fb.addOperator( + ts.Op.RESCALE, + [input_node.name, *rescale_inputs], + [output_name], + attr_rescale, + ) + + @register_node_visitor class RescaleVisitor(NodeVisitor): target = "tosa.RESCALE.default" @@ -60,7 +193,7 @@ def define_node( f"If output dtype is not int8 or int16, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}" ) - build_rescale( + _build_rescale( tosa_graph, scale=scales, input_node=inputs[0], diff --git a/backends/arm/tosa/TARGETS b/backends/arm/tosa/TARGETS index 51919025591..d0f7a743f53 100644 --- a/backends/arm/tosa/TARGETS +++ b/backends/arm/tosa/TARGETS @@ -11,20 +11,6 @@ runtime.python_library( ":specification", ], ) -runtime.python_library( - name = "quant_utils", - srcs = [ - "quant_utils.py", - ], - deps = [ - "fbsource//third-party/pypi/numpy:numpy", - "fbsource//third-party/tosa_tools:serializer", - "fbsource//third-party/tosa_tools:tosa", - "//executorch/backends/arm:constants", - ":mapping", - "//executorch/exir/dialects:lib", - ], -) runtime.python_library( name = "specification", srcs = [ @@ -41,7 +27,6 @@ runtime.python_library( "utils.py", ], deps = [ - ":quant_utils", "//executorch/backends/arm/operators:node_visitor", ], ) diff --git a/backends/arm/tosa/quant_utils.py b/backends/arm/tosa/quant_utils.py deleted file mode 100644 index b3840c6ab1c..00000000000 --- a/backends/arm/tosa/quant_utils.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -# Utility functions for TOSA quantized lowerings - -import math - -from typing import Any, Tuple - -import tosa_serializer as ts - - -# TOSA uses the RESCALE operation to scale between values with differing precision. -# The RESCALE operator is defined using an integer multiply, add, and shift. -# This utility function is for calculating the multiplier and shift given a scale. -# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling -def _compute_multiplier_and_shift( - scales: list[float], scaleWidth: int = 32 -) -> Tuple[list[int], list[int]]: - if scaleWidth == 16: - offset = 15 - elif scaleWidth == 32: - offset = 31 - else: - raise ValueError( - f"Unsupported scale width: {scaleWidth}, only 16 and 32 are valid values." - ) - - multipliers = [] - shifts = [] - for scale in scales: - mantissa, exponent = math.frexp(scale) - shift = exponent - - const_2_power_15_or_31 = 1 << offset - shifted_mantissa = round(mantissa * const_2_power_15_or_31) - - assert ( - shifted_mantissa <= const_2_power_15_or_31 - ), f"Mantissa {shifted_mantissa} exceeds limit {const_2_power_15_or_31}" - - if shifted_mantissa == const_2_power_15_or_31: - shifted_mantissa = shifted_mantissa // 2 - shift += 1 - - # TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits. - shift = offset - shift - - # INT32_MAX, 2^31 - 1 - assert shifted_mantissa <= (const_2_power_15_or_31 - 1), ( - f"Mantissa {shifted_mantissa} exceeds signed max " - f"{const_2_power_15_or_31 - 1}" - ) - - multiplier = shifted_mantissa - - if shift > 62: - multiplier = multiplier >> min(31, shift - 62) - shift = 62 - - assert multiplier >= 0, "Multiplier should be non-negative" - assert shift >= 2 and shift <= 62, "Shift should be in range [2, 62]" - multipliers.append(multiplier) - shifts.append(shift) - return multipliers, shifts - - -# For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp and output_zp to be -# const inputs. Create constant operators from the data already initialized. -def _create_const_ops_for_rescale( - tosa_fb, - scale_32, - input_dtype, - node_name, - multipliers, - shifts, - input_zp, - output_zp, - output_dtype, - ts, -): - - multipliers = tosa_fb.addConst( - (len(multipliers),), - ts.DType.INT32 if scale_32 else ts.DType.INT16, - multipliers, - name=node_name + "_multipliers", - ) - shifts = tosa_fb.addConst( - (len(shifts),), ts.DType.INT8, shifts, name=node_name + "_shifts" - ) - input_zp = tosa_fb.addConst( - [1], input_dtype, input_zp, name=node_name + "_input_zp" - ) - output_zp = tosa_fb.addConst( - [1], output_dtype, output_zp, name=node_name + "_output_zp" - ) - - return [multipliers.name, shifts.name, input_zp.name, output_zp.name] - - -def build_rescale( - tosa_fb: Any, - scale: list[float], - input_node: Any, - output_name: str, - output_type: Any, - input_zp: list[int], - output_zp: list[int], - rounding_mode: ts.RoundingMode, - per_channel: bool = False, - is_scale32: bool = True, -): - scaleWidth = 16 if input_node.dtype == ts.DType.INT48 else 32 - is_scale32 = False if input_node.dtype == ts.DType.INT48 else True - multipliers, shifts = _compute_multiplier_and_shift(scale, scaleWidth) - rescale_inputs = _create_const_ops_for_rescale( - tosa_fb, - is_scale32, - input_node.dtype, - output_name, - multipliers, - shifts, - input_zp, - output_zp, - output_type, - ts, - ) - attr_rescale = ts.TosaSerializerAttribute() - attr_rescale.RescaleAttribute( - scale32=is_scale32, - rounding_mode=rounding_mode, - per_channel=per_channel, - input_unsigned=False, - output_unsigned=False, - ) - - tosa_fb.addOperator( - ts.Op.RESCALE, - [input_node.name, *rescale_inputs], - [output_name], - attr_rescale, - ) - - return