Skip to content

Commit 2de680e

Browse files
committed
Update
[ghstack-poisoned]
1 parent fd73fb9 commit 2de680e

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class Model(torch.nn.Module):
16+
def __init__(
17+
self,
18+
in_features=10,
19+
out_features=5,
20+
bias=True,
21+
):
22+
super().__init__()
23+
self.linear = torch.nn.Linear(
24+
in_features=in_features,
25+
out_features=out_features,
26+
bias=bias,
27+
)
28+
29+
def forward(self, x):
30+
return self.linear(x)
31+
32+
@operator_test
33+
class TestLinear(OperatorTest):
34+
@dtype_test
35+
def test_linear_dtype(self, dtype, tester_factory: Callable) -> None:
36+
# Input shape: (batch_size, in_features)
37+
model = Model().to(dtype)
38+
self._test_op(model, ((torch.rand(2, 10) * 10).to(dtype),), tester_factory)
39+
40+
@dtype_test
41+
def test_linear_no_bias_dtype(self, dtype, tester_factory: Callable) -> None:
42+
# Input shape: (batch_size, in_features)
43+
model = Model(bias=False).to(dtype)
44+
self._test_op(model, ((torch.rand(2, 10) * 10).to(dtype),), tester_factory)
45+
46+
def test_linear_basic(self, tester_factory: Callable) -> None:
47+
# Basic test with default parameters
48+
self._test_op(Model(), (torch.randn(2, 10),), tester_factory)
49+
50+
def test_linear_feature_sizes(self, tester_factory: Callable) -> None:
51+
# Test with different input and output feature sizes
52+
self._test_op(Model(in_features=5, out_features=3), (torch.randn(2, 5),), tester_factory)
53+
self._test_op(Model(in_features=20, out_features=10), (torch.randn(2, 20),), tester_factory)
54+
self._test_op(Model(in_features=100, out_features=1), (torch.randn(2, 100),), tester_factory)
55+
self._test_op(Model(in_features=1, out_features=100), (torch.randn(2, 1),), tester_factory)
56+
57+
def test_linear_no_bias(self, tester_factory: Callable) -> None:
58+
# Test without bias
59+
self._test_op(Model(bias=False), (torch.randn(2, 10),), tester_factory)
60+
self._test_op(Model(in_features=20, out_features=15, bias=False), (torch.randn(2, 20),), tester_factory)
61+
62+
def test_linear_batch_sizes(self, tester_factory: Callable) -> None:
63+
# Test with different batch sizes
64+
self._test_op(Model(), (torch.randn(1, 10),), tester_factory)
65+
self._test_op(Model(), (torch.randn(5, 10),), tester_factory)
66+
self._test_op(Model(), (torch.randn(100, 10),), tester_factory)
67+
68+
def test_linear_unbatched(self, tester_factory: Callable) -> None:
69+
# Test with unbatched input (just features)
70+
self._test_op(Model(), (torch.randn(10),), tester_factory)
71+
72+
def test_linear_multi_dim_input(self, tester_factory: Callable) -> None:
73+
# Test with multi-dimensional input
74+
# For multi-dimensional inputs, the linear transformation is applied to the last dimension
75+
self._test_op(Model(), (torch.randn(3, 4, 10),), tester_factory)
76+
self._test_op(Model(), (torch.randn(2, 3, 4, 10),), tester_factory)
77+

0 commit comments

Comments
 (0)