File tree Expand file tree Collapse file tree 1 file changed +12
-0
lines changed
backends/arm/operator_support Expand file tree Collapse file tree 1 file changed +12
-0
lines changed Original file line number Diff line number Diff line change 22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
5+ """Declare operator support for ``aten.index_select`` in TOSA.
56
7+ Accept int32 indices and restrict supported weight shapes to 2D or 3D with a
8+ unit batch dimension.
9+
10+ """
611import torch
712import torch .fx as fx
813from executorch .backends .arm .operator_support .tosa_supported_operators import (
1520
1621@register_tosa_support_check
1722class IndexSelectSupported (SupportedTOSAOperatorCheck ):
23+ """Provide TOSA support check for ``aten.index_select``."""
24+
1825 targets = [exir_ops .edge .aten .index_select .default ]
1926
2027 tosa_specs = [
@@ -25,7 +32,12 @@ class IndexSelectSupported(SupportedTOSAOperatorCheck):
2532 def is_node_tosa_supported (
2633 self , node : fx .Node , tosa_spec : TosaSpecification
2734 ) -> bool : # type: ignore[override, misc]
35+ """Return True if the node is supported by TOSA.
36+
37+ Require int32 indices and limit weight shapes to 2D or 3D with a leading
38+ dimension of 1.
2839
40+ """
2941 weights_shape = node .all_input_nodes [0 ].meta ["val" ].shape
3042 indices_val = node .all_input_nodes [1 ].meta ["val" ]
3143 indices_dtype = indices_val .dtype
You can’t perform that action at this time.
0 commit comments