Skip to content

Commit a1c1523

Browse files
committed
[Backend Tester] Add linear tests
ghstack-source-id: 78c21f1 ghstack-comment-id: 3116316332 Pull-Request: #12848
1 parent 2131b7e commit a1c1523

File tree

1 file changed

+131
-0
lines changed

1 file changed

+131
-0
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
10+
import torch
11+
from executorch.backends.test.suite.flow import TestFlow
12+
13+
from executorch.backends.test.suite.operators import (
14+
dtype_test,
15+
operator_test,
16+
OperatorTest,
17+
)
18+
19+
20+
class Model(torch.nn.Module):
21+
def __init__(
22+
self,
23+
in_features=10,
24+
out_features=5,
25+
bias=True,
26+
):
27+
super().__init__()
28+
self.linear = torch.nn.Linear(
29+
in_features=in_features,
30+
out_features=out_features,
31+
bias=bias,
32+
)
33+
34+
def forward(self, x):
35+
return self.linear(x)
36+
37+
38+
@operator_test
39+
class Linear(OperatorTest):
40+
@dtype_test
41+
def test_linear_dtype(self, flow: TestFlow, dtype) -> None:
42+
self._test_op(
43+
Model().to(dtype),
44+
((torch.rand(2, 10) * 10).to(dtype),),
45+
flow,
46+
)
47+
48+
@dtype_test
49+
def test_linear_no_bias_dtype(self, flow: TestFlow, dtype) -> None:
50+
self._test_op(
51+
Model(bias=False).to(dtype),
52+
((torch.rand(2, 10) * 10).to(dtype),),
53+
flow,
54+
)
55+
56+
def test_linear_basic(self, flow: TestFlow) -> None:
57+
self._test_op(
58+
Model(),
59+
(torch.randn(2, 10),),
60+
flow,
61+
)
62+
63+
def test_linear_feature_sizes(self, flow: TestFlow) -> None:
64+
self._test_op(
65+
Model(in_features=5, out_features=3),
66+
(torch.randn(2, 5),),
67+
flow,
68+
)
69+
self._test_op(
70+
Model(in_features=20, out_features=10),
71+
(torch.randn(2, 20),),
72+
flow,
73+
)
74+
self._test_op(
75+
Model(in_features=100, out_features=1),
76+
(torch.randn(2, 100),),
77+
flow,
78+
)
79+
self._test_op(
80+
Model(in_features=1, out_features=100),
81+
(torch.randn(2, 1),),
82+
flow,
83+
)
84+
85+
def test_linear_no_bias(self, flow: TestFlow) -> None:
86+
self._test_op(
87+
Model(bias=False),
88+
(torch.randn(2, 10),),
89+
flow,
90+
)
91+
self._test_op(
92+
Model(in_features=20, out_features=15, bias=False),
93+
(torch.randn(2, 20),),
94+
flow,
95+
)
96+
97+
def test_linear_batch_sizes(self, flow: TestFlow) -> None:
98+
self._test_op(
99+
Model(),
100+
(torch.randn(1, 10),),
101+
flow,
102+
)
103+
self._test_op(
104+
Model(),
105+
(torch.randn(5, 10),),
106+
flow,
107+
)
108+
self._test_op(
109+
Model(),
110+
(torch.randn(100, 10),),
111+
flow,
112+
)
113+
114+
def test_linear_unbatched(self, flow: TestFlow) -> None:
115+
self._test_op(
116+
Model(),
117+
(torch.randn(10),),
118+
flow,
119+
)
120+
121+
def test_linear_multi_dim_input(self, flow: TestFlow) -> None:
122+
self._test_op(
123+
Model(),
124+
(torch.randn(3, 4, 10),),
125+
flow,
126+
)
127+
self._test_op(
128+
Model(),
129+
(torch.randn(2, 3, 4, 10),),
130+
flow,
131+
)

0 commit comments

Comments
 (0)