22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
5-
5+ """Utility helpers for building TOSA graphs in the Arm backend."""
66
77import logging
88from typing import Any
2626def are_fake_tensors_broadcastable (
2727 fake_tensors : list [FakeTensor ],
2828) -> tuple [bool , list [int ]]:
29- """
30- Determines whether a list of FakeTensors can be broadcast together.
29+ """Determine whether the fake tensors share a broadcastable shape.
30+
3131 Args:
32- fake_tensors (list[FakeTensor]): List of 2 or more FakeTensors
33- who's shapes to evaluate
32+ fake_tensors (list[FakeTensor]): Fake tensors whose shapes should
33+ be validated for broadcasting.
3434
3535 Returns:
36- tuple[bool, list[int]]: First element is whether the shapes are
37- broadcastable. Second element is the common shape if compatible.
38- If not, empty list.
36+ tuple[bool, list[int]]: Tuple where the first element indicates
37+ whether broadcasting is possible and the second element contains
38+ the broadcast shape. The shape list is empty when broadcasting
39+ fails.
3940
4041 Raises:
41- RuntimeError: If less than 2 tensors are passed in.
42+ RuntimeError: Raised when fewer than two tensors are supplied.
43+
4244 """
4345 if len (fake_tensors ) < 1 :
4446 raise RuntimeError (f"Expected 2 or more tensors got { len (fake_tensors )} " )
@@ -65,26 +67,27 @@ def are_fake_tensors_broadcastable(
6567def broadcast_tensors (
6668 tosa_fb , nodes : list [Node ], tosa_spec : TosaSpecification
6769) -> list [Any ]:
68- """
69- Given a list of nodes it determines the common shape they broadcast to
70- and adds the necessary reshape and tile operations to perform the broadcast.
70+ """Broadcast the FX nodes to a shared shape inside the TOSA graph.
71+
72+ This mirrors ``reshape_for_broadcast`` but also emits the tile operators
73+ needed to materialize the broadcast and supports any number of inputs.
7174
7275 Args:
73- tosa_fb: Tosa graph to add nodes to
74- nodes (list[Node]): List of nodes to broadcast together
75- tosa_spec (TosaSpecification): Tosa spec
76+ tosa_fb (Any): TOSA graph builder that receives the broadcast
77+ operators.
78+ nodes (list[Node]): FX nodes whose tensor metadata should be
79+ broadcast.
80+ tosa_spec (TosaSpecification): Active TOSA specification used to
81+ decode tensor metadata.
7682
7783 Returns:
78- list[Any]: List containing the fx.Nodes or TosaSerializerTensors
79- of the right common shape. Order of output matches order of input.
84+ list[Any]: Broadcast versions of the inputs. Each element is either
85+ the original FX node or a TOSA serializer tensor, ordered to match
86+ ``nodes``.
8087
8188 Raises:
8289 RuntimeError: If the supplied nodes are not broadcastable.
8390
84- Note:
85- This function and `reshape_for_broadcast` both reshape the tensors
86- for broadcast. However this function also performs the broadcast and
87- does not have a limit on only two input tensors.
8891 """
8992 index_fake_tensors = [node .meta ["val" ] for node in nodes ]
9093 broadcastable , common_shape = are_fake_tensors_broadcastable (index_fake_tensors )
@@ -137,6 +140,17 @@ def broadcast_tensors(
137140def build_reshape_tosa_1_0 (
138141 tosa_graph , input_name , new_shape , output_name , shape_name_override = ""
139142):
143+ """Insert a TOSA reshape operator using the v1.0 semantics.
144+
145+ Args:
146+ tosa_graph (Any): Graph builder used to emit TOSA operators.
147+ input_name (str): Name of the tensor that should be reshaped.
148+ new_shape (list[int]): Target tensor shape.
149+ output_name (str): Name assigned to the reshaped tensor.
150+ shape_name_override (str): Optional override for the shape constant
151+ name.
152+
153+ """
140154 shape = tosa_graph .addConst (
141155 np .array (new_shape ).shape ,
142156 ts .DType .SHAPE ,
@@ -155,6 +169,19 @@ def build_reshape_tosa_1_0(
155169
156170
157171def tosa_shape (shape , dim_order ):
172+ """Reorder a shape tuple into TOSA layout while resolving symints.
173+
174+ Args:
175+ shape (Sequence[int | torch.SymInt]): Original tensor shape,
176+ possibly containing ``torch.SymInt``.
177+ dim_order (Sequence[int]): Desired dimension order for the output
178+ shape.
179+
180+ Returns:
181+ list[int]: List containing the reordered dimensions where symbolic
182+ values become ``-1``.
183+
184+ """
158185 reordered = tuple ([shape [dim ] for dim in dim_order ])
159186 # Dynamic shapes in executorch are represented with torch.SymInt objects in the shapes,
160187 # in TOSA we do not have this concept and instead use -1.
@@ -170,6 +197,26 @@ def get_resize_parameters_1d(
170197 resize_mode : int ,
171198 align_corners : bool ,
172199):
200+ """Compute resize coefficients for a single spatial dimension.
201+
202+ Args:
203+ input_size (int | torch.SymInt): Input size for the axis, possibly
204+ symbolic.
205+ output_size (int | torch.SymInt): Output size for the axis, possibly
206+ symbolic.
207+ resize_mode (int): Target resize mode defined by TOSA.
208+ align_corners (bool): Whether the resize should align the corner
209+ pixels.
210+
211+ Returns:
212+ tuple[int, int, int, int]: Numerator, denominator, offset, and border
213+ terms encoded as integers.
214+
215+ Raises:
216+ RuntimeError: If symbolic shapes are used with ``align_corners`` or if
217+ the computed ratio or border is not constant.
218+
219+ """
173220 # We don't support align_corners for symbolic shapes, because handling the edge case where size == 1 is tricky.
174221 if align_corners :
175222 if (not isinstance (input_size , int )) or (not isinstance (output_size , int )):
@@ -229,19 +276,23 @@ def get_resize_parameters(
229276 resize_mode : int ,
230277 align_corners : bool ,
231278) -> tuple [torch .IntTensor , ...]:
232- """Get the tosa. resize parameters based on the input and output size .
279+ """Calculate 2D resize parameters for TOSA emission .
233280
234281 Args:
235- input_size_xy (tuple[int | torch.SymInt]): Size of the input
236- output_size_xy (tuple[int | torch.SymInt]): Size of the output
237- resize_mode (tosa.ResizeMode): The TOSA resize mode
238- align_corners (bool): Align the corners pixels of the input and output
282+ input_size_xy (tuple[int | torch.SymInt, int | torch.SymInt]): Height
283+ and width of the input tensor.
284+ output_size_xy (tuple[int | torch.SymInt, int | torch.SymInt]): Height
285+ and width of the output tensor.
286+ resize_mode (int): TOSA resize mode used for coefficient generation.
287+ align_corners (bool): Whether to align corner pixels between input and
288+ output.
239289
240290 Returns:
241- scale_n ( torch.IntTensor), scale_d (torch.IntTensor),
242- offset (torch.IntTensor), border (torch.IntTensor)
243- """
291+ tuple[ torch.IntTensor, ...]: Four-element tuple of tensors describing
292+ the scale numerator, scale denominator, offset, and border for Y
293+ and X dimensions.
244294
295+ """
245296 # Get the parameters for each dimension independently
246297 y_params = get_resize_parameters_1d (
247298 input_size_xy [0 ], output_size_xy [0 ], resize_mode , align_corners
0 commit comments