Skip to content

Commit 65a9044

Browse files
Arm backend: Add docstrings for tosa/utils.py & tosa rescale helpers (#15772)
cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Sebastian Larsson <[email protected]> Co-authored-by: Zingo Andersen <[email protected]>
1 parent 0f58198 commit 65a9044

File tree

2 files changed

+141
-35
lines changed

2 files changed

+141
-35
lines changed

backends/arm/operators/op_tosa_rescale.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,31 @@
2323
from torch.fx import Node
2424

2525

26-
# TOSA uses the RESCALE operation to scale between values with differing precision.
27-
# The RESCALE operator is defined using an integer multiply, add, and shift.
28-
# This utility function is for calculating the multiplier and shift given a scale.
29-
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
3026
def _compute_multiplier_and_shift(
3127
scales: list[float], scaleWidth: int = 32
3228
) -> Tuple[list[int], list[int]]:
29+
"""Derive integer multipliers and shifts from floating-point scales.
30+
31+
TOSA uses the RESCALE operation to scale between values with differing
32+
precision. The RESCALE operator is defined using an integer multiply, add,
33+
and shift. This utility function is for calculating the multiplier and shift
34+
given a scale.
35+
Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
36+
37+
Args:
38+
scales (list[float]): Scale factors to decompose into multiplier and
39+
shift pairs.
40+
scaleWidth (int): Bit-width of the multiplier representation; expects
41+
``16`` or ``32``.
42+
43+
Returns:
44+
Tuple[list[int], list[int]]: Parallel lists containing the computed
45+
multipliers and right shifts.
46+
47+
Raises:
48+
ValueError: If ``scaleWidth`` is not supported.
49+
50+
"""
3351
if scaleWidth == 16:
3452
offset = 15
3553
elif scaleWidth == 32:
@@ -78,8 +96,6 @@ def _compute_multiplier_and_shift(
7896
return multipliers, shifts
7997

8098

81-
# For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp and output_zp to be
82-
# const inputs. Create constant operators from the data already initialized.
8399
def _create_const_ops_for_rescale(
84100
tosa_fb,
85101
scale_32,
@@ -92,6 +108,29 @@ def _create_const_ops_for_rescale(
92108
output_dtype,
93109
ts,
94110
):
111+
"""Materialize constant operands required by the TOSA RESCALE op.
112+
113+
For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp
114+
and output_zp to be const inputs. Create constant operators from the data
115+
already initialized.
116+
117+
Args:
118+
tosa_fb (Any): Graph builder used to emit TOSA operators and tensors.
119+
scale_32 (bool): Flag indicating whether multipliers use 32-bit width.
120+
input_dtype (ts.DType): Data type of the input tensor.
121+
node_name (str): Base name reused for created constant tensors.
122+
multipliers (list[int]): Precomputed multiplier coefficients.
123+
shifts (list[int]): Precomputed shift coefficients.
124+
input_zp (list[int]): Quantization zero points for the input.
125+
output_zp (list[int]): Quantization zero points for the output.
126+
output_dtype (ts.DType): Data type of the output tensor.
127+
ts (module): Reference to the ``tosa_serializer`` module.
128+
129+
Returns:
130+
list[str]: Names of the constant tensors added to ``tosa_fb`` in the
131+
order expected by RESCALE.
132+
133+
"""
95134

96135
multipliers = tosa_fb.addConst(
97136
(len(multipliers),),
@@ -124,6 +163,22 @@ def _build_rescale(
124163
per_channel: bool = False,
125164
is_scale32: bool = True,
126165
):
166+
"""Insert a TOSA RESCALE operator configured for the quantized path.
167+
168+
Args:
169+
tosa_fb (Any): Graph builder receiving the RESCALE operator.
170+
scale (list[float]): Scale factors applied during rescaling.
171+
input_node (Any): Input tensor node feeding the operator.
172+
output_name (str): Name assigned to the RESCALE output tensor.
173+
output_type (ts.DType): Data type of the output tensor.
174+
input_zp (list[int]): Quantization zero points for the input tensor.
175+
output_zp (list[int]): Quantization zero points for the output tensor.
176+
rounding_mode (ts.RoundingMode): Rounding policy for the RESCALE op.
177+
per_channel (bool): Whether scales are applied per output channel.
178+
is_scale32 (bool): Declared scale width; ignored when the input type is
179+
``ts.DType.INT48``.
180+
181+
"""
127182
scaleWidth = 16 if input_node.dtype == ts.DType.INT48 else 32
128183
is_scale32 = False if input_node.dtype == ts.DType.INT48 else True
129184
multipliers, shifts = _compute_multiplier_and_shift(scale, scaleWidth)

backends/arm/tosa/utils.py

Lines changed: 80 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5-
5+
"""Utility helpers for building TOSA graphs in the Arm backend."""
66

77
import logging
88
from typing import Any
@@ -26,19 +26,21 @@
2626
def are_fake_tensors_broadcastable(
2727
fake_tensors: list[FakeTensor],
2828
) -> tuple[bool, list[int]]:
29-
"""
30-
Determines whether a list of FakeTensors can be broadcast together.
29+
"""Determine whether the fake tensors share a broadcastable shape.
30+
3131
Args:
32-
fake_tensors (list[FakeTensor]): List of 2 or more FakeTensors
33-
who's shapes to evaluate
32+
fake_tensors (list[FakeTensor]): Fake tensors whose shapes should
33+
be validated for broadcasting.
3434
3535
Returns:
36-
tuple[bool, list[int]]: First element is whether the shapes are
37-
broadcastable. Second element is the common shape if compatible.
38-
If not, empty list.
36+
tuple[bool, list[int]]: Tuple where the first element indicates
37+
whether broadcasting is possible and the second element contains
38+
the broadcast shape. The shape list is empty when broadcasting
39+
fails.
3940
4041
Raises:
41-
RuntimeError: If less than 2 tensors are passed in.
42+
RuntimeError: Raised when fewer than two tensors are supplied.
43+
4244
"""
4345
if len(fake_tensors) < 1:
4446
raise RuntimeError(f"Expected 2 or more tensors got {len(fake_tensors)}")
@@ -65,26 +67,27 @@ def are_fake_tensors_broadcastable(
6567
def broadcast_tensors(
6668
tosa_fb, nodes: list[Node], tosa_spec: TosaSpecification
6769
) -> list[Any]:
68-
"""
69-
Given a list of nodes it determines the common shape they broadcast to
70-
and adds the necessary reshape and tile operations to perform the broadcast.
70+
"""Broadcast the FX nodes to a shared shape inside the TOSA graph.
71+
72+
This mirrors ``reshape_for_broadcast`` but also emits the tile operators
73+
needed to materialize the broadcast and supports any number of inputs.
7174
7275
Args:
73-
tosa_fb: Tosa graph to add nodes to
74-
nodes (list[Node]): List of nodes to broadcast together
75-
tosa_spec (TosaSpecification): Tosa spec
76+
tosa_fb (Any): TOSA graph builder that receives the broadcast
77+
operators.
78+
nodes (list[Node]): FX nodes whose tensor metadata should be
79+
broadcast.
80+
tosa_spec (TosaSpecification): Active TOSA specification used to
81+
decode tensor metadata.
7682
7783
Returns:
78-
list[Any]: List containing the fx.Nodes or TosaSerializerTensors
79-
of the right common shape. Order of output matches order of input.
84+
list[Any]: Broadcast versions of the inputs. Each element is either
85+
the original FX node or a TOSA serializer tensor, ordered to match
86+
``nodes``.
8087
8188
Raises:
8289
RuntimeError: If the supplied nodes are not broadcastable.
8390
84-
Note:
85-
This function and `reshape_for_broadcast` both reshape the tensors
86-
for broadcast. However this function also performs the broadcast and
87-
does not have a limit on only two input tensors.
8891
"""
8992
index_fake_tensors = [node.meta["val"] for node in nodes]
9093
broadcastable, common_shape = are_fake_tensors_broadcastable(index_fake_tensors)
@@ -137,6 +140,17 @@ def broadcast_tensors(
137140
def build_reshape_tosa_1_0(
138141
tosa_graph, input_name, new_shape, output_name, shape_name_override=""
139142
):
143+
"""Insert a TOSA reshape operator using the v1.0 semantics.
144+
145+
Args:
146+
tosa_graph (Any): Graph builder used to emit TOSA operators.
147+
input_name (str): Name of the tensor that should be reshaped.
148+
new_shape (list[int]): Target tensor shape.
149+
output_name (str): Name assigned to the reshaped tensor.
150+
shape_name_override (str): Optional override for the shape constant
151+
name.
152+
153+
"""
140154
shape = tosa_graph.addConst(
141155
np.array(new_shape).shape,
142156
ts.DType.SHAPE,
@@ -155,6 +169,19 @@ def build_reshape_tosa_1_0(
155169

156170

157171
def tosa_shape(shape, dim_order):
172+
"""Reorder a shape tuple into TOSA layout while resolving symints.
173+
174+
Args:
175+
shape (Sequence[int | torch.SymInt]): Original tensor shape,
176+
possibly containing ``torch.SymInt``.
177+
dim_order (Sequence[int]): Desired dimension order for the output
178+
shape.
179+
180+
Returns:
181+
list[int]: List containing the reordered dimensions where symbolic
182+
values become ``-1``.
183+
184+
"""
158185
reordered = tuple([shape[dim] for dim in dim_order])
159186
# Dynamic shapes in executorch are represented with torch.SymInt objects in the shapes,
160187
# in TOSA we do not have this concept and instead use -1.
@@ -170,6 +197,26 @@ def get_resize_parameters_1d(
170197
resize_mode: int,
171198
align_corners: bool,
172199
):
200+
"""Compute resize coefficients for a single spatial dimension.
201+
202+
Args:
203+
input_size (int | torch.SymInt): Input size for the axis, possibly
204+
symbolic.
205+
output_size (int | torch.SymInt): Output size for the axis, possibly
206+
symbolic.
207+
resize_mode (int): Target resize mode defined by TOSA.
208+
align_corners (bool): Whether the resize should align the corner
209+
pixels.
210+
211+
Returns:
212+
tuple[int, int, int, int]: Numerator, denominator, offset, and border
213+
terms encoded as integers.
214+
215+
Raises:
216+
RuntimeError: If symbolic shapes are used with ``align_corners`` or if
217+
the computed ratio or border is not constant.
218+
219+
"""
173220
# We don't support align_corners for symbolic shapes, because handling the edge case where size == 1 is tricky.
174221
if align_corners:
175222
if (not isinstance(input_size, int)) or (not isinstance(output_size, int)):
@@ -229,19 +276,23 @@ def get_resize_parameters(
229276
resize_mode: int,
230277
align_corners: bool,
231278
) -> tuple[torch.IntTensor, ...]:
232-
"""Get the tosa.resize parameters based on the input and output size.
279+
"""Calculate 2D resize parameters for TOSA emission.
233280
234281
Args:
235-
input_size_xy (tuple[int | torch.SymInt]): Size of the input
236-
output_size_xy (tuple[int | torch.SymInt]): Size of the output
237-
resize_mode (tosa.ResizeMode): The TOSA resize mode
238-
align_corners (bool): Align the corners pixels of the input and output
282+
input_size_xy (tuple[int | torch.SymInt, int | torch.SymInt]): Height
283+
and width of the input tensor.
284+
output_size_xy (tuple[int | torch.SymInt, int | torch.SymInt]): Height
285+
and width of the output tensor.
286+
resize_mode (int): TOSA resize mode used for coefficient generation.
287+
align_corners (bool): Whether to align corner pixels between input and
288+
output.
239289
240290
Returns:
241-
scale_n (torch.IntTensor), scale_d (torch.IntTensor),
242-
offset (torch.IntTensor), border (torch.IntTensor)
243-
"""
291+
tuple[torch.IntTensor, ...]: Four-element tuple of tensors describing
292+
the scale numerator, scale denominator, offset, and border for Y
293+
and X dimensions.
244294
295+
"""
245296
# Get the parameters for each dimension independently
246297
y_params = get_resize_parameters_1d(
247298
input_size_xy[0], output_size_xy[0], resize_mode, align_corners

0 commit comments

Comments
 (0)