Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion backends/arm/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ runtime.python_library(
"fbsource//third-party/tosa_tools:tosa",
"//executorch/backends/arm/operators:node_visitor",
"//executorch/backends/arm/tosa:mapping",
"//executorch/backends/arm/tosa:quant_utils",
"//executorch/backends/arm/tosa:utils",
"//executorch/exir:lib",
],
Expand Down
1 change: 0 additions & 1 deletion backends/arm/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ runtime.python_library(
deps = [
"//executorch/backends/arm:common",
"//executorch/backends/arm:constants",
"//executorch/backends/arm/tosa:quant_utils",
"//executorch/backends/arm/tosa:utils",
"//executorch/backends/arm/tosa/dialect:lib",
"//executorch/backends/transforms:fuse_view_copy",
Expand Down
1 change: 0 additions & 1 deletion backends/arm/operators/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ runtime.python_library(
":node_visitor",
":operator_validation_utils",
"//executorch/backends/arm/tosa:mapping",
"//executorch/backends/arm/tosa:quant_utils",
"//executorch/backends/arm/tosa:utils",
"//executorch/backends/arm/_passes:passes",
"//executorch/exir:lib",
Expand Down
1 change: 0 additions & 1 deletion backends/arm/operators/op_index_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from typing import Any, List

import executorch.backends.arm.tosa.quant_utils as tqutils # noqa: F401
import tosa_serializer as ts

from executorch.backends.arm.operators.node_visitor import (
Expand Down
139 changes: 136 additions & 3 deletions backends/arm/operators/op_tosa_rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# LICENSE file in the root directory of this source tree.


from typing import Any, cast, List
import math
from typing import Any, cast, List, Tuple

import torch

Expand All @@ -19,10 +20,142 @@

from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.arm.tosa.mapping import map_dtype, TosaArg
from executorch.backends.arm.tosa.quant_utils import build_rescale
from torch.fx import Node


# TOSA uses the RESCALE operation to scale between values with differing precision.
# The RESCALE operator is defined using an integer multiply, add, and shift.
# This utility function is for calculating the multiplier and shift given a scale.
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
def _compute_multiplier_and_shift(
scales: list[float], scaleWidth: int = 32
) -> Tuple[list[int], list[int]]:
if scaleWidth == 16:
offset = 15
elif scaleWidth == 32:
offset = 31
else:
raise ValueError(
f"Unsupported scale width: {scaleWidth}, only 16 and 32 are valid values."
)

multipliers = []
shifts = []
for scale in scales:
mantissa, exponent = math.frexp(scale)
shift = exponent

const_2_power_15_or_31 = 1 << offset
shifted_mantissa = round(mantissa * const_2_power_15_or_31)

assert (
shifted_mantissa <= const_2_power_15_or_31
), f"Mantissa {shifted_mantissa} exceeds limit {const_2_power_15_or_31}"

if shifted_mantissa == const_2_power_15_or_31:
shifted_mantissa = shifted_mantissa // 2
shift += 1

# TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits.
shift = offset - shift

# INT32_MAX, 2^31 - 1
assert shifted_mantissa <= (const_2_power_15_or_31 - 1), (
f"Mantissa {shifted_mantissa} exceeds signed max "
f"{const_2_power_15_or_31 - 1}"
)

multiplier = shifted_mantissa

if shift > 62:
multiplier = multiplier >> min(31, shift - 62)
shift = 62

assert multiplier >= 0, "Multiplier should be non-negative"
assert shift >= 2 and shift <= 62, "Shift should be in range [2, 62]"
multipliers.append(multiplier)
shifts.append(shift)
return multipliers, shifts


# For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp and output_zp to be
# const inputs. Create constant operators from the data already initialized.
def _create_const_ops_for_rescale(
tosa_fb,
scale_32,
input_dtype,
node_name,
multipliers,
shifts,
input_zp,
output_zp,
output_dtype,
ts,
):

multipliers = tosa_fb.addConst(
(len(multipliers),),
ts.DType.INT32 if scale_32 else ts.DType.INT16,
multipliers,
name=node_name + "_multipliers",
)
shifts = tosa_fb.addConst(
(len(shifts),), ts.DType.INT8, shifts, name=node_name + "_shifts"
)
input_zp = tosa_fb.addConst(
[1], input_dtype, input_zp, name=node_name + "_input_zp"
)
output_zp = tosa_fb.addConst(
[1], output_dtype, output_zp, name=node_name + "_output_zp"
)

return [multipliers.name, shifts.name, input_zp.name, output_zp.name]


def _build_rescale(
tosa_fb: Any,
scale: list[float],
input_node: Any,
output_name: str,
output_type: Any,
input_zp: list[int],
output_zp: list[int],
rounding_mode: ts.RoundingMode,
per_channel: bool = False,
is_scale32: bool = True,
):
scaleWidth = 16 if input_node.dtype == ts.DType.INT48 else 32
is_scale32 = False if input_node.dtype == ts.DType.INT48 else True
multipliers, shifts = _compute_multiplier_and_shift(scale, scaleWidth)
rescale_inputs = _create_const_ops_for_rescale(
tosa_fb,
is_scale32,
input_node.dtype,
output_name,
multipliers,
shifts,
input_zp,
output_zp,
output_type,
ts,
)
attr_rescale = ts.TosaSerializerAttribute()
attr_rescale.RescaleAttribute(
scale32=is_scale32,
rounding_mode=rounding_mode,
per_channel=per_channel,
input_unsigned=False,
output_unsigned=False,
)

tosa_fb.addOperator(
ts.Op.RESCALE,
[input_node.name, *rescale_inputs],
[output_name],
attr_rescale,
)


@register_node_visitor
class RescaleVisitor(NodeVisitor):
target = "tosa.RESCALE.default"
Expand Down Expand Up @@ -60,7 +193,7 @@ def define_node(
f"If output dtype is not int8 or int16, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}"
)

build_rescale(
_build_rescale(
tosa_graph,
scale=scales,
input_node=inputs[0],
Expand Down
15 changes: 0 additions & 15 deletions backends/arm/tosa/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,6 @@ runtime.python_library(
":specification",
],
)
runtime.python_library(
name = "quant_utils",
srcs = [
"quant_utils.py",
],
deps = [
"fbsource//third-party/pypi/numpy:numpy",
"fbsource//third-party/tosa_tools:serializer",
"fbsource//third-party/tosa_tools:tosa",
"//executorch/backends/arm:constants",
":mapping",
"//executorch/exir/dialects:lib",
],
)
runtime.python_library(
name = "specification",
srcs = [
Expand All @@ -41,7 +27,6 @@ runtime.python_library(
"utils.py",
],
deps = [
":quant_utils",
"//executorch/backends/arm/operators:node_visitor",
],
)
Expand Down
148 changes: 0 additions & 148 deletions backends/arm/tosa/quant_utils.py

This file was deleted.

Loading