Skip to content

Commit 574c5d3

Browse files
Martin LindströmMartin Lindström
authored andcommitted
Arm backend: Remove quant_utils.py
quant_utils.py only have one public function, `build_rescale`, which in turn is only use by one file op_tosa_rescale.py. For this reason, move `build_rescale` into op_tosa_rescale.py and remove quant_utils.py. Signed-off-by: Martin Lindström <[email protected]> Change-Id: I2a3564e60e1566979b0c3714d3b17cded91b61cd
1 parent 4efd79c commit 574c5d3

File tree

7 files changed

+136
-170
lines changed

7 files changed

+136
-170
lines changed

backends/arm/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ runtime.python_library(
8383
"fbsource//third-party/tosa_tools:tosa",
8484
"//executorch/backends/arm/operators:node_visitor",
8585
"//executorch/backends/arm/tosa:mapping",
86-
"//executorch/backends/arm/tosa:quant_utils",
8786
"//executorch/backends/arm/tosa:utils",
8887
"//executorch/exir:lib",
8988
],

backends/arm/_passes/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ runtime.python_library(
66
deps = [
77
"//executorch/backends/arm:common",
88
"//executorch/backends/arm:constants",
9-
"//executorch/backends/arm/tosa:quant_utils",
109
"//executorch/backends/arm/tosa:utils",
1110
"//executorch/backends/arm/tosa/dialect:lib",
1211
"//executorch/backends/transforms:fuse_view_copy",

backends/arm/operators/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ runtime.python_library(
2424
":node_visitor",
2525
":operator_validation_utils",
2626
"//executorch/backends/arm/tosa:mapping",
27-
"//executorch/backends/arm/tosa:quant_utils",
2827
"//executorch/backends/arm/tosa:utils",
2928
"//executorch/backends/arm/_passes:passes",
3029
"//executorch/exir:lib",

backends/arm/operators/op_index_select.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from typing import Any, List
88

9-
import executorch.backends.arm.tosa.quant_utils as tqutils # noqa: F401
109
import tosa_serializer as ts
1110

1211
from executorch.backends.arm.operators.node_visitor import (

backends/arm/operators/op_tosa_rescale.py

Lines changed: 136 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7-
from typing import Any, cast, List
7+
import math
8+
from typing import Any, cast, List, Tuple
89

910
import torch
1011

@@ -19,10 +20,142 @@
1920

2021
from executorch.backends.arm.tosa import TosaSpecification
2122
from executorch.backends.arm.tosa.mapping import map_dtype, TosaArg
22-
from executorch.backends.arm.tosa.quant_utils import build_rescale
2323
from torch.fx import Node
2424

2525

26+
# TOSA uses the RESCALE operation to scale between values with differing precision.
27+
# The RESCALE operator is defined using an integer multiply, add, and shift.
28+
# This utility function is for calculating the multiplier and shift given a scale.
29+
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
30+
def _compute_multiplier_and_shift(
31+
scales: list[float], scaleWidth: int = 32
32+
) -> Tuple[list[int], list[int]]:
33+
if scaleWidth == 16:
34+
offset = 15
35+
elif scaleWidth == 32:
36+
offset = 31
37+
else:
38+
raise ValueError(
39+
f"Unsupported scale width: {scaleWidth}, only 16 and 32 are valid values."
40+
)
41+
42+
multipliers = []
43+
shifts = []
44+
for scale in scales:
45+
mantissa, exponent = math.frexp(scale)
46+
shift = exponent
47+
48+
const_2_power_15_or_31 = 1 << offset
49+
shifted_mantissa = round(mantissa * const_2_power_15_or_31)
50+
51+
assert (
52+
shifted_mantissa <= const_2_power_15_or_31
53+
), f"Mantissa {shifted_mantissa} exceeds limit {const_2_power_15_or_31}"
54+
55+
if shifted_mantissa == const_2_power_15_or_31:
56+
shifted_mantissa = shifted_mantissa // 2
57+
shift += 1
58+
59+
# TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits.
60+
shift = offset - shift
61+
62+
# INT32_MAX, 2^31 - 1
63+
assert shifted_mantissa <= (const_2_power_15_or_31 - 1), (
64+
f"Mantissa {shifted_mantissa} exceeds signed max "
65+
f"{const_2_power_15_or_31 - 1}"
66+
)
67+
68+
multiplier = shifted_mantissa
69+
70+
if shift > 62:
71+
multiplier = multiplier >> min(31, shift - 62)
72+
shift = 62
73+
74+
assert multiplier >= 0, "Multiplier should be non-negative"
75+
assert shift >= 2 and shift <= 62, "Shift should be in range [2, 62]"
76+
multipliers.append(multiplier)
77+
shifts.append(shift)
78+
return multipliers, shifts
79+
80+
81+
# For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp and output_zp to be
82+
# const inputs. Create constant operators from the data already initialized.
83+
def _create_const_ops_for_rescale(
84+
tosa_fb,
85+
scale_32,
86+
input_dtype,
87+
node_name,
88+
multipliers,
89+
shifts,
90+
input_zp,
91+
output_zp,
92+
output_dtype,
93+
ts,
94+
):
95+
96+
multipliers = tosa_fb.addConst(
97+
(len(multipliers),),
98+
ts.DType.INT32 if scale_32 else ts.DType.INT16,
99+
multipliers,
100+
name=node_name + "_multipliers",
101+
)
102+
shifts = tosa_fb.addConst(
103+
(len(shifts),), ts.DType.INT8, shifts, name=node_name + "_shifts"
104+
)
105+
input_zp = tosa_fb.addConst(
106+
[1], input_dtype, input_zp, name=node_name + "_input_zp"
107+
)
108+
output_zp = tosa_fb.addConst(
109+
[1], output_dtype, output_zp, name=node_name + "_output_zp"
110+
)
111+
112+
return [multipliers.name, shifts.name, input_zp.name, output_zp.name]
113+
114+
115+
def _build_rescale(
116+
tosa_fb: Any,
117+
scale: list[float],
118+
input_node: Any,
119+
output_name: str,
120+
output_type: Any,
121+
input_zp: list[int],
122+
output_zp: list[int],
123+
rounding_mode: ts.RoundingMode,
124+
per_channel: bool = False,
125+
is_scale32: bool = True,
126+
):
127+
scaleWidth = 16 if input_node.dtype == ts.DType.INT48 else 32
128+
is_scale32 = False if input_node.dtype == ts.DType.INT48 else True
129+
multipliers, shifts = _compute_multiplier_and_shift(scale, scaleWidth)
130+
rescale_inputs = _create_const_ops_for_rescale(
131+
tosa_fb,
132+
is_scale32,
133+
input_node.dtype,
134+
output_name,
135+
multipliers,
136+
shifts,
137+
input_zp,
138+
output_zp,
139+
output_type,
140+
ts,
141+
)
142+
attr_rescale = ts.TosaSerializerAttribute()
143+
attr_rescale.RescaleAttribute(
144+
scale32=is_scale32,
145+
rounding_mode=rounding_mode,
146+
per_channel=per_channel,
147+
input_unsigned=False,
148+
output_unsigned=False,
149+
)
150+
151+
tosa_fb.addOperator(
152+
ts.Op.RESCALE,
153+
[input_node.name, *rescale_inputs],
154+
[output_name],
155+
attr_rescale,
156+
)
157+
158+
26159
@register_node_visitor
27160
class RescaleVisitor(NodeVisitor):
28161
target = "tosa.RESCALE.default"
@@ -60,7 +193,7 @@ def define_node(
60193
f"If output dtype is not int8 or int16, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}"
61194
)
62195

63-
build_rescale(
196+
_build_rescale(
64197
tosa_graph,
65198
scale=scales,
66199
input_node=inputs[0],

backends/arm/tosa/TARGETS

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,6 @@ runtime.python_library(
1111
":specification",
1212
],
1313
)
14-
runtime.python_library(
15-
name = "quant_utils",
16-
srcs = [
17-
"quant_utils.py",
18-
],
19-
deps = [
20-
"fbsource//third-party/pypi/numpy:numpy",
21-
"fbsource//third-party/tosa_tools:serializer",
22-
"fbsource//third-party/tosa_tools:tosa",
23-
"//executorch/backends/arm:constants",
24-
":mapping",
25-
"//executorch/exir/dialects:lib",
26-
],
27-
)
2814
runtime.python_library(
2915
name = "specification",
3016
srcs = [
@@ -41,7 +27,6 @@ runtime.python_library(
4127
"utils.py",
4228
],
4329
deps = [
44-
":quant_utils",
4530
"//executorch/backends/arm/operators:node_visitor",
4631
],
4732
)

backends/arm/tosa/quant_utils.py

Lines changed: 0 additions & 148 deletions
This file was deleted.

0 commit comments

Comments
 (0)