@@ -173,18 +173,6 @@ def get_inputs(self):
173173 return (torch .randn (2 , 2 , 4 , 4 ),)
174174
175175
176- class Conv2dDQ (torch .nn .Module ):
177- def __init__ (self ):
178- super ().__init__ ()
179- self .conv = torch .nn .Conv2d (in_channels = 3 , out_channels = 10 , kernel_size = 3 )
180-
181- def forward (self , x ):
182- return self .conv (x )
183-
184- def get_inputs (self ):
185- return (torch .randn (1 , 3 , 8 , 8 ),)
186-
187-
188176class Conv2dDQSeq (torch .nn .Module ):
189177 def __init__ (self ):
190178 super ().__init__ ()
@@ -210,7 +198,7 @@ def __init__(self):
210198 in_channels = 3 , out_channels = 8 , kernel_size = 3 , padding = 1
211199 )
212200 self .second = torch .nn .Conv2d (
213- in_channels = 3 , out_channels = 10 , kernel_size = 3 , padding = 1
201+ in_channels = 3 , out_channels = 8 , kernel_size = 3 , padding = 1
214202 )
215203
216204 def forward (self , x ):
@@ -785,13 +773,24 @@ def forward(self, x):
785773 )
786774
787775 def test_dq_conv2d (self ) -> None :
788- model = Conv2dDQ ()
776+ model = Conv2d (
777+ in_channels = 3 ,
778+ out_channels = 10 ,
779+ kernel_size = (3 , 3 ),
780+ stride = (1 , 1 ),
781+ padding = (0 , 0 ),
782+ batches = 1 ,
783+ width = 8 ,
784+ height = 8 ,
785+ )
789786 self ._test_dq (model )
790787
791788 def test_dq_conv2d_seq (self ) -> None :
792789 model = Conv2dDQSeq ()
793- self ._test_dq (model , conv_count = 2 )
790+ conv_count = sum (1 for m in model .modules () if type (m ) is torch .nn .Conv2d )
791+ self ._test_dq (model , conv_count )
794792
795793 def test_dq_conv2d_parallel (self ) -> None :
796794 model = Conv2dDQParallel ()
797- self ._test_dq (model , conv_count = 2 )
795+ conv_count = sum (1 for m in model .modules () if type (m ) is torch .nn .Conv2d )
796+ self ._test_dq (model , conv_count )
0 commit comments