diff --git a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py index 029dc421920..f8ead856fbb 100644 --- a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py +++ b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py @@ -203,10 +203,18 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): - 1D/2D tensors """ for node in graph_module.graph.nodes: - if node.op != "call_function": + # call_function and placeholder allowed due to + # index.Tensor being able to come in as both + if node.op not in ["call_function", "placeholder"]: continue - elif node.target == exir_ops.edge.aten.view_copy.default: + elif node.target in ( + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.index.Tensor, + ): + # For index.Tensor: + # If we want to support 4D indexing tensors this logic + # should be updated. input_node = node.args[0] input_shape = input_node.meta["val"].shape output_shape = node.meta["val"].shape diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index 4a1f1269fe2..2075e0f554f 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -10,6 +10,7 @@ embedding_support, ethos_u55_support, index_select_support, + index_tensor_support, minmax_support, pool_2d_support, reduce_sum_support, diff --git a/backends/arm/operator_support/index_tensor_support.py b/backends/arm/operator_support/index_tensor_support.py new file mode 100644 index 00000000000..7330f98667d --- /dev/null +++ b/backends/arm/operator_support/index_tensor_support.py @@ -0,0 +1,128 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.fx as fx +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops + + +@register_tosa_support_check +class IndexTensorSupported(SupportedTOSAOperatorCheck): + """ + This support check is intended to prevent the partitioning of + currently unsupported usages of the index.Tensor operator. + + 1. Usages where indexing tensors are of rank 4 or higher. + This is due to the AnnotateChannelsLastDimOrder pass and + the rarity of such operation. + Support is possible but would require further changes to the above + pass which can be added at such a time as is necessary. + + 2. Usages where slice, ellipsis or None are present before an indexing tensor: + t[{start}:{end}, indexTensor] - slicing + t[None, indexTensor] - unsqueeze + t[..., indexTensor] - ellipsis + + 3. Usages where the value tensor contains more than int32.max elements + This is due to int32 TOSA limitation and the fact that we flatten out + and accumulate all index tensors. + As such to avoid overflow we reject lowering of this operator if it is + possible for indices to go over the int32 limit. + + Extra information regarding #2: + Pytorch decomposes slice and None usages before they reach aten. + In the case of Slicing and Unsqueeze, Pytorch will add the relevant + operation just before the index.Tensor op. + In the case of Ellipsis no extra operation is added. + + In all three cases Pytorch will insert "None"(s) in the index list + only if the above operations are done on a dimension BEFORE one being indexed. + + When slicing, unsqueeze and ellipsis are done on dimensions after + the ones being indexed, then they do not affect the final output + values, only the shape. Thus None is not passed to the index.Tensor op. + + The purpose of None is to signify to index.Tensor that a dimension + should not be indexed. + In such cases the logic behaves similar to batching along that dimension. + For the sake of simplicity we have not implemented this behavior yet + and thus have put this support check in place to prevent the partitioning + of index.Tensor ops which include None. + + Examples: + #1 - Slice ----------------------------------------------------- + t = torch.randint(25, size(25, 3, 6)) + t[1:5, torch.arange(3)] + + Turns into: (edge pseudo code) + slice_res = ...edge__ops_aten_slice_copy_Tensor(t, dim=0, start=1, end=2) + out = ...edge__ops_aten_index_Tensor(slice_res, [None, torch.arange(3)]) + + #2 - None (Unsqueeze) ------------------------------------------ + t = torch.randint(25, size(25, 3, 6)) + t[None, torch.arange(3)] + + Turns into: edge pseudo code) + unsqueeze_res = ...edge__ops_aten_unsqueeze(t, dim=0) + out = ...edge__ops_aten_index_Tensor(unsqueeze_res, [None, torch.arange(3)]) + + #3 - None (Unsqueeze) After index ------------------------------ + t = torch.randint(25, size(25, 3, 6)) + t[torch.arange(3), None] + + Turns into: edge pseudo code) + unsqueeze_res = ...edge__ops_aten_unsqueeze(t, dim=1) + out = ...edge__ops_aten_index_Tensor(unsqueeze_res, [torch.arange(3)]) + + NB. + With the current implementation of flattening tensors and indices out, + supporting None (Unsqueeze) is simply a matter of ignoring the + None dimension. + This is not the case for Slice and Ellipsis operators, where + the size of the new dimension can be > 1. + + Note that slice ops interleaved between indexes such as: + t[1:3, torch.arange(5), 2:3, torch.arange(3).reshape(3,1)] + are also possible and can result in some unintuitive behaviors + where batching and indexing are mixed together. + """ + + targets = [exir_ops.edge.aten.index.Tensor] + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification + ) -> bool: # type: ignore[override, misc] + indices = node.args[1] + for index in indices: # type: ignore[union-attr] + # Usage 2 guard + if index is None: + return False + + # Usage 1 guard + fake_tensor = get_first_fake_tensor(index) # type: ignore[arg-type] + if len(fake_tensor.size()) > 3: + return False + + # Usage 3 guard + total_vals = math.prod(get_first_fake_tensor(node.args[0]).shape) # type: ignore[arg-type] + if total_vals > torch.iinfo(torch.int32).max: + return False + + return True diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 1e2620e4533..260299d6423 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -25,6 +25,7 @@ op_ge, op_gt, op_index_select, + op_index_tensor, op_le, op_log, op_lt, diff --git a/backends/arm/operators/op_index_tensor.py b/backends/arm/operators/op_index_tensor.py new file mode 100644 index 00000000000..8c5c84ddd5a --- /dev/null +++ b/backends/arm/operators/op_index_tensor.py @@ -0,0 +1,354 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import math +from typing import Any, List + +import executorch.backends.arm.tosa_utils as tutils + +import numpy as np + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_same_dtype, +) +from executorch.backends.arm.tosa_mapping import extract_tensor_meta, TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification +from torch.fx import Node + + +class CommonIndexTensorVisitor(NodeVisitor): + target = "aten.index.Tensor" + + def __init__(self, *args): + super().__init__(*args) + + def _get_tensor_info(self, tensor: Node): + """ + Consolidates obtaining name, dtype and shape into a common function + reconciling access based on the type of the input. + + Args: + fake_tensors (list[FakeTensor]): List of 2 or more FakeTensors + who's shapes to evaluate + + Returns: + tuple[bool, list[int]]: First element is whether the shapes are + broadcastable. Second element is the common shape if compatible. + If not, empty list. + + """ + if isinstance(tensor, Node): + dtype, shape, _ = extract_tensor_meta(tensor.meta, self.tosa_spec) + return tensor.name, dtype, shape + else: + return tensor.name, tensor.dtype, tensor.shape + + def _calculate_tosa_vals( + self, + index_shape: List[int], + index_nodes: List[Node], + values_shape: List[int], + ): + # From TOSA spec + # N - number of batches + # W - number of indices in each batch + # K - range of each index (number of elements to index through) + # C - number of data channels for each index + N, K, W, C = 1, 1, 1, 1 + + # Calculate K, W, C + # N - kept to 1 for generic n-dim implementation + # Note: If/when slice and ellipsis support is added batching + # may have to be used to facilitate proper implementation of + # the relevant logic. + # W - common between all indices as they have been broadcast + # to a common shape in a pass. + W = math.prod(index_shape) + + for i, dim in enumerate(values_shape): + if i < len(index_nodes): + K *= dim + + total_vals = math.prod(values_shape) + C = int(total_vals / K) + + return N, K, W, C + + def _calculate_value_strides(self, values_shape: List[int]) -> List[int]: + values_strides: List[int] = [] + stride = 1 + for dim in reversed(values_shape): + values_strides.insert(0, stride) + stride *= dim + + return values_strides + + +@register_node_visitor +class IndexTensorVisitor_080(CommonIndexTensorVisitor): + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+MI"), + TosaSpecification.create_from_string("TOSA-0.80+BI"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + """ + This approach uses the fact that all indexing tensors are incremented + simultaneously and they essentially act as a map along the corresponding + dimensions of the values tensor. + Note: that this does not hold true when slicing or ellipsis ops + are involved as such they are not currently not supported. + + As such this approach flattens out the values tensor and + constructs a flattened out index obtained by flattening out the + index tensors, multiplying them by the relevant stride and accumulating them. + + This approach suffers from the fact that we are taking a number of index tensors of + type int32 and applying multiplications and additions. + + If the number of total elements in the values tensor exceeds int32 limits + then this approach falls apart. + """ + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + validate_same_dtype(self.target, [inputs[0], output]) + + values, indices = inputs + index_nodes = indices.special + + # Broadcast indices + broadcasted_tensors = tutils.broadcast_tensors( + tosa_graph, index_nodes, self.tosa_spec + ) + + values_strides = self._calculate_value_strides(values.shape) + + # The indices have already been broadcast to a common shape + # in so they are all the same. + _, index_dtype, index_shape = self._get_tensor_info(broadcasted_tensors[0]) + + N, K, W, C = self._calculate_tosa_vals(index_shape, index_nodes, values.shape) + + gather_idx_shape = [N, W] + + gather_index_name = "" + # Flatten out and shift indexes. + for i, index_node in enumerate(broadcasted_tensors): + index_name, _, _ = self._get_tensor_info(index_node) + index_name = index_node.name + + stride_shifted_indices = tosa_graph.addIntermediate( + index_shape, + index_dtype, + ) + + # Division by C is necessary when len(indices) < values.rank + # When there are dimensions left unindexed that changes the + # channels and thus the stride-shift. + data = np.full(index_shape, int(values_strides[i] / C)) + mul_const = tosa_graph.addConst(index_shape, index_dtype, data) + attr = ts.TosaSerializerAttribute() + attr.MulAttribute(shift=0) + tosa_graph.addOperator( + ts.TosaOp.Op().MUL, + [index_name, mul_const.name], + [stride_shifted_indices.name], + attr, + ) + + reshaped_idxs = tosa_graph.addIntermediate( + gather_idx_shape, + index_dtype, + ) + tutils.build_reshape( + tosa_graph, + stride_shifted_indices.name, + gather_idx_shape, + reshaped_idxs.name, + ) + + # Guarantees that the accumulation tensor is properly + # initialized and does not contain junk data. + if i == 0: + gather_index_name = reshaped_idxs.name + else: + tosa_graph.addOperator( + ts.TosaOp.Op().ADD, + [gather_index_name, reshaped_idxs.name], + [gather_index_name], + ) + + gather_vals_shape = [N, K, C] + reshaped_input = tosa_graph.addIntermediate(gather_vals_shape, values.dtype) + tutils.build_reshape( + tosa_graph, values.name, gather_vals_shape, reshaped_input.name + ) + + gather_out_shape = (N, W, C) + gather_out = tosa_graph.addIntermediate( + gather_out_shape, + output.dtype, + ) + tosa_graph.addOperator( + ts.TosaOp.Op().GATHER, + [reshaped_input.name, gather_index_name], + [gather_out.name], + None, + ) + + output_shape = tutils.tosa_shape(output.shape, output.dim_order) + tutils.build_reshape(tosa_graph, gather_out.name, output_shape, output.name) + + +@register_node_visitor +class IndexTensorVisitor(CommonIndexTensorVisitor): + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + """ + This approach uses the fact that all indexing tensors are incremented + simultaneously and they essentially act as a map along the corresponding + dimensions of the values tensor. + Note: that this does not hold true when slicing or ellipsis ops + are involved as such they are not currently not supported. + + As such this approach flattens out the values tensor and + constructs a flattened out index obtained by flattening out the + index tensors, multiplying them by the relevant stride and accumulating them. + + This approach suffers from the fact that we are taking a number of index tensors of + type int32 and applying multiplications and additions. + + If the number of total elements in the values tensor exceeds int32 limits + then this approach falls apart. + """ + import serializer.tosa_serializer as ts + + validate_same_dtype(self.target, [inputs[0], output]) + + values, indices = inputs + index_nodes = indices.special + + # Broadcast indices + broadcasted_tensors = tutils.broadcast_tensors( + tosa_graph, index_nodes, self.tosa_spec + ) + + # Calculate strides so we can shift indices down the line. + values_strides = self._calculate_value_strides(values.shape) + + # The indices have already been broadcast to a common shape + # in so they are all the same. + _, index_dtype, index_shape = self._get_tensor_info(broadcasted_tensors[0]) + + N, K, W, C = self._calculate_tosa_vals(index_shape, index_nodes, values.shape) + + gather_idx_shape = [N, W] + + gather_index_name = "" + # Flatten out and shift indexes. + for i, index_node in enumerate(broadcasted_tensors): + index_name, _, _ = self._get_tensor_info(index_node) + index_name = index_node.name + + stride_shifted_indices = tosa_graph.addIntermediate( + index_shape, + index_dtype, + ) + + # Division by C is necessary when len(indices) < values.rank + # When there are dimensions left unindexed that changes the + # channels and thus the stride-shift. + data = np.full(index_shape, int(values_strides[i] / C)) + mul_const = tosa_graph.addConst(index_shape, index_dtype, data) + tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_{i}_shift") + tosa_graph.addOperator( + ts.TosaOp.Op().MUL, + [index_name, mul_const.name, f"{node.name}_{i}_shift"], + [stride_shifted_indices.name], + ) + + reshaped_idxs = tosa_graph.addIntermediate( + gather_idx_shape, + index_dtype, + ) + tutils.build_reshape_tosa_1_0( + tosa_graph, + stride_shifted_indices.name, + gather_idx_shape, + reshaped_idxs.name, + shape_name_override=f"{node.name}_{i}_shape", + ) + + # Guarantees that the accumulation tensor is properly + # initialized and does not contain junk data. + if i == 0: + gather_index_name = reshaped_idxs.name + else: + tosa_graph.addOperator( + ts.TosaOp.Op().ADD, + [gather_index_name, reshaped_idxs.name], + [gather_index_name], + ) + + gather_vals_shape = [N, K, C] + reshaped_input = tosa_graph.addIntermediate(gather_vals_shape, values.dtype) + + tutils.build_reshape_tosa_1_0( + tosa_graph, + values.name, + gather_vals_shape, + reshaped_input.name, + shape_name_override=f"{node.name}_index_shape", + ) + + gather_out_shape = (N, W, C) + gather_out = tosa_graph.addIntermediate( + gather_out_shape, + output.dtype, + ) + tosa_graph.addOperator( + ts.TosaOp.Op().GATHER, + [reshaped_input.name, gather_index_name], + [gather_out.name], + None, + ) + + output_shape = tutils.tosa_shape(output.shape, output.dim_order) + + tutils.build_reshape_tosa_1_0( + tosa_graph, + gather_out.name, + list(output_shape), + output.name, + shape_name_override=f"{node.name}_output_shape", + ) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 5c2f7822097..83a648c7d8a 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -260,6 +260,7 @@ def _match_pattern( torch.ops.aten.clamp.Tensor, torch.ops.aten.unflatten.int, torch.ops.aten.index_select.default, + torch.ops.aten.index.Tensor, ] _one_to_one_shared_input_or_input_act_qspec = [ diff --git a/backends/arm/test/ops/test_index_tensor.py b/backends/arm/test/ops/test_index_tensor.py new file mode 100644 index 00000000000..f1f6f5171d8 --- /dev/null +++ b/backends/arm/test/ops/test_index_tensor.py @@ -0,0 +1,462 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from enum import IntEnum +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineBI, + TosaPipelineMI, +) + + +class IndexTensorTestCommon: + """Class containing constants common between the tests""" + + aten_op = "torch.ops.aten.index.Tensor" + exir_op = "executorch_exir_dialects_edge__ops_aten_index_Tensor" + + # Gathers and reshapes should result in no inaccuracies + rtol = 0.0 + atol = 0.0 + + class OpPlacement(IntEnum): + """ + Simple enum used to indicate where slices or ellipsis should be placed + in tests. + IntEnum so that Dynamo does not complain about unsupported types. + """ + + BEFORE = 1 + MIDDLE = 2 + AFTER = 3 + + +input_params_slice = Tuple[ + torch.Tensor, int, int, IndexTensorTestCommon.OpPlacement, Tuple[torch.Tensor] +] +input_params = Tuple[torch.Tensor, Tuple[torch.Tensor]] + + +class IndexTensor_Ellipsis(torch.nn.Module): + """ + There are technical limitations with torch/export as it does not support + the ellipsis class and as such the forward function has been crafted + to circumvent that limitation. + """ + + # xfail - ellipsis unsupported + test_data_ellipsis: dict[input_params] = { + "test_4d_ellipsis_before": ( + torch.rand(size=(25, 5, 13, 7)), + IndexTensorTestCommon.OpPlacement.BEFORE, + (torch.arange(2, dtype=torch.int32),), + ), + "test_4d_ellipsis_middle": ( + torch.rand(size=(25, 5, 13, 7)), + IndexTensorTestCommon.OpPlacement.MIDDLE, + ( + torch.arange(2, dtype=torch.int32), + torch.arange(2, dtype=torch.int32), + ), + ), + "test_4d_ellipsis_after": ( + # Due to the information passed to the NodeVisitor and + # preceding passes, detecting this and rejecting it for + # partitioning is difficult and unreliable, as such + # it is not xfail as the existing logic can handle it. + torch.rand(size=(25, 5, 13, 7)), + IndexTensorTestCommon.OpPlacement.AFTER, + (torch.arange(2, dtype=torch.int32),), + ), + } + + def forward( + self, + input_: torch.Tensor, + position: IndexTensorTestCommon.OpPlacement, + indices: Tuple[None | torch.Tensor], + ): + match position: + case IndexTensorTestCommon.OpPlacement.BEFORE: + return input_[..., indices[0]] + case IndexTensorTestCommon.OpPlacement.MIDDLE: + return input_[indices[0], ..., indices[1]] + case IndexTensorTestCommon.OpPlacement.AFTER: + return input_[indices[0], ...] + + return input_[indices] + + +@common.parametrize( + "test_data", + IndexTensor_Ellipsis.test_data_ellipsis, + xfails={ + # More info in index_tensor_support.py + "test_4d_ellipsis_before": "Ellipsis before index unsupported", + "test_4d_ellipsis_middle": "Ellipsis before index unsupported", + }, +) +def test_index_tensor_tosa_MI_ellipsis(test_data: input_params): + test_input = test_data + with torch.no_grad(): + ( + TosaPipelineMI[input_params]( + IndexTensor_Ellipsis(), + test_input, + IndexTensorTestCommon.aten_op, + IndexTensorTestCommon.exir_op, + atol=IndexTensorTestCommon.atol, + rtol=IndexTensorTestCommon.rtol, + ).run() + ) + + +@common.parametrize( + "test_data", + IndexTensor_Ellipsis.test_data_ellipsis, + xfails={ + # More info in index_tensor_support.py + "test_4d_ellipsis_before": "Ellipsis before index unsupported", + "test_4d_ellipsis_middle": "Ellipsis before index unsupported", + }, +) +def test_index_tensor_tosa_BI_ellipsis(test_data: input_params): + test_input = test_data + with torch.no_grad(): + ( + TosaPipelineBI[input_params]( + IndexTensor_Ellipsis(), + test_input, + IndexTensorTestCommon.aten_op, + IndexTensorTestCommon.exir_op, + ).run() + ) + + +class IndexTensor_Slice(torch.nn.Module): + """ + There are technical limitations with Dynamo as it does not support the + slice class and as such the forward function has been crafted + to circumvent that limitation. + """ + + # xfail - None unsupported + test_data: dict[input_params_slice] = { + "test_4d_slice_before_1d_idx": ( + # Value tens is 3D because with the + torch.rand(size=(5, 3, 4, 5)), + 0, + 2, + IndexTensorTestCommon.OpPlacement.BEFORE, + (torch.arange(2, dtype=torch.int32),), + ), + "test_3d_slice_before_2d_idx": ( + # TODO: MLETORCH-859 - Testing framework does not support output rank > 4 + # With the bellow configuration a 4D value tensor and 2D index tensor + # results in a 5D output. + torch.arange(5 * 3 * 4, dtype=torch.float32).reshape(5, 3, 4), + 0, + 2, + IndexTensorTestCommon.OpPlacement.BEFORE, + (torch.arange(2, dtype=torch.int32).unsqueeze(0).tile(2, 1),), + ), + "test_4d_slice_middle": ( + torch.arange(5 * 3 * 2, dtype=torch.int32).reshape(5, 3, 2), + 0, + 2, + IndexTensorTestCommon.OpPlacement.MIDDLE, + ( + torch.arange(2, dtype=torch.int32), + torch.arange(2, dtype=torch.int32), + ), + ), + "test_4d_slice_after": ( + # Due to the information passed to the NodeVisitor and + # preceding passes, detecting this and rejecting it for + # partitioning is difficult and unreliable, as such + # it is not xfail as the existing logic can handle it. + torch.rand(size=(25, 5, 13, 7)), + 0, + 2, + IndexTensorTestCommon.OpPlacement.AFTER, + (torch.arange(2, dtype=torch.int32),), + ), + } + + def forward( + self, + input_: torch.Tensor, + slice_start: int, + slice_end: int, + position: IndexTensorTestCommon.OpPlacement, + indices: Tuple[None | torch.Tensor], + ): + match position: + case IndexTensorTestCommon.OpPlacement.BEFORE: + return input_[slice_start:slice_end, indices[0]] + case IndexTensorTestCommon.OpPlacement.MIDDLE: + return input_[indices[0], slice_start:slice_end, indices[1]] + case IndexTensorTestCommon.OpPlacement.AFTER: + return input_[indices[0], slice_start:slice_end] + + +@common.parametrize( + "test_data", + IndexTensor_Slice.test_data, + xfails={ + # More info in index_tensor_support.py + "test_4d_slice_before_1d_idx": "Slice before index unsupported", + "test_3d_slice_before_2d_idx": "Slice before index unsupported", + "test_4d_slice_middle": "Slice before index unsupported", + }, +) +def test_index_tensor_tosa_MI_slice(test_data: input_params_slice): + test_input = test_data + with torch.no_grad(): + ( + TosaPipelineMI[input_params_slice]( + IndexTensor_Slice(), + test_input, + IndexTensorTestCommon.aten_op, + IndexTensorTestCommon.exir_op, + atol=IndexTensorTestCommon.atol, + rtol=IndexTensorTestCommon.rtol, + ).run() + ) + + +@common.parametrize( + "test_data", + IndexTensor_Slice.test_data, + xfails={ + # More info in index_tensor_support.py + "test_4d_slice_before_1d_idx": "Slice before index unsupported", + "test_3d_slice_before_2d_idx": "Slice before index unsupported", + "test_4d_slice_middle": "Slice before index unsupported", + }, +) +def test_index_tensor_tosa_BI_slice(test_data: input_params_slice): + test_input = test_data + with torch.no_grad(): + ( + TosaPipelineBI[input_params_slice]( + IndexTensor_Slice(), + test_input, + IndexTensorTestCommon.aten_op, + IndexTensorTestCommon.exir_op, + ).run() + ) + + +class IndexTensor(torch.nn.Module): + test_data: dict[input_params] = { + "test_2d_1_idx": (torch.rand(5, 2), (torch.arange(5, dtype=torch.int32),)), + "test_2d_1_less_than_max_idx": ( + torch.rand(5, 2), + (torch.arange(3, dtype=torch.int32),), + ), + "test_2d_1_2d_idx": ( + torch.rand(5, 2), + (torch.randint(5, size=(4, 3), dtype=torch.int32)), + ), + "test_2d_2_idx": ( + torch.rand(5, 2), + ( + torch.randint(5, size=(5,), dtype=torch.int32), + torch.randint(2, size=(5,), dtype=torch.int32), + ), + ), + "test_2d_2_2d_idx_broadcastable": ( + torch.rand(5, 2), + ( + torch.randint(5, size=(5, 3), dtype=torch.int32), + torch.randint(2, size=(1, 3), dtype=torch.int32), + ), + ), + "test_2d_2_2d_idx_broadcastable_2": ( + torch.rand(5, 2), + ( + torch.randint(5, size=(5, 1), dtype=torch.int32), + torch.randint(2, size=(3,), dtype=torch.int32), + ), + ), + "test_3d_1_idx": (torch.rand(12, 3, 7), (torch.arange(12, dtype=torch.int32),)), + "test_3d_2_idx": ( + torch.rand(12, 3, 7), + ( + torch.arange(12, dtype=torch.int32), + torch.randint(3, size=(12,), dtype=torch.int32), + ), + ), + "test_3d_3_idx": ( + torch.rand(12, 3, 7), + ( + torch.arange(12, dtype=torch.int32), + torch.randint(3, size=(12,), dtype=torch.int32), + torch.randint(7, size=(12,), dtype=torch.int32), + ), + ), + "test_4d_1_idx": ( + torch.rand(15, 3, 7, 2), + (torch.arange(15, dtype=torch.int32),), + ), + "test_4d_2_idx": ( + torch.rand(15, 3, 7, 2), + ( + torch.randint(15, size=(15,), dtype=torch.int32), + torch.randint(3, size=(1,), dtype=torch.int32), + ), + ), + "test_4d_3_idx": ( + torch.rand(15, 3, 7, 2), + ( + torch.arange(15, dtype=torch.int32), + torch.randint(3, size=(15,), dtype=torch.int32), + torch.randint(7, size=(15,), dtype=torch.int32), + ), + ), + "test_4d_4_id_broadcastable": ( + torch.rand(15, 3, 7, 2), + ( + torch.arange(15, dtype=torch.int32), + torch.randint(3, size=(3, 1), dtype=torch.int32), + torch.randint(6, size=(6, 1, 1), dtype=torch.int32), + torch.randint(2, size=(15,), dtype=torch.int32), + ), + ), + } + + # xfail - None (unsqueeze) unsupported + test_data_none: dict[input_params] = { + "test_3d_3_idx_with_none_before": ( + torch.rand(12, 3, 7), + ( + None, + torch.randint(3, size=(12,), dtype=torch.int32), + ), + ), + "test_3d_3_idx_with_2_none_before": ( + torch.rand(12, 3, 7), + ( + None, + None, + torch.randint(3, size=(12,), dtype=torch.int32), + ), + ), + "test_3d_3_idx_with_none_around": ( + torch.rand(12, 3, 7), + ( + None, + torch.randint(3, size=(12,), dtype=torch.int32), + None, + ), + ), + "test_3d_3_idx_with_none_after": ( + # Due to the information passed to the NodeVisitor and + # preceding passes, detecting this and rejecting it for + # partitioning is difficult and unreliable, as such + # it is not xfail as the existing logic can handle it. + torch.rand(12, 3, 7), + ( + torch.randint(3, size=(12,), dtype=torch.int32), + None, + ), + ), + "test_3d_3_idx_with_none_middle": ( + torch.rand(12, 3, 7), + ( + torch.randint(3, size=(12,), dtype=torch.int32), + None, + torch.randint(3, size=(12,), dtype=torch.int32), + ), + ), + } + + def forward(self, input_: torch.Tensor, indices: Tuple[None | torch.Tensor]): + return input_[indices] + + +@common.parametrize("test_data", IndexTensor.test_data) +def test_index_tensor_tosa_MI(test_data: input_params): + test_input = test_data + with torch.no_grad(): + ( + TosaPipelineMI[input_params]( + IndexTensor(), + test_input, + IndexTensorTestCommon.aten_op, + IndexTensorTestCommon.exir_op, + atol=IndexTensorTestCommon.atol, + rtol=IndexTensorTestCommon.rtol, + ).run() + ) + + +@common.parametrize("test_data", IndexTensor.test_data) +def test_index_tensor_tosa_BI(test_data: input_params): + test_input = test_data + with torch.no_grad(): + ( + TosaPipelineBI[input_params]( + IndexTensor(), + test_input, + IndexTensorTestCommon.aten_op, + IndexTensorTestCommon.exir_op, + ).run() + ) + + +@common.parametrize( + "test_data", + IndexTensor.test_data_none, + xfails={ + # More info in index_tensor_support.py + "test_3d_3_idx_with_none_before": "None (Unsqueeze) unsupported", + "test_3d_3_idx_with_2_none_before": "None (Unsqueeze) unsupported", + "test_3d_3_idx_with_none_around": "None (Unsqueeze) unsupported", + "test_3d_3_idx_with_none_middle": "None (Unsqueeze) unsupported", + }, +) +def test_index_tensor_tosa_MI_none(test_data: input_params): + test_input = test_data + with torch.no_grad(): + ( + TosaPipelineMI[input_params]( + IndexTensor(), + test_input, + IndexTensorTestCommon.aten_op, + IndexTensorTestCommon.exir_op, + atol=IndexTensorTestCommon.atol, + rtol=IndexTensorTestCommon.rtol, + ).run() + ) + + +@common.parametrize( + "test_data", + IndexTensor.test_data_none, + xfails={ + # More info in index_tensor_support.py + "test_3d_3_idx_with_none_before": "None (Unsqueeze) unsupported", + "test_3d_3_idx_with_2_none_before": "None (Unsqueeze) unsupported", + "test_3d_3_idx_with_none_around": "None (Unsqueeze) unsupported", + "test_3d_3_idx_with_none_middle": "None (Unsqueeze) unsupported", + }, +) +def test_index_tensor_tosa_BI_none(test_data: input_params): + test_input = test_data + with torch.no_grad(): + ( + TosaPipelineBI[input_params]( + IndexTensor(), + test_input, + IndexTensorTestCommon.aten_op, + IndexTensorTestCommon.exir_op, + ).run() + ) diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index a176bc62973..3b56fdd1cbf 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -15,14 +15,19 @@ import torch import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore -from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.arm.tosa_mapping import extract_tensor_meta, TosaArg +from executorch.backends.arm.tosa_specification import ( + Tosa_0_80, + Tosa_1_00, + TosaSpecification, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.print_program import inspect_node + +from torch._subclasses.fake_tensor import FakeTensor from torch.fx import Node -from tosa_tools.v0_80.serializer.tosa_serializer import TosaOp logger = logging.getLogger(__name__) @@ -116,17 +121,149 @@ def get_output_node(node: Node) -> Node: def build_reshape(tosa_fb, input_name, new_shape, output_name): attr = ts.TosaSerializerAttribute() attr.ReshapeAttribute(new_shape) - tosa_fb.addOperator(TosaOp.Op().RESHAPE, [input_name], [output_name], attr) + tosa_fb.addOperator(ts.TosaOp.Op().RESHAPE, [input_name], [output_name], attr) + + +def are_fake_tensors_broadcastable( + fake_tensors: list[FakeTensor], +) -> tuple[bool, list[int]]: + """ + Determines whether a list of FakeTensors can be broadcast together. + Args: + fake_tensors (list[FakeTensor]): List of 2 or more FakeTensors + who's shapes to evaluate + Returns: + tuple[bool, list[int]]: First element is whether the shapes are + broadcastable. Second element is the common shape if compatible. + If not, empty list. -def build_reshape_tosa_1_0(tosa_graph, input_name, new_shape, output_name): + Raises: + RuntimeError: If less than 2 tensors are passed in. + """ + if len(fake_tensors) < 1: + raise RuntimeError(f"Expected 2 or more tensors got {len(fake_tensors)}") + + reversed_shapes = [list(reversed(ft.shape)) for ft in fake_tensors] + sorted_shapes = sorted(reversed_shapes, key=len, reverse=True) + + broadcast_shape = [] + for dim in range(len(sorted_shapes[0])): + curr_dim = 1 + for shape in sorted_shapes: + if dim >= len(shape): + continue + if curr_dim == 1 and shape[dim] != 1: + curr_dim = shape[dim] + elif shape[dim] == 1: + continue + elif curr_dim != 1 and shape[dim] != curr_dim: + return (False, []) + broadcast_shape.append(curr_dim) + return (True, list(reversed(broadcast_shape))) + + +def broadcast_tensors( + tosa_fb, nodes: list[Node], tosa_spec: TosaSpecification +) -> list[Any]: + """ + Given a list of nodes it determines the common shape they broadcast to + and adds the necessary reshape and tile operations to perform the broadcast. + + Args: + tosa_fb: Tosa graph to add nodes to + nodes (list[Node]): List of nodes to broadcast together + tosa_spec (TosaSpecification): Tosa spec + + Returns: + list[Any]: List containing the fx.Nodes or TosaSerializerTensors + of the right common shape. Order of output matches order of input. + + Raises: + RuntimeError: If the supplied nodes are not broadcastable. + + Note: + This function and `reshape_for_broadcast` both reshape the tensors + for broadcast. However this function also performs the broadcast and + does not have a limit on only two input tensors. + """ + if isinstance(tosa_spec, Tosa_0_80): + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + reshape_helper = build_reshape + elif isinstance(tosa_spec, Tosa_1_00): + import serializer.tosa_serializer as ts + + reshape_helper = build_reshape_tosa_1_0 + else: + raise ValueError(f"Unsupported TOSA spec: {tosa_spec}") + + index_fake_tensors = [node.meta["val"] for node in nodes] + broadcastable, common_shape = are_fake_tensors_broadcastable(index_fake_tensors) + if not broadcastable: + raise RuntimeError("FakeTensors are not broadcastable") + + broadcast_tensors = [] + for node in nodes: + tens_dtype, tens_shape, _ = extract_tensor_meta(node.meta, tosa_spec) + list_tens_shape = list(tens_shape) + # Already in the right shape we can just add it to the list. + if list_tens_shape == common_shape: + broadcast_tensors.append(node) + continue + + rank_diff = len(common_shape) - len(tens_shape) + new_shape = [1] * rank_diff + list_tens_shape + reshaped = tosa_fb.addIntermediate( + new_shape, + tens_dtype, + ) + + reshape_helper(tosa_fb, node.name, new_shape, reshaped.name) + + tiled = tosa_fb.addIntermediate(common_shape, tens_dtype) + multipliers = [ + comm if curr == 1 else 1 for comm, curr in zip(common_shape, new_shape) + ] + if isinstance(tosa_spec, Tosa_0_80): + attr = ts.TosaSerializerAttribute() + attr.TileAttribute(multipliers) + tosa_fb.addOperator( + ts.TosaOp.Op().TILE, + [reshaped.name], + [tiled.name], + attr, + ) + elif isinstance(tosa_spec, Tosa_1_00): + multiple_shapes = tosa_fb.addConst( + (len(multipliers),), + ts.DType.SHAPE, + multipliers, + name=f"{node.name}_multiples", + ) + + tosa_fb.addOperator( + ts.TosaOp.Op().TILE, + [reshaped.name, multiple_shapes.name], + [tiled.name], + None, + ) + + broadcast_tensors.append(tiled) + + return broadcast_tensors + + +def build_reshape_tosa_1_0( + tosa_graph, input_name, new_shape, output_name, shape_name_override="" +): import serializer.tosa_serializer as ts_ # type: ignore shape = tosa_graph.addConst( np.array(new_shape).shape, ts_.DType.SHAPE, np.array(new_shape), - name=output_name + "_shape", + name=shape_name_override if shape_name_override else output_name + "_shape", ) attr = ts_.TosaSerializerAttribute()