Skip to content

Commit f7dcc47

Browse files
committed
[Backend Tester] Add embedding tests
ghstack-source-id: 2f7abec ghstack-comment-id: 3116316430 Pull-Request: #12849
1 parent 9d29fbb commit f7dcc47

File tree

4 files changed

+217
-2
lines changed

4 files changed

+217
-2
lines changed

backends/test/suite/operators/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@ def _create_test_for_backend(
133133

134134

135135
class OperatorTest(unittest.TestCase):
136-
def _test_op(self, model, inputs, flow: TestFlow):
136+
def _test_op(
137+
self, model, inputs, flow: TestFlow, generate_random_test_inputs: bool = True
138+
):
137139
context = get_active_test_context()
138140

139141
# This should be set in the wrapped test. See _make_wrapped_test above.
@@ -145,6 +147,7 @@ def _test_op(self, model, inputs, flow: TestFlow):
145147
flow,
146148
context.test_name,
147149
context.params,
150+
generate_random_test_inputs=generate_random_test_inputs,
148151
)
149152

150153
log_test_summary(run_summary)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
from executorch.backends.test.suite.flow import TestFlow
11+
12+
from executorch.backends.test.suite.operators import (
13+
dtype_test,
14+
operator_test,
15+
OperatorTest,
16+
)
17+
18+
19+
class Model(torch.nn.Module):
20+
def __init__(
21+
self,
22+
num_embeddings=100,
23+
embedding_dim=50,
24+
):
25+
super().__init__()
26+
self.embedding = torch.nn.Embedding(
27+
num_embeddings=num_embeddings,
28+
embedding_dim=embedding_dim,
29+
)
30+
31+
def forward(self, x):
32+
return self.embedding(x)
33+
34+
35+
@operator_test
36+
class Embedding(OperatorTest):
37+
# Note that generate_random_test_inputs is used to avoid the tester
38+
# generating random inputs that are out of range of the embedding size.
39+
# The tester's random input generation is not smart enough to know that
40+
# the index inputs must be within a certain range.
41+
42+
@dtype_test
43+
def test_embedding_dtype(self, flow: TestFlow, dtype) -> None:
44+
self._test_op(
45+
Model().to(dtype),
46+
(torch.randint(0, 10, (2, 8), dtype=torch.long),),
47+
flow,
48+
generate_random_test_inputs=False,
49+
)
50+
51+
def test_embedding_sizes(self, flow: TestFlow) -> None:
52+
self._test_op(
53+
Model(num_embeddings=5, embedding_dim=3),
54+
(torch.randint(0, 5, (2, 8), dtype=torch.long),),
55+
flow,
56+
generate_random_test_inputs=False,
57+
)
58+
self._test_op(
59+
Model(num_embeddings=100, embedding_dim=10),
60+
(torch.randint(0, 100, (2, 8), dtype=torch.long),),
61+
flow,
62+
generate_random_test_inputs=False,
63+
)
64+
self._test_op(
65+
Model(num_embeddings=1000, embedding_dim=50),
66+
(torch.randint(0, 1000, (2, 4), dtype=torch.long),),
67+
flow,
68+
generate_random_test_inputs=False,
69+
)
70+
71+
def test_embedding_batch_dim(self, flow: TestFlow) -> None:
72+
self._test_op(
73+
Model(),
74+
(torch.randint(0, 100, (5,), dtype=torch.long),),
75+
flow,
76+
generate_random_test_inputs=False,
77+
)
78+
self._test_op(
79+
Model(),
80+
(torch.randint(0, 100, (2, 8), dtype=torch.long),),
81+
flow,
82+
generate_random_test_inputs=False,
83+
)
84+
self._test_op(
85+
Model(),
86+
(torch.randint(0, 100, (2, 3, 4), dtype=torch.long),),
87+
flow,
88+
generate_random_test_inputs=False,
89+
)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from typing import Optional
10+
11+
import torch
12+
from executorch.backends.test.suite.flow import TestFlow
13+
14+
from executorch.backends.test.suite.operators import (
15+
dtype_test,
16+
operator_test,
17+
OperatorTest,
18+
)
19+
20+
21+
class Model(torch.nn.Module):
22+
def __init__(
23+
self,
24+
num_embeddings=10,
25+
embedding_dim=5,
26+
mode="mean",
27+
include_last_offset: bool = False,
28+
):
29+
super().__init__()
30+
self.embedding_bag = torch.nn.EmbeddingBag(
31+
num_embeddings=num_embeddings,
32+
embedding_dim=embedding_dim,
33+
mode=mode,
34+
include_last_offset=include_last_offset,
35+
)
36+
37+
def forward(self, x, offsets=None):
38+
return self.embedding_bag(x, offsets)
39+
40+
41+
@operator_test
42+
class EmbeddingBag(OperatorTest):
43+
# Note that generate_random_test_inputs is used to avoid the tester
44+
# generating random inputs that are out of range of the embedding size.
45+
# The tester's random input generation is not smart enough to know that
46+
# the index inputs must be within a certain range.
47+
48+
@dtype_test
49+
def test_embedding_bag_dtype(self, flow: TestFlow, dtype) -> None:
50+
indices = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
51+
offsets = torch.tensor([0, 4], dtype=torch.long)
52+
self._test_op(
53+
Model().to(dtype),
54+
(indices, offsets),
55+
flow,
56+
generate_random_test_inputs=False,
57+
)
58+
59+
def test_embedding_bag_sizes(self, flow: TestFlow) -> None:
60+
indices = torch.tensor([1, 2, 3, 1], dtype=torch.long)
61+
offsets = torch.tensor([0, 2], dtype=torch.long)
62+
63+
self._test_op(
64+
Model(num_embeddings=5, embedding_dim=3),
65+
(indices, offsets),
66+
flow,
67+
generate_random_test_inputs=False,
68+
)
69+
70+
indices = torch.tensor([5, 20, 10, 43, 7], dtype=torch.long)
71+
offsets = torch.tensor([0, 2, 4], dtype=torch.long)
72+
self._test_op(
73+
Model(num_embeddings=50, embedding_dim=10),
74+
(indices, offsets),
75+
flow,
76+
generate_random_test_inputs=False,
77+
)
78+
79+
indices = torch.tensor([100, 200, 300, 400], dtype=torch.long)
80+
offsets = torch.tensor([0, 2], dtype=torch.long)
81+
self._test_op(
82+
Model(num_embeddings=500, embedding_dim=20),
83+
(indices, offsets),
84+
flow,
85+
generate_random_test_inputs=False,
86+
)
87+
88+
def test_embedding_bag_modes(self, flow: TestFlow) -> None:
89+
indices = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
90+
offsets = torch.tensor([0, 4], dtype=torch.long)
91+
92+
self._test_op(
93+
Model(mode="sum"),
94+
(indices, offsets),
95+
flow,
96+
generate_random_test_inputs=False,
97+
)
98+
self._test_op(
99+
Model(mode="mean"),
100+
(indices, offsets),
101+
flow,
102+
generate_random_test_inputs=False,
103+
)
104+
self._test_op(
105+
Model(mode="max"),
106+
(indices, offsets),
107+
flow,
108+
generate_random_test_inputs=False,
109+
)
110+
111+
def test_embedding_bag_include_last_offset(self, flow: TestFlow) -> None:
112+
indices = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
113+
offsets = torch.tensor([0, 4], dtype=torch.long)
114+
115+
self._test_op(
116+
Model(include_last_offset=True),
117+
(indices, offsets),
118+
flow,
119+
generate_random_test_inputs=False,
120+
)

backends/test/suite/runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def run_test( # noqa: C901
3333
test_name: str,
3434
params: dict | None,
3535
dynamic_shapes: Any | None = None,
36+
generate_random_test_inputs: bool = True,
3637
) -> TestCaseSummary:
3738
"""
3839
Top-level test run function for a model, input set, and tester. Handles test execution
@@ -102,7 +103,9 @@ def build_result(
102103
# the cause of a failure in run_method_and_compare_outputs. We can look for
103104
# AssertionErrors to catch output mismatches, but this might catch more than that.
104105
try:
105-
tester.run_method_and_compare_outputs()
106+
tester.run_method_and_compare_outputs(
107+
inputs=None if generate_random_test_inputs else inputs
108+
)
106109
except AssertionError as e:
107110
return build_result(TestResult.OUTPUT_MISMATCH_FAIL, e)
108111
except Exception as e:

0 commit comments

Comments
 (0)