Skip to content

Commit aec847d

Browse files
Arm backend: Add docstrings for operator_support/index_tensor_support.py (#14505)
Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 7dc059f commit aec847d

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

backends/arm/operator_support/index_tensor_support.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
"""Provide TOSA support checks for ``aten.index.Tensor``.
6+
7+
Reject unsupported patterns such as high-rank index tensors, front-positioned
8+
slice/ellipsis/None markers, and cases that exceed ``int32`` element limits.
9+
10+
"""
511

612
import math
713

@@ -18,7 +24,8 @@
1824

1925
@register_tosa_support_check
2026
class IndexTensorSupported(SupportedTOSAOperatorCheck):
21-
"""
27+
"""Prevent partitioning of unsupported ``index.Tensor`` usages.
28+
2229
This support check is intended to prevent the partitioning of
2330
currently unsupported usages of the index.Tensor operator.
2431
@@ -95,6 +102,7 @@ class IndexTensorSupported(SupportedTOSAOperatorCheck):
95102
t[1:3, torch.arange(5), 2:3, torch.arange(3).reshape(3,1)]
96103
are also possible and can result in some unintuitive behaviors
97104
where batching and indexing are mixed together.
105+
98106
"""
99107

100108
targets = [exir_ops.edge.aten.index.Tensor]
@@ -107,6 +115,14 @@ class IndexTensorSupported(SupportedTOSAOperatorCheck):
107115
def is_node_tosa_supported(
108116
self, node: fx.Node, tosa_spec: TosaSpecification
109117
) -> bool: # type: ignore[override, misc]
118+
"""Return True if ``aten.index.Tensor`` usage fits supported patterns.
119+
120+
Enforces the following constraints:
121+
- No ``None`` (unsqueeze), slice, or ellipsis before an indexing tensor.
122+
- Indexing tensors have rank <= 3.
123+
- The value tensor element count fits in ``int32``.
124+
125+
"""
110126
indices = node.args[1]
111127
for index in indices: # type: ignore[union-attr]
112128
# Usage 2 guard

0 commit comments

Comments
 (0)