@@ -41,10 +41,6 @@ def __init__(
4141
4242 self .in_channels = in_channels
4343
44- def get_example_inputs (self ) -> tuple [torch .Tensor ]:
45- input_1 = torch .randn (1 , self .in_channels , 24 , 24 )
46- return (input_1 ,)
47-
4844 def forward (self , x : torch .Tensor ) -> torch .Tensor :
4945 return self .conv (x )
5046
@@ -62,19 +58,15 @@ def __init__(self) -> None:
6258 bias = True ,
6359 )
6460
65- def get_example_inputs (self ) -> tuple [torch .Tensor ]:
66- input_1 = torch .randn (1 , 32 , 24 , 24 )
67- return (input_1 ,)
68-
6961 def forward (self , x : torch .Tensor ) -> torch .Tensor :
7062 return self .conv (x )
7163
7264
7365class TestConv2d (unittest .TestCase ):
74- def _test (self , module : torch .nn .Module ):
66+ def _test (self , module : torch .nn .Module , inputs ):
7567 tester = SamsungTester (
7668 module ,
77- module . get_example_inputs () ,
69+ inputs ,
7870 [gen_samsung_backend_compile_spec ("E9955" )],
7971 )
8072 (
@@ -83,16 +75,21 @@ def _test(self, module: torch.nn.Module):
8375 .check_not (["executorch_exir_dialects_edge__ops_aten_convolution_default" ])
8476 .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
8577 .to_executorch ()
78+ .run_method_and_compare_outputs (inputs = inputs )
8679 )
8780
8881 def test_fp32_conv2d_without_bias (self ):
89- self ._test (Conv2d (bias = False ))
82+ inputs = (torch .randn (1 , 3 , 24 , 24 ),)
83+ self ._test (Conv2d (bias = False ), inputs )
9084
9185 def test_fp32_conv2d_with_bias (self ):
92- self ._test (Conv2d (bias = True ))
86+ inputs = (torch .randn (1 , 3 , 24 , 24 ),)
87+ self ._test (Conv2d (bias = True ), inputs )
9388
9489 def test_fp32_depthwise_conv2d (self ):
95- self ._test (Conv2d (in_channels = 8 , out_channels = 8 , groups = 8 ))
90+ inputs = (torch .randn (1 , 8 , 24 , 24 ),)
91+ self ._test (Conv2d (in_channels = 8 , out_channels = 8 , groups = 8 ), inputs )
9692
9793 def test_fp32_transpose_conv2d (self ):
98- self ._test (TransposeConv2d ())
94+ inputs = (torch .randn (1 , 32 , 24 , 24 ),)
95+ self ._test (TransposeConv2d (), inputs )
0 commit comments