Skip to content

Commit f3a2081

Browse files
committed
[Backend Tester] Add embedding tests
ghstack-source-id: 6aecb24 ghstack-comment-id: 3116316430 Pull-Request: #12849
1 parent a1c1523 commit f3a2081

File tree

4 files changed

+270
-2
lines changed

4 files changed

+270
-2
lines changed

backends/test/suite/operators/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@ def _create_test_for_backend(
133133

134134

135135
class OperatorTest(unittest.TestCase):
136-
def _test_op(self, model, inputs, flow: TestFlow):
136+
def _test_op(
137+
self, model, inputs, flow: TestFlow, generate_random_test_inputs: bool = True
138+
):
137139
context = get_active_test_context()
138140

139141
# This should be set in the wrapped test. See _make_wrapped_test above.
@@ -145,6 +147,7 @@ def _test_op(self, model, inputs, flow: TestFlow):
145147
flow,
146148
context.test_name,
147149
context.params,
150+
generate_random_test_inputs=generate_random_test_inputs,
148151
)
149152

150153
log_test_summary(run_summary)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from typing import Optional
10+
11+
import torch
12+
from executorch.backends.test.suite.flow import TestFlow
13+
14+
from executorch.backends.test.suite.operators import (
15+
dtype_test,
16+
operator_test,
17+
OperatorTest,
18+
)
19+
20+
21+
class Model(torch.nn.Module):
22+
def __init__(
23+
self,
24+
num_embeddings=10,
25+
embedding_dim=5,
26+
padding_idx: Optional[int] = None,
27+
norm_type: float = 2.0,
28+
):
29+
super().__init__()
30+
self.embedding = torch.nn.Embedding(
31+
num_embeddings=num_embeddings,
32+
embedding_dim=embedding_dim,
33+
padding_idx=padding_idx,
34+
norm_type=norm_type,
35+
)
36+
37+
def forward(self, x):
38+
return self.embedding(x)
39+
40+
41+
@operator_test
42+
class Embedding(OperatorTest):
43+
@dtype_test
44+
def test_embedding_dtype(self, flow: TestFlow, dtype) -> None:
45+
self._test_op(
46+
Model().to(dtype),
47+
(torch.randint(0, 10, (2, 8), dtype=torch.long),),
48+
flow,
49+
generate_random_test_inputs=False,
50+
)
51+
52+
def test_embedding_basic(self, flow: TestFlow) -> None:
53+
self._test_op(
54+
Model(),
55+
(torch.randint(0, 10, (2, 8), dtype=torch.long),),
56+
flow,
57+
generate_random_test_inputs=False,
58+
)
59+
60+
def test_embedding_sizes(self, flow: TestFlow) -> None:
61+
self._test_op(
62+
Model(num_embeddings=5, embedding_dim=3),
63+
(torch.randint(0, 5, (2, 8), dtype=torch.long),),
64+
flow,
65+
generate_random_test_inputs=False,
66+
)
67+
self._test_op(
68+
Model(num_embeddings=100, embedding_dim=10),
69+
(torch.randint(0, 100, (2, 8), dtype=torch.long),),
70+
flow,
71+
generate_random_test_inputs=False,
72+
)
73+
self._test_op(
74+
Model(num_embeddings=1000, embedding_dim=50),
75+
(torch.randint(0, 1000, (2, 4), dtype=torch.long),),
76+
flow,
77+
generate_random_test_inputs=False,
78+
)
79+
80+
def test_embedding_padding_idx(self, flow: TestFlow) -> None:
81+
self._test_op(
82+
Model(padding_idx=0),
83+
(torch.randint(0, 10, (2, 8), dtype=torch.long),),
84+
flow,
85+
generate_random_test_inputs=False,
86+
)
87+
self._test_op(
88+
Model(padding_idx=5),
89+
(torch.randint(0, 10, (2, 8), dtype=torch.long),),
90+
flow,
91+
generate_random_test_inputs=False,
92+
)
93+
94+
def test_embedding_input_shapes(self, flow: TestFlow) -> None:
95+
self._test_op(
96+
Model(),
97+
(torch.randint(0, 10, (5,), dtype=torch.long),),
98+
flow,
99+
generate_random_test_inputs=False,
100+
)
101+
self._test_op(
102+
Model(),
103+
(torch.randint(0, 10, (2, 8), dtype=torch.long),),
104+
flow,
105+
generate_random_test_inputs=False,
106+
)
107+
self._test_op(
108+
Model(),
109+
(torch.randint(0, 10, (2, 3, 4), dtype=torch.long),),
110+
flow,
111+
generate_random_test_inputs=False,
112+
)
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from typing import Optional
10+
11+
import torch
12+
from executorch.backends.test.suite.flow import TestFlow
13+
14+
from executorch.backends.test.suite.operators import (
15+
dtype_test,
16+
operator_test,
17+
OperatorTest,
18+
)
19+
20+
21+
class Model(torch.nn.Module):
22+
def __init__(
23+
self,
24+
num_embeddings=10,
25+
embedding_dim=5,
26+
mode="mean",
27+
padding_idx: Optional[int] = None,
28+
norm_type: float = 2.0,
29+
include_last_offset: bool = False,
30+
):
31+
super().__init__()
32+
self.embedding_bag = torch.nn.EmbeddingBag(
33+
num_embeddings=num_embeddings,
34+
embedding_dim=embedding_dim,
35+
mode=mode,
36+
padding_idx=padding_idx,
37+
norm_type=norm_type,
38+
include_last_offset=include_last_offset,
39+
)
40+
41+
def forward(self, x, offsets=None):
42+
return self.embedding_bag(x, offsets)
43+
44+
45+
@operator_test
46+
class EmbeddingBag(OperatorTest):
47+
@dtype_test
48+
def test_embedding_bag_dtype(self, flow: TestFlow, dtype) -> None:
49+
indices = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
50+
offsets = torch.tensor([0, 4], dtype=torch.long)
51+
self._test_op(
52+
Model().to(dtype),
53+
(indices, offsets),
54+
flow,
55+
generate_random_test_inputs=False,
56+
)
57+
58+
def test_embedding_bag_basic(self, flow: TestFlow) -> None:
59+
indices = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
60+
offsets = torch.tensor([0, 4], dtype=torch.long)
61+
self._test_op(
62+
Model(),
63+
(indices, offsets),
64+
flow,
65+
generate_random_test_inputs=False,
66+
)
67+
68+
def test_embedding_bag_sizes(self, flow: TestFlow) -> None:
69+
indices = torch.tensor([1, 2, 3, 1], dtype=torch.long)
70+
offsets = torch.tensor([0, 2], dtype=torch.long)
71+
72+
self._test_op(
73+
Model(num_embeddings=5, embedding_dim=3),
74+
(indices, offsets),
75+
flow,
76+
generate_random_test_inputs=False,
77+
)
78+
79+
indices = torch.tensor([5, 20, 10, 43, 7], dtype=torch.long)
80+
offsets = torch.tensor([0, 2, 4], dtype=torch.long)
81+
self._test_op(
82+
Model(num_embeddings=50, embedding_dim=10),
83+
(indices, offsets),
84+
flow,
85+
generate_random_test_inputs=False,
86+
)
87+
88+
indices = torch.tensor([100, 200, 300, 400], dtype=torch.long)
89+
offsets = torch.tensor([0, 2], dtype=torch.long)
90+
self._test_op(
91+
Model(num_embeddings=500, embedding_dim=20),
92+
(indices, offsets),
93+
flow,
94+
generate_random_test_inputs=False,
95+
)
96+
97+
def test_embedding_bag_modes(self, flow: TestFlow) -> None:
98+
indices = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
99+
offsets = torch.tensor([0, 4], dtype=torch.long)
100+
101+
self._test_op(
102+
Model(mode="sum"),
103+
(indices, offsets),
104+
flow,
105+
generate_random_test_inputs=False,
106+
)
107+
self._test_op(
108+
Model(mode="mean"),
109+
(indices, offsets),
110+
flow,
111+
generate_random_test_inputs=False,
112+
)
113+
self._test_op(
114+
Model(mode="max"),
115+
(indices, offsets),
116+
flow,
117+
generate_random_test_inputs=False,
118+
)
119+
120+
def test_embedding_bag_padding_idx(self, flow: TestFlow) -> None:
121+
indices = torch.tensor([0, 1, 2, 0, 3, 0, 4], dtype=torch.long)
122+
offsets = torch.tensor([0, 3, 6], dtype=torch.long)
123+
124+
self._test_op(
125+
Model(padding_idx=0),
126+
(indices, offsets),
127+
flow,
128+
generate_random_test_inputs=False,
129+
)
130+
131+
indices = torch.tensor([1, 5, 2, 5, 3, 5, 4], dtype=torch.long)
132+
offsets = torch.tensor([0, 3, 6], dtype=torch.long)
133+
134+
self._test_op(
135+
Model(padding_idx=5),
136+
(indices, offsets),
137+
flow,
138+
generate_random_test_inputs=False,
139+
)
140+
141+
def test_embedding_bag_include_last_offset(self, flow: TestFlow) -> None:
142+
indices = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
143+
offsets = torch.tensor([0, 4], dtype=torch.long)
144+
145+
self._test_op(
146+
Model(include_last_offset=True),
147+
(indices, offsets),
148+
flow,
149+
generate_random_test_inputs=False,
150+
)

backends/test/suite/runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def run_test( # noqa: C901
3333
test_name: str,
3434
params: dict | None,
3535
dynamic_shapes: Any | None = None,
36+
generate_random_test_inputs: bool = True,
3637
) -> TestCaseSummary:
3738
"""
3839
Top-level test run function for a model, input set, and tester. Handles test execution
@@ -102,7 +103,9 @@ def build_result(
102103
# the cause of a failure in run_method_and_compare_outputs. We can look for
103104
# AssertionErrors to catch output mismatches, but this might catch more than that.
104105
try:
105-
tester.run_method_and_compare_outputs()
106+
tester.run_method_and_compare_outputs(
107+
None if generate_random_test_inputs else inputs
108+
)
106109
except AssertionError as e:
107110
return build_result(TestResult.OUTPUT_MISMATCH_FAIL, e)
108111
except Exception as e:

0 commit comments

Comments
 (0)