Skip to content

Commit 80a43ac

Browse files
Arm backend: Add docstrings for operator_support/index_select_support.py (#15522)
Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 27ae6cc commit 80a43ac

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

backends/arm/operator_support/index_select_support.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
"""Declare operator support for ``aten.index_select`` in TOSA.
56
7+
Accept int32 indices and restrict supported weight shapes to 2D or 3D with a
8+
unit batch dimension.
9+
10+
"""
611
import torch
712
import torch.fx as fx
813
from executorch.backends.arm.operator_support.tosa_supported_operators import (
@@ -15,6 +20,8 @@
1520

1621
@register_tosa_support_check
1722
class IndexSelectSupported(SupportedTOSAOperatorCheck):
23+
"""Provide TOSA support check for ``aten.index_select``."""
24+
1825
targets = [exir_ops.edge.aten.index_select.default]
1926

2027
tosa_specs = [
@@ -25,7 +32,12 @@ class IndexSelectSupported(SupportedTOSAOperatorCheck):
2532
def is_node_tosa_supported(
2633
self, node: fx.Node, tosa_spec: TosaSpecification
2734
) -> bool: # type: ignore[override, misc]
35+
"""Return True if the node is supported by TOSA.
36+
37+
Require int32 indices and limit weight shapes to 2D or 3D with a leading
38+
dimension of 1.
2839
40+
"""
2941
weights_shape = node.all_input_nodes[0].meta["val"].shape
3042
indices_val = node.all_input_nodes[1].meta["val"]
3143
indices_dtype = indices_val.dtype

0 commit comments

Comments
 (0)