22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
5- """Declare operator support for ``aten.embedding`` in TOSA.
65
7- Permit embeddings with int32 indices (TOSA lacks int64 support); other dtypes
8- are rejected by this check.
9-
10- """
116
127import torch
138
2217
2318@register_tosa_support_check
2419class EmbeddingSupported (SupportedTOSAOperatorCheck ):
25- """Provide TOSA support check for ``aten.embedding``."""
26-
2720 targets = [exir_ops .edge .aten .embedding .default ]
2821
2922 tosa_specs = [
@@ -34,20 +27,16 @@ class EmbeddingSupported(SupportedTOSAOperatorCheck):
3427 def is_node_tosa_supported (
3528 self , node : fx .Node , tosa_spec : TosaSpecification
3629 ) -> bool : # type: ignore[override, misc]
37- """Return True if the node is supported by TOSA.
30+ # Note aten.embedding.default requires int64 indices and TOSA does not
31+ # support it. Int32 indices here for aten.embedding.default is ok since
32+ # it will be decomposed into ops that can handle it.
3833
39- PyTorch's ``aten.embedding`` typically takes int64 indices, but for
40- TOSA we only allow int32 indices. The export path decomposes the op so
41- that int32 indices are ok.
42-
43- """
4434 if len (node .all_input_nodes ) != 2 :
4535 self .reporter .report_reject (
4636 node ,
4737 (f"Expected exactly two input nodes, got { len (node .all_input_nodes )} " ),
4838 )
4939 return False
50-
5140 indices_val = node .all_input_nodes [1 ].meta ["val" ]
5241 indices_dtype = indices_val .dtype
5342
0 commit comments