1+ # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+ # pyre-strict
4+
5+ from typing import Callable
6+
7+ import torch
8+
9+ from executorch .backends .test .compliance_suite import (
10+ dtype_test ,
11+ operator_test ,
12+ OperatorTest ,
13+ )
14+
15+ class Model (torch .nn .Module ):
16+ def __init__ (
17+ self ,
18+ in_features = 10 ,
19+ out_features = 5 ,
20+ bias = True ,
21+ ):
22+ super ().__init__ ()
23+ self .linear = torch .nn .Linear (
24+ in_features = in_features ,
25+ out_features = out_features ,
26+ bias = bias ,
27+ )
28+
29+ def forward (self , x ):
30+ return self .linear (x )
31+
32+ @operator_test
33+ class TestLinear (OperatorTest ):
34+ @dtype_test
35+ def test_linear_dtype (self , dtype , tester_factory : Callable ) -> None :
36+ # Input shape: (batch_size, in_features)
37+ model = Model ().to (dtype )
38+ self ._test_op (model , ((torch .rand (2 , 10 ) * 10 ).to (dtype ),), tester_factory )
39+
40+ @dtype_test
41+ def test_linear_no_bias_dtype (self , dtype , tester_factory : Callable ) -> None :
42+ # Input shape: (batch_size, in_features)
43+ model = Model (bias = False ).to (dtype )
44+ self ._test_op (model , ((torch .rand (2 , 10 ) * 10 ).to (dtype ),), tester_factory )
45+
46+ def test_linear_basic (self , tester_factory : Callable ) -> None :
47+ # Basic test with default parameters
48+ self ._test_op (Model (), (torch .randn (2 , 10 ),), tester_factory )
49+
50+ def test_linear_feature_sizes (self , tester_factory : Callable ) -> None :
51+ # Test with different input and output feature sizes
52+ self ._test_op (Model (in_features = 5 , out_features = 3 ), (torch .randn (2 , 5 ),), tester_factory )
53+ self ._test_op (Model (in_features = 20 , out_features = 10 ), (torch .randn (2 , 20 ),), tester_factory )
54+ self ._test_op (Model (in_features = 100 , out_features = 1 ), (torch .randn (2 , 100 ),), tester_factory )
55+ self ._test_op (Model (in_features = 1 , out_features = 100 ), (torch .randn (2 , 1 ),), tester_factory )
56+
57+ def test_linear_no_bias (self , tester_factory : Callable ) -> None :
58+ # Test without bias
59+ self ._test_op (Model (bias = False ), (torch .randn (2 , 10 ),), tester_factory )
60+ self ._test_op (Model (in_features = 20 , out_features = 15 , bias = False ), (torch .randn (2 , 20 ),), tester_factory )
61+
62+ def test_linear_batch_sizes (self , tester_factory : Callable ) -> None :
63+ # Test with different batch sizes
64+ self ._test_op (Model (), (torch .randn (1 , 10 ),), tester_factory )
65+ self ._test_op (Model (), (torch .randn (5 , 10 ),), tester_factory )
66+ self ._test_op (Model (), (torch .randn (100 , 10 ),), tester_factory )
67+
68+ def test_linear_unbatched (self , tester_factory : Callable ) -> None :
69+ # Test with unbatched input (just features)
70+ self ._test_op (Model (), (torch .randn (10 ),), tester_factory )
71+
72+ def test_linear_multi_dim_input (self , tester_factory : Callable ) -> None :
73+ # Test with multi-dimensional input
74+ # For multi-dimensional inputs, the linear transformation is applied to the last dimension
75+ self ._test_op (Model (), (torch .randn (3 , 4 , 10 ),), tester_factory )
76+ self ._test_op (Model (), (torch .randn (2 , 3 , 4 , 10 ),), tester_factory )
77+
0 commit comments