Skip to content

Commit 952fd1d

Browse files
authored
[Backend Tester] Add embedding tests (#12849)
Add tests for embedding and embedding_bag.
1 parent 964fee9 commit 952fd1d

File tree

4 files changed

+215
-2
lines changed

4 files changed

+215
-2
lines changed

backends/test/suite/operators/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@ def _create_test_for_backend(
133133

134134

135135
class OperatorTest(unittest.TestCase):
136-
def _test_op(self, model, inputs, flow: TestFlow):
136+
def _test_op(
137+
self, model, inputs, flow: TestFlow, generate_random_test_inputs: bool = True
138+
):
137139
context = get_active_test_context()
138140

139141
# This should be set in the wrapped test. See _make_wrapped_test above.
@@ -145,6 +147,7 @@ def _test_op(self, model, inputs, flow: TestFlow):
145147
flow,
146148
context.test_name,
147149
context.params,
150+
generate_random_test_inputs=generate_random_test_inputs,
148151
)
149152

150153
log_test_summary(run_summary)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
from executorch.backends.test.suite.flow import TestFlow
11+
12+
from executorch.backends.test.suite.operators import (
13+
dtype_test,
14+
operator_test,
15+
OperatorTest,
16+
)
17+
18+
19+
class Model(torch.nn.Module):
20+
def __init__(
21+
self,
22+
num_embeddings=100,
23+
embedding_dim=50,
24+
):
25+
super().__init__()
26+
self.embedding = torch.nn.Embedding(
27+
num_embeddings=num_embeddings,
28+
embedding_dim=embedding_dim,
29+
)
30+
31+
def forward(self, x):
32+
return self.embedding(x)
33+
34+
35+
@operator_test
36+
class Embedding(OperatorTest):
37+
# Note that generate_random_test_inputs is used to avoid the tester
38+
# generating random inputs that are out of range of the embedding size.
39+
# The tester's random input generation is not smart enough to know that
40+
# the index inputs must be within a certain range.
41+
42+
@dtype_test
43+
def test_embedding_dtype(self, flow: TestFlow, dtype) -> None:
44+
self._test_op(
45+
Model().to(dtype),
46+
(torch.randint(0, 10, (2, 8), dtype=torch.long),),
47+
flow,
48+
generate_random_test_inputs=False,
49+
)
50+
51+
def test_embedding_sizes(self, flow: TestFlow) -> None:
52+
self._test_op(
53+
Model(num_embeddings=5, embedding_dim=3),
54+
(torch.randint(0, 5, (2, 8), dtype=torch.long),),
55+
flow,
56+
generate_random_test_inputs=False,
57+
)
58+
self._test_op(
59+
Model(num_embeddings=100, embedding_dim=10),
60+
(torch.randint(0, 100, (2, 8), dtype=torch.long),),
61+
flow,
62+
generate_random_test_inputs=False,
63+
)
64+
self._test_op(
65+
Model(num_embeddings=1000, embedding_dim=50),
66+
(torch.randint(0, 1000, (2, 4), dtype=torch.long),),
67+
flow,
68+
generate_random_test_inputs=False,
69+
)
70+
71+
def test_embedding_batch_dim(self, flow: TestFlow) -> None:
72+
self._test_op(
73+
Model(),
74+
(torch.randint(0, 100, (5,), dtype=torch.long),),
75+
flow,
76+
generate_random_test_inputs=False,
77+
)
78+
self._test_op(
79+
Model(),
80+
(torch.randint(0, 100, (2, 8), dtype=torch.long),),
81+
flow,
82+
generate_random_test_inputs=False,
83+
)
84+
self._test_op(
85+
Model(),
86+
(torch.randint(0, 100, (2, 3, 4), dtype=torch.long),),
87+
flow,
88+
generate_random_test_inputs=False,
89+
)
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
from executorch.backends.test.suite.flow import TestFlow
11+
12+
from executorch.backends.test.suite.operators import (
13+
dtype_test,
14+
operator_test,
15+
OperatorTest,
16+
)
17+
18+
19+
class Model(torch.nn.Module):
20+
def __init__(
21+
self,
22+
num_embeddings=10,
23+
embedding_dim=5,
24+
mode="mean",
25+
include_last_offset: bool = False,
26+
):
27+
super().__init__()
28+
self.embedding_bag = torch.nn.EmbeddingBag(
29+
num_embeddings=num_embeddings,
30+
embedding_dim=embedding_dim,
31+
mode=mode,
32+
include_last_offset=include_last_offset,
33+
)
34+
35+
def forward(self, x, offsets=None):
36+
return self.embedding_bag(x, offsets)
37+
38+
39+
@operator_test
40+
class EmbeddingBag(OperatorTest):
41+
# Note that generate_random_test_inputs is used to avoid the tester
42+
# generating random inputs that are out of range of the embedding size.
43+
# The tester's random input generation is not smart enough to know that
44+
# the index inputs must be within a certain range.
45+
46+
@dtype_test
47+
def test_embedding_bag_dtype(self, flow: TestFlow, dtype) -> None:
48+
indices = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
49+
offsets = torch.tensor([0, 4], dtype=torch.long)
50+
self._test_op(
51+
Model().to(dtype),
52+
(indices, offsets),
53+
flow,
54+
generate_random_test_inputs=False,
55+
)
56+
57+
def test_embedding_bag_sizes(self, flow: TestFlow) -> None:
58+
indices = torch.tensor([1, 2, 3, 1], dtype=torch.long)
59+
offsets = torch.tensor([0, 2], dtype=torch.long)
60+
61+
self._test_op(
62+
Model(num_embeddings=5, embedding_dim=3),
63+
(indices, offsets),
64+
flow,
65+
generate_random_test_inputs=False,
66+
)
67+
68+
indices = torch.tensor([5, 20, 10, 43, 7], dtype=torch.long)
69+
offsets = torch.tensor([0, 2, 4], dtype=torch.long)
70+
self._test_op(
71+
Model(num_embeddings=50, embedding_dim=10),
72+
(indices, offsets),
73+
flow,
74+
generate_random_test_inputs=False,
75+
)
76+
77+
indices = torch.tensor([100, 200, 300, 400], dtype=torch.long)
78+
offsets = torch.tensor([0, 2], dtype=torch.long)
79+
self._test_op(
80+
Model(num_embeddings=500, embedding_dim=20),
81+
(indices, offsets),
82+
flow,
83+
generate_random_test_inputs=False,
84+
)
85+
86+
def test_embedding_bag_modes(self, flow: TestFlow) -> None:
87+
indices = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
88+
offsets = torch.tensor([0, 4], dtype=torch.long)
89+
90+
self._test_op(
91+
Model(mode="sum"),
92+
(indices, offsets),
93+
flow,
94+
generate_random_test_inputs=False,
95+
)
96+
self._test_op(
97+
Model(mode="mean"),
98+
(indices, offsets),
99+
flow,
100+
generate_random_test_inputs=False,
101+
)
102+
self._test_op(
103+
Model(mode="max"),
104+
(indices, offsets),
105+
flow,
106+
generate_random_test_inputs=False,
107+
)
108+
109+
def test_embedding_bag_include_last_offset(self, flow: TestFlow) -> None:
110+
indices = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
111+
offsets = torch.tensor([0, 4], dtype=torch.long)
112+
113+
self._test_op(
114+
Model(include_last_offset=True),
115+
(indices, offsets),
116+
flow,
117+
generate_random_test_inputs=False,
118+
)

backends/test/suite/runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def run_test( # noqa: C901
3333
test_name: str,
3434
params: dict | None,
3535
dynamic_shapes: Any | None = None,
36+
generate_random_test_inputs: bool = True,
3637
) -> TestCaseSummary:
3738
"""
3839
Top-level test run function for a model, input set, and tester. Handles test execution
@@ -102,7 +103,9 @@ def build_result(
102103
# the cause of a failure in run_method_and_compare_outputs. We can look for
103104
# AssertionErrors to catch output mismatches, but this might catch more than that.
104105
try:
105-
tester.run_method_and_compare_outputs()
106+
tester.run_method_and_compare_outputs(
107+
inputs=None if generate_random_test_inputs else inputs
108+
)
106109
except AssertionError as e:
107110
return build_result(TestResult.OUTPUT_MISMATCH_FAIL, e)
108111
except Exception as e:

0 commit comments

Comments
 (0)