2
2
#
3
3
# This source code is licensed under the BSD-style license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
+ """Provide TOSA support checks for ``aten.index.Tensor``.
6
+
7
+ Reject unsupported patterns such as high-rank index tensors, front-positioned
8
+ slice/ellipsis/None markers, and cases that exceed ``int32`` element limits.
9
+
10
+ """
5
11
6
12
import math
7
13
18
24
19
25
@register_tosa_support_check
20
26
class IndexTensorSupported (SupportedTOSAOperatorCheck ):
21
- """
27
+ """Prevent partitioning of unsupported ``index.Tensor`` usages.
28
+
22
29
This support check is intended to prevent the partitioning of
23
30
currently unsupported usages of the index.Tensor operator.
24
31
@@ -95,6 +102,7 @@ class IndexTensorSupported(SupportedTOSAOperatorCheck):
95
102
t[1:3, torch.arange(5), 2:3, torch.arange(3).reshape(3,1)]
96
103
are also possible and can result in some unintuitive behaviors
97
104
where batching and indexing are mixed together.
105
+
98
106
"""
99
107
100
108
targets = [exir_ops .edge .aten .index .Tensor ]
@@ -107,6 +115,14 @@ class IndexTensorSupported(SupportedTOSAOperatorCheck):
107
115
def is_node_tosa_supported (
108
116
self , node : fx .Node , tosa_spec : TosaSpecification
109
117
) -> bool : # type: ignore[override, misc]
118
+ """Return True if ``aten.index.Tensor`` usage fits supported patterns.
119
+
120
+ Enforces the following constraints:
121
+ - No ``None`` (unsqueeze), slice, or ellipsis before an indexing tensor.
122
+ - Indexing tensors have rank <= 3.
123
+ - The value tensor element count fits in ``int32``.
124
+
125
+ """
110
126
indices = node .args [1 ]
111
127
for index in indices : # type: ignore[union-attr]
112
128
# Usage 2 guard
0 commit comments