22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
5+ """Provide TOSA support checks for ``aten.index.Tensor``.
6+
7+ Reject unsupported patterns such as high-rank index tensors, front-positioned
8+ slice/ellipsis/None markers, and cases that exceed ``int32`` element limits.
9+
10+ """
511
612import math
713
1824
1925@register_tosa_support_check
2026class IndexTensorSupported (SupportedTOSAOperatorCheck ):
21- """
27+ """Prevent partitioning of unsupported ``index.Tensor`` usages.
28+
2229 This support check is intended to prevent the partitioning of
2330 currently unsupported usages of the index.Tensor operator.
2431
@@ -95,6 +102,7 @@ class IndexTensorSupported(SupportedTOSAOperatorCheck):
95102 t[1:3, torch.arange(5), 2:3, torch.arange(3).reshape(3,1)]
96103 are also possible and can result in some unintuitive behaviors
97104 where batching and indexing are mixed together.
105+
98106 """
99107
100108 targets = [exir_ops .edge .aten .index .Tensor ]
@@ -107,6 +115,14 @@ class IndexTensorSupported(SupportedTOSAOperatorCheck):
107115 def is_node_tosa_supported (
108116 self , node : fx .Node , tosa_spec : TosaSpecification
109117 ) -> bool : # type: ignore[override, misc]
118+ """Return True if ``aten.index.Tensor`` usage fits supported patterns.
119+
120+ Enforces the following constraints:
121+ - No ``None`` (unsqueeze), slice, or ellipsis before an indexing tensor.
122+ - Indexing tensors have rank <= 3.
123+ - The value tensor element count fits in ``int32``.
124+
125+ """
110126 indices = node .args [1 ]
111127 for index in indices : # type: ignore[union-attr]
112128 # Usage 2 guard
0 commit comments