1+ # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+ # pyre-strict
4+
5+ from typing import Callable , Optional
6+
7+ import torch
8+
9+ from executorch .backends .test .compliance_suite import (
10+ dtype_test ,
11+ operator_test ,
12+ OperatorTest ,
13+ )
14+
15+ class Model (torch .nn .Module ):
16+ def __init__ (
17+ self ,
18+ num_embeddings = 10 ,
19+ embedding_dim = 5 ,
20+ mode = 'mean' ,
21+ padding_idx : Optional [int ] = None ,
22+ norm_type : float = 2.0 ,
23+ include_last_offset : bool = False ,
24+ ):
25+ super ().__init__ ()
26+ self .embedding_bag = torch .nn .EmbeddingBag (
27+ num_embeddings = num_embeddings ,
28+ embedding_dim = embedding_dim ,
29+ mode = mode ,
30+ padding_idx = padding_idx ,
31+ norm_type = norm_type ,
32+ include_last_offset = include_last_offset ,
33+ )
34+
35+ def forward (self , x , offsets = None ):
36+ return self .embedding_bag (x , offsets )
37+
38+ @operator_test
39+ class TestEmbeddingBag (OperatorTest ):
40+ @dtype_test
41+ def test_embedding_bag_dtype (self , dtype , tester_factory : Callable ) -> None :
42+ # Input: indices and offsets
43+ # Note: Input indices should be of type Long (int64)
44+ model = Model ().to (dtype )
45+ indices = torch .tensor ([1 , 2 , 4 , 5 , 4 , 3 , 2 , 9 ], dtype = torch .long )
46+ offsets = torch .tensor ([0 , 4 ], dtype = torch .long ) # 2 bags
47+ self ._test_op (model , (indices , offsets ), tester_factory , use_random_test_inputs = False )
48+
49+ def test_embedding_bag_basic (self , tester_factory : Callable ) -> None :
50+ # Basic test with default parameters
51+ indices = torch .tensor ([1 , 2 , 4 , 5 , 4 , 3 , 2 , 9 ], dtype = torch .long )
52+ offsets = torch .tensor ([0 , 4 ], dtype = torch .long ) # 2 bags
53+ self ._test_op (Model (), (indices , offsets ), tester_factory , use_random_test_inputs = False )
54+
55+ def test_embedding_bag_sizes (self , tester_factory : Callable ) -> None :
56+ # Test with different dictionary sizes and embedding dimensions
57+ indices = torch .tensor ([1 , 2 , 3 , 1 ], dtype = torch .long )
58+ offsets = torch .tensor ([0 , 2 ], dtype = torch .long )
59+
60+ self ._test_op (Model (num_embeddings = 5 , embedding_dim = 3 ),
61+ (indices , offsets ), tester_factory , use_random_test_inputs = False )
62+
63+ indices = torch .tensor ([5 , 20 , 10 , 43 , 7 ], dtype = torch .long )
64+ offsets = torch .tensor ([0 , 2 , 4 ], dtype = torch .long )
65+ self ._test_op (Model (num_embeddings = 50 , embedding_dim = 10 ),
66+ (indices , offsets ), tester_factory , use_random_test_inputs = False )
67+
68+ indices = torch .tensor ([100 , 200 , 300 , 400 ], dtype = torch .long )
69+ offsets = torch .tensor ([0 , 2 ], dtype = torch .long )
70+ self ._test_op (Model (num_embeddings = 500 , embedding_dim = 20 ),
71+ (indices , offsets ), tester_factory , use_random_test_inputs = False )
72+
73+ def test_embedding_bag_modes (self , tester_factory : Callable ) -> None :
74+ # Test with different modes (sum, mean, max)
75+ indices = torch .tensor ([1 , 2 , 4 , 5 , 4 , 3 , 2 , 9 ], dtype = torch .long )
76+ offsets = torch .tensor ([0 , 4 ], dtype = torch .long )
77+
78+ self ._test_op (Model (mode = 'sum' ), (indices , offsets ), tester_factory , use_random_test_inputs = False )
79+ self ._test_op (Model (mode = 'mean' ), (indices , offsets ), tester_factory , use_random_test_inputs = False )
80+ self ._test_op (Model (mode = 'max' ), (indices , offsets ), tester_factory , use_random_test_inputs = False )
81+
82+ def test_embedding_bag_padding_idx (self , tester_factory : Callable ) -> None :
83+ # Test with padding_idx
84+ indices = torch .tensor ([0 , 1 , 2 , 0 , 3 , 0 , 4 ], dtype = torch .long )
85+ offsets = torch .tensor ([0 , 3 , 6 ], dtype = torch .long )
86+
87+ self ._test_op (Model (padding_idx = 0 ), (indices , offsets ), tester_factory , use_random_test_inputs = False )
88+
89+ indices = torch .tensor ([1 , 5 , 2 , 5 , 3 , 5 , 4 ], dtype = torch .long )
90+ offsets = torch .tensor ([0 , 3 , 6 ], dtype = torch .long )
91+
92+ self ._test_op (Model (padding_idx = 5 ), (indices , offsets ), tester_factory , use_random_test_inputs = False )
93+
94+ def test_embedding_bag_include_last_offset (self , tester_factory : Callable ) -> None :
95+ # Test with include_last_offset
96+ indices = torch .tensor ([1 , 2 , 4 , 5 , 4 , 3 , 2 , 9 ], dtype = torch .long )
97+ offsets = torch .tensor ([0 , 4 ], dtype = torch .long )
98+
99+ self ._test_op (Model (include_last_offset = True ), (indices , offsets ), tester_factory , use_random_test_inputs = False )
100+
0 commit comments