Skip to content

Commit 964fee9

Browse files
authored
[Backend Tester] Add linear tests (#12848)
Add tests for linear.
1 parent 574e109 commit 964fee9

File tree

1 file changed

+124
-0
lines changed

1 file changed

+124
-0
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
10+
import torch
11+
from executorch.backends.test.suite.flow import TestFlow
12+
13+
from executorch.backends.test.suite.operators import (
14+
dtype_test,
15+
operator_test,
16+
OperatorTest,
17+
)
18+
19+
20+
class Model(torch.nn.Module):
21+
def __init__(
22+
self,
23+
in_features=67,
24+
out_features=43,
25+
bias=True,
26+
):
27+
super().__init__()
28+
self.linear = torch.nn.Linear(
29+
in_features=in_features,
30+
out_features=out_features,
31+
bias=bias,
32+
)
33+
34+
def forward(self, x):
35+
return self.linear(x)
36+
37+
38+
@operator_test
39+
class Linear(OperatorTest):
40+
@dtype_test
41+
def test_linear_dtype(self, flow: TestFlow, dtype) -> None:
42+
self._test_op(
43+
Model().to(dtype),
44+
((torch.rand(16, 64) * 10).to(dtype),),
45+
flow,
46+
)
47+
48+
@dtype_test
49+
def test_linear_no_bias_dtype(self, flow: TestFlow, dtype) -> None:
50+
self._test_op(
51+
Model(bias=False).to(dtype),
52+
((torch.rand(16, 64) * 10).to(dtype),),
53+
flow,
54+
)
55+
56+
def test_linear_feature_sizes(self, flow: TestFlow) -> None:
57+
self._test_op(
58+
Model(in_features=32, out_features=16),
59+
(torch.randn(20, 32),),
60+
flow,
61+
)
62+
self._test_op(
63+
Model(in_features=128, out_features=64),
64+
(torch.randn(8, 128),),
65+
flow,
66+
)
67+
self._test_op(
68+
Model(in_features=256, out_features=1),
69+
(torch.randn(4, 256),),
70+
flow,
71+
)
72+
self._test_op(
73+
Model(in_features=1, out_features=512),
74+
(torch.randn(1024, 1),),
75+
flow,
76+
)
77+
78+
def test_linear_no_bias(self, flow: TestFlow) -> None:
79+
self._test_op(
80+
Model(bias=False),
81+
(torch.randn(16, 64),),
82+
flow,
83+
)
84+
self._test_op(
85+
Model(in_features=128, out_features=96, bias=False),
86+
(torch.randn(8, 128),),
87+
flow,
88+
)
89+
90+
def test_linear_batch_sizes(self, flow: TestFlow) -> None:
91+
self._test_op(
92+
Model(),
93+
(torch.randn(8, 64),),
94+
flow,
95+
)
96+
self._test_op(
97+
Model(),
98+
(torch.randn(32, 64),),
99+
flow,
100+
)
101+
self._test_op(
102+
Model(),
103+
(torch.randn(100, 64),),
104+
flow,
105+
)
106+
107+
def test_linear_unbatched(self, flow: TestFlow) -> None:
108+
self._test_op(
109+
Model(in_features=512),
110+
(torch.randn(512),),
111+
flow,
112+
)
113+
114+
def test_linear_leading_batch(self, flow: TestFlow) -> None:
115+
self._test_op(
116+
Model(),
117+
(torch.randn(4, 8, 64),),
118+
flow,
119+
)
120+
self._test_op(
121+
Model(),
122+
(torch.randn(2, 4, 8, 64),),
123+
flow,
124+
)

0 commit comments

Comments
 (0)