1- # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD-style license found in the
5+ # LICENSE file in the root directory of this source tree.
26
3- # pyre-strict
7+ # pyre-unsafe
48
5- from typing import Callable , Union , Tuple
9+ from typing import Tuple , Union
610
711import torch
12+ from executorch .backends .test .suite .flow import TestFlow
813
9- from executorch .backends .test .compliance_suite import (
14+ from executorch .backends .test .suite . operators import (
1015 dtype_test ,
1116 operator_test ,
1217 OperatorTest ,
1318)
1419
20+
1521class Model (torch .nn .Module ):
1622 def __init__ (
1723 self ,
1824 in_channels = 3 ,
1925 out_channels = 6 ,
20- kernel_size = 3 ,
21- stride = 1 ,
22- padding = 0 ,
23- dilation = 1 ,
26+ kernel_size : Union [ int , Tuple [ int , int ]] = 3 ,
27+ stride : Union [ int , Tuple [ int , int ]] = 1 ,
28+ padding : Union [ int , Tuple [ int , int ]] = 0 ,
29+ dilation : Union [ int , Tuple [ int , int ]] = 1 ,
2430 groups = 1 ,
2531 bias = True ,
2632 padding_mode = "zeros" ,
@@ -37,60 +43,118 @@ def __init__(
3743 bias = bias ,
3844 padding_mode = padding_mode ,
3945 )
40-
46+
4147 def forward (self , x ):
4248 return self .conv (x )
4349
50+
4451@operator_test
45- class TestConv2d (OperatorTest ):
52+ class Conv2d (OperatorTest ):
4653 @dtype_test
47- def test_conv2d_dtype (self , dtype , tester_factory : Callable ) -> None :
48- # Input shape: (batch_size, in_channels, height, width)
49- self ._test_op (Model ().to (dtype ), ((torch .rand (2 , 3 , 8 , 8 ) * 10 ).to (dtype ),), tester_factory )
50-
51- def test_conv2d_basic (self , tester_factory : Callable ) -> None :
52- # Basic test with default parameters
53- self ._test_op (Model (), (torch .randn (2 , 3 , 8 , 8 ),), tester_factory )
54-
55- def test_conv2d_kernel_size (self , tester_factory : Callable ) -> None :
56- # Test with different kernel sizes
57- self ._test_op (Model (kernel_size = 1 ), (torch .randn (2 , 3 , 8 , 8 ),), tester_factory )
58- self ._test_op (Model (kernel_size = 5 ), (torch .randn (2 , 3 , 8 , 8 ),), tester_factory )
59- self ._test_op (Model (kernel_size = (3 , 5 )), (torch .randn (2 , 3 , 8 , 8 ),), tester_factory )
60-
61- def test_conv2d_stride (self , tester_factory : Callable ) -> None :
62- # Test with different stride values
63- self ._test_op (Model (stride = 2 ), (torch .randn (2 , 3 , 8 , 8 ),), tester_factory )
64- self ._test_op (Model (stride = (2 , 1 )), (torch .randn (2 , 3 , 8 , 8 ),), tester_factory )
65-
66- def test_conv2d_padding (self , tester_factory : Callable ) -> None :
67- # Test with different padding values
68- self ._test_op (Model (padding = 1 ), (torch .randn (2 , 3 , 8 , 8 ),), tester_factory )
69- self ._test_op (Model (padding = (1 , 2 )), (torch .randn (2 , 3 , 8 , 8 ),), tester_factory )
70-
71- def test_conv2d_dilation (self , tester_factory : Callable ) -> None :
72- # Test with different dilation values
73- self ._test_op (Model (dilation = 2 ), (torch .randn (2 , 3 , 8 , 8 ),), tester_factory )
74- self ._test_op (Model (dilation = (2 , 1 )), (torch .randn (2 , 3 , 8 , 8 ),), tester_factory )
75-
76- def test_conv2d_groups (self , tester_factory : Callable ) -> None :
77- # Test with groups=3 (in_channels must be divisible by groups)
78- self ._test_op (Model (in_channels = 6 , out_channels = 6 , groups = 3 ), (torch .randn (2 , 6 , 8 , 8 ),), tester_factory )
79-
80- def test_conv2d_no_bias (self , tester_factory : Callable ) -> None :
81- # Test without bias
82- self ._test_op (Model (bias = False ), (torch .randn (2 , 3 , 8 , 8 ),), tester_factory )
83-
84- def test_conv2d_padding_modes (self , tester_factory : Callable ) -> None :
85- # Test different padding modes
54+ def test_conv2d_dtype (self , flow : TestFlow , dtype ) -> None :
55+ self ._test_op (
56+ Model ().to (dtype ),
57+ ((torch .rand (2 , 3 , 8 , 8 ) * 10 ).to (dtype ),),
58+ flow ,
59+ )
60+
61+ def test_conv2d_basic (self , flow : TestFlow ) -> None :
62+ self ._test_op (
63+ Model (),
64+ (torch .randn (2 , 3 , 8 , 8 ),),
65+ flow ,
66+ )
67+
68+ def test_conv2d_kernel_size (self , flow : TestFlow ) -> None :
69+ self ._test_op (
70+ Model (kernel_size = 1 ),
71+ (torch .randn (2 , 3 , 8 , 8 ),),
72+ flow ,
73+ )
74+ self ._test_op (
75+ Model (kernel_size = 5 ),
76+ (torch .randn (2 , 3 , 8 , 8 ),),
77+ flow ,
78+ )
79+ self ._test_op (
80+ Model (kernel_size = (3 , 5 )),
81+ (torch .randn (2 , 3 , 8 , 8 ),),
82+ flow ,
83+ )
84+
85+ def test_conv2d_stride (self , flow : TestFlow ) -> None :
86+ self ._test_op (
87+ Model (stride = 2 ),
88+ (torch .randn (2 , 3 , 8 , 8 ),),
89+ flow ,
90+ )
91+ self ._test_op (
92+ Model (stride = (2 , 1 )),
93+ (torch .randn (2 , 3 , 8 , 8 ),),
94+ flow ,
95+ )
96+
97+ def test_conv2d_padding (self , flow : TestFlow ) -> None :
98+ self ._test_op (
99+ Model (padding = 1 ),
100+ (torch .randn (2 , 3 , 8 , 8 ),),
101+ flow ,
102+ )
103+ self ._test_op (
104+ Model (padding = (1 , 2 )),
105+ (torch .randn (2 , 3 , 8 , 8 ),),
106+ flow ,
107+ )
108+
109+ def test_conv2d_dilation (self , flow : TestFlow ) -> None :
110+ self ._test_op (
111+ Model (dilation = 2 ),
112+ (torch .randn (2 , 3 , 8 , 8 ),),
113+ flow ,
114+ )
115+ self ._test_op (
116+ Model (dilation = (2 , 1 )),
117+ (torch .randn (2 , 3 , 8 , 8 ),),
118+ flow ,
119+ )
120+
121+ def test_conv2d_groups (self , flow : TestFlow ) -> None :
122+ self ._test_op (
123+ Model (in_channels = 6 , out_channels = 6 , groups = 3 ),
124+ (torch .randn (2 , 6 , 8 , 8 ),),
125+ flow ,
126+ )
127+
128+ def test_conv2d_no_bias (self , flow : TestFlow ) -> None :
129+ self ._test_op (
130+ Model (bias = False ),
131+ (torch .randn (2 , 3 , 8 , 8 ),),
132+ flow ,
133+ )
134+
135+ def test_conv2d_padding_modes (self , flow : TestFlow ) -> None :
86136 for mode in ["zeros" , "reflect" , "replicate" , "circular" ]:
87- self ._test_op (Model (padding = 1 , padding_mode = mode ), (torch .randn (2 , 3 , 8 , 8 ),), tester_factory )
88-
89- def test_conv2d_channels (self , tester_factory : Callable ) -> None :
90- # Test with different channel configurations
91- self ._test_op (Model (in_channels = 1 , out_channels = 1 ), (torch .randn (2 , 1 , 8 , 8 ),), tester_factory )
92- self ._test_op (Model (in_channels = 5 , out_channels = 10 ), (torch .randn (2 , 5 , 8 , 8 ),), tester_factory )
93-
94- def test_conv2d_different_spatial_dims (self , tester_factory : Callable ) -> None :
95- # Test with different height and width
96- self ._test_op (Model (), (torch .randn (2 , 3 , 10 , 8 ),), tester_factory )
137+ self ._test_op (
138+ Model (padding = 1 , padding_mode = mode ),
139+ (torch .randn (2 , 3 , 8 , 8 ),),
140+ flow ,
141+ )
142+
143+ def test_conv2d_channels (self , flow : TestFlow ) -> None :
144+ self ._test_op (
145+ Model (in_channels = 1 , out_channels = 1 ),
146+ (torch .randn (2 , 1 , 8 , 8 ),),
147+ flow ,
148+ )
149+ self ._test_op (
150+ Model (in_channels = 5 , out_channels = 10 ),
151+ (torch .randn (2 , 5 , 8 , 8 ),),
152+ flow ,
153+ )
154+
155+ def test_conv2d_different_spatial_dims (self , flow : TestFlow ) -> None :
156+ self ._test_op (
157+ Model (),
158+ (torch .randn (2 , 3 , 10 , 8 ),),
159+ flow ,
160+ )
0 commit comments