Skip to content

Commit 8b11366

Browse files
committed
Update
[ghstack-poisoned]
1 parent 2de680e commit 8b11366

File tree

2 files changed

+168
-0
lines changed

2 files changed

+168
-0
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable, Optional
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class Model(torch.nn.Module):
16+
def __init__(
17+
self,
18+
num_embeddings=10,
19+
embedding_dim=5,
20+
padding_idx: Optional[int] = None,
21+
norm_type: float = 2.0,
22+
):
23+
super().__init__()
24+
self.embedding = torch.nn.Embedding(
25+
num_embeddings=num_embeddings,
26+
embedding_dim=embedding_dim,
27+
padding_idx=padding_idx,
28+
norm_type=norm_type,
29+
)
30+
31+
def forward(self, x):
32+
return self.embedding(x)
33+
34+
@operator_test
35+
class TestEmbedding(OperatorTest):
36+
@dtype_test
37+
def test_embedding_dtype(self, dtype, tester_factory: Callable) -> None:
38+
# Input shape: (batch_size, seq_length)
39+
# Note: Input indices should be of type Long (int64)
40+
model = Model().to(dtype)
41+
self._test_op(model, (torch.randint(0, 10, (2, 8), dtype=torch.long),), tester_factory, use_random_test_inputs=False)
42+
43+
def test_embedding_basic(self, tester_factory: Callable) -> None:
44+
# Basic test with default parameters
45+
self._test_op(Model(), (torch.randint(0, 10, (2, 8), dtype=torch.long),), tester_factory, use_random_test_inputs=False)
46+
47+
def test_embedding_sizes(self, tester_factory: Callable) -> None:
48+
# Test with different dictionary sizes and embedding dimensions
49+
self._test_op(Model(num_embeddings=5, embedding_dim=3),
50+
(torch.randint(0, 5, (2, 8), dtype=torch.long),), tester_factory, use_random_test_inputs=False)
51+
self._test_op(Model(num_embeddings=100, embedding_dim=10),
52+
(torch.randint(0, 100, (2, 8), dtype=torch.long),), tester_factory, use_random_test_inputs=False)
53+
self._test_op(Model(num_embeddings=1000, embedding_dim=50),
54+
(torch.randint(0, 1000, (2, 4), dtype=torch.long),), tester_factory, use_random_test_inputs=False)
55+
56+
def test_embedding_padding_idx(self, tester_factory: Callable) -> None:
57+
# Test with padding_idx
58+
self._test_op(Model(padding_idx=0),
59+
(torch.randint(0, 10, (2, 8), dtype=torch.long),), tester_factory, use_random_test_inputs=False)
60+
self._test_op(Model(padding_idx=5),
61+
(torch.randint(0, 10, (2, 8), dtype=torch.long),), tester_factory, use_random_test_inputs=False)
62+
63+
def test_embedding_input_shapes(self, tester_factory: Callable) -> None:
64+
# Test with different input shapes
65+
self._test_op(Model(), (torch.randint(0, 10, (5,), dtype=torch.long),), tester_factory, use_random_test_inputs=False) # 1D input
66+
self._test_op(Model(), (torch.randint(0, 10, (2, 8), dtype=torch.long),), tester_factory, use_random_test_inputs=False) # 2D input
67+
self._test_op(Model(), (torch.randint(0, 10, (2, 3, 4), dtype=torch.long),), tester_factory, use_random_test_inputs=False) # 3D input
68+
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable, Optional
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class Model(torch.nn.Module):
16+
def __init__(
17+
self,
18+
num_embeddings=10,
19+
embedding_dim=5,
20+
mode='mean',
21+
padding_idx: Optional[int] = None,
22+
norm_type: float = 2.0,
23+
include_last_offset: bool = False,
24+
):
25+
super().__init__()
26+
self.embedding_bag = torch.nn.EmbeddingBag(
27+
num_embeddings=num_embeddings,
28+
embedding_dim=embedding_dim,
29+
mode=mode,
30+
padding_idx=padding_idx,
31+
norm_type=norm_type,
32+
include_last_offset=include_last_offset,
33+
)
34+
35+
def forward(self, x, offsets=None):
36+
return self.embedding_bag(x, offsets)
37+
38+
@operator_test
39+
class TestEmbeddingBag(OperatorTest):
40+
@dtype_test
41+
def test_embedding_bag_dtype(self, dtype, tester_factory: Callable) -> None:
42+
# Input: indices and offsets
43+
# Note: Input indices should be of type Long (int64)
44+
model = Model().to(dtype)
45+
indices = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
46+
offsets = torch.tensor([0, 4], dtype=torch.long) # 2 bags
47+
self._test_op(model, (indices, offsets), tester_factory, use_random_test_inputs=False)
48+
49+
def test_embedding_bag_basic(self, tester_factory: Callable) -> None:
50+
# Basic test with default parameters
51+
indices = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
52+
offsets = torch.tensor([0, 4], dtype=torch.long) # 2 bags
53+
self._test_op(Model(), (indices, offsets), tester_factory, use_random_test_inputs=False)
54+
55+
def test_embedding_bag_sizes(self, tester_factory: Callable) -> None:
56+
# Test with different dictionary sizes and embedding dimensions
57+
indices = torch.tensor([1, 2, 3, 1], dtype=torch.long)
58+
offsets = torch.tensor([0, 2], dtype=torch.long)
59+
60+
self._test_op(Model(num_embeddings=5, embedding_dim=3),
61+
(indices, offsets), tester_factory, use_random_test_inputs=False)
62+
63+
indices = torch.tensor([5, 20, 10, 43, 7], dtype=torch.long)
64+
offsets = torch.tensor([0, 2, 4], dtype=torch.long)
65+
self._test_op(Model(num_embeddings=50, embedding_dim=10),
66+
(indices, offsets), tester_factory, use_random_test_inputs=False)
67+
68+
indices = torch.tensor([100, 200, 300, 400], dtype=torch.long)
69+
offsets = torch.tensor([0, 2], dtype=torch.long)
70+
self._test_op(Model(num_embeddings=500, embedding_dim=20),
71+
(indices, offsets), tester_factory, use_random_test_inputs=False)
72+
73+
def test_embedding_bag_modes(self, tester_factory: Callable) -> None:
74+
# Test with different modes (sum, mean, max)
75+
indices = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
76+
offsets = torch.tensor([0, 4], dtype=torch.long)
77+
78+
self._test_op(Model(mode='sum'), (indices, offsets), tester_factory, use_random_test_inputs=False)
79+
self._test_op(Model(mode='mean'), (indices, offsets), tester_factory, use_random_test_inputs=False)
80+
self._test_op(Model(mode='max'), (indices, offsets), tester_factory, use_random_test_inputs=False)
81+
82+
def test_embedding_bag_padding_idx(self, tester_factory: Callable) -> None:
83+
# Test with padding_idx
84+
indices = torch.tensor([0, 1, 2, 0, 3, 0, 4], dtype=torch.long)
85+
offsets = torch.tensor([0, 3, 6], dtype=torch.long)
86+
87+
self._test_op(Model(padding_idx=0), (indices, offsets), tester_factory, use_random_test_inputs=False)
88+
89+
indices = torch.tensor([1, 5, 2, 5, 3, 5, 4], dtype=torch.long)
90+
offsets = torch.tensor([0, 3, 6], dtype=torch.long)
91+
92+
self._test_op(Model(padding_idx=5), (indices, offsets), tester_factory, use_random_test_inputs=False)
93+
94+
def test_embedding_bag_include_last_offset(self, tester_factory: Callable) -> None:
95+
# Test with include_last_offset
96+
indices = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
97+
offsets = torch.tensor([0, 4], dtype=torch.long)
98+
99+
self._test_op(Model(include_last_offset=True), (indices, offsets), tester_factory, use_random_test_inputs=False)
100+

0 commit comments

Comments
 (0)