22# 
33# This source code is licensed under the BSD-style license found in the 
44# LICENSE file in the root directory of this source tree. 
5+ """Declare operator support for ``aten.embedding`` in TOSA. 
56
7+ Permit embeddings with int32 indices (TOSA lacks int64 support); other dtypes 
8+ are rejected by this check. 
9+ 
10+ """ 
611
712import  torch 
813
1722
1823@register_tosa_support_check  
1924class  EmbeddingSupported (SupportedTOSAOperatorCheck ):
25+     """Provide TOSA support check for ``aten.embedding``.""" 
26+ 
2027    targets  =  [exir_ops .edge .aten .embedding .default ]
2128
2229    tosa_specs  =  [
@@ -27,16 +34,20 @@ class EmbeddingSupported(SupportedTOSAOperatorCheck):
2734    def  is_node_tosa_supported (
2835        self , node : fx .Node , tosa_spec : TosaSpecification 
2936    ) ->  bool :  # type: ignore[override, misc] 
30-         # Note aten.embedding.default requires int64 indices and TOSA does not 
31-         # support it. Int32 indices here for aten.embedding.default is ok since 
32-         # it will be decomposed into ops that can handle it. 
37+         """Return True if the node is supported by TOSA. 
3338
39+         PyTorch's ``aten.embedding`` typically takes int64 indices, but for 
40+         TOSA we only allow int32 indices. The export path decomposes the op so 
41+         that int32 indices are ok. 
42+ 
43+         """ 
3444        if  len (node .all_input_nodes ) !=  2 :
3545            self .reporter .report_reject (
3646                node ,
3747                (f"Expected exactly two input nodes, got { len (node .all_input_nodes )}  " ),
3848            )
3949            return  False 
50+ 
4051        indices_val  =  node .all_input_nodes [1 ].meta ["val" ]
4152        indices_dtype  =  indices_val .dtype 
4253
0 commit comments