@@ -160,7 +160,7 @@ def forward(self, x, y):
160160 return x
161161
162162 module = IAdd ().cuda ().eval ()
163- inputs = [torch .randn (1 , 3 , 4 ).cuda (), torch .randn (1 , 3 , 4 ).cuda ()]
163+ inputs = [torch .ones (1 , 3 , 4 ).cuda (), torch .ones (1 , 3 , 4 ).cuda ()]
164164 cross_validate (module , inputs , fp16_mode = False , tol = 1e-2 )
165165
166166
@@ -179,29 +179,6 @@ def test_radd_float():
179179# TODO: radd, add, iadd
180180
181181
182- @pytest .mark .parametrize ("kernel_size,stride,padding,ceil_mode,count_include_pad" , [
183- (3 , 2 , 1 , False , True ),
184- (3 , 2 , 1 , True , False )
185- ])
186- def test_avg_pool2d (kernel_size , stride , padding , ceil_mode , count_include_pad ):
187- module = UnaryModule (lambda x : torch .nn .functional .avg_pool2d (
188- x , kernel_size , stride , padding , ceil_mode , count_include_pad
189- )).cuda ().eval ()
190- inputs = [torch .randn (1 , 3 , 8 , 8 ).cuda ()]
191- cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
192-
193-
194- @pytest .mark .parametrize ("kernel_size,stride,padding,ceil_mode,count_include_pad" , [
195- (3 , 2 , 1 , False , True ),
196- (3 , 2 , 1 , True , False )
197- ])
198- def test_avg_pool3d (kernel_size , stride , padding , ceil_mode , count_include_pad ):
199- module = UnaryModule (lambda x : torch .nn .functional .avg_pool3d (
200- x , kernel_size , stride , padding , ceil_mode , count_include_pad
201- )).cuda ().eval ()
202- inputs = [torch .randn (1 , 3 , 8 , 8 , 8 ).cuda ()]
203- cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
204-
205182
206183def test_batch_norm_1d ():
207184 module = nn .BatchNorm1d (3 ).cuda ().eval ()
@@ -413,7 +390,7 @@ def fn(x):
413390 x /= val
414391 return x
415392 module = UnaryModule (fn ).cuda ().eval ()
416- inputs = [torch .randn (1 , 4 , 4 ).cuda ()]
393+ inputs = [torch .ones (1 , 4 , 4 ).cuda ()]
417394 cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
418395
419396
@@ -455,10 +432,11 @@ def test_flatten(start_dim, end_dim):
455432 cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
456433
457434
458- def test_floordiv ():
435+ @pytest .mark .parametrize ("denom" , [1. , 2. ])
436+ def test_floordiv (denom ):
459437 module = BinaryModule (lambda x , y : x // y ).cuda ().eval ()
460- inputs = [torch .randn (1 , 2 , 3 , 4 , 5 ).cuda ()]
461- inputs .append (torch .ones_like (inputs [0 ])* 2 )
438+ inputs = [torch .ones (1 , 2 , 3 , 4 , 5 ).cuda ()]
439+ inputs .append (torch .ones_like (inputs [0 ]) * denom )
462440 cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
463441
464442
@@ -578,15 +556,13 @@ def test_matmul(shape_a, shape_b):
578556 cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
579557
580558
581- @pytest .mark .parametrize ("nd" , [1 ,2 ,3 ])
582559@pytest .mark .parametrize (
583560 "kernel_size,stride,padding,dilation,ceil_mode" , [
584561 (3 , 2 , 1 , 1 , False ),
585- (3 , 2 , 1 , 1 , False ),
586- (3 , 2 , 1 , 1 , False ),
587562 ]
588563)
589- def test_max_pool (nd , kernel_size , stride , padding , dilation , ceil_mode ):
564+ @pytest .mark .parametrize ("nd" , [1 ,2 ,3 ])
565+ def test_max_pool_nd (nd , kernel_size , stride , padding , dilation , ceil_mode ):
590566 if nd == 1 :
591567 cls = nn .MaxPool1d
592568 elif nd == 2 :
@@ -598,6 +574,26 @@ def test_max_pool(nd, kernel_size, stride, padding, dilation, ceil_mode):
598574 inputs = [torch .randn (* input_size ).cuda ()]
599575 cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
600576
577+
578+ @pytest .mark .parametrize (
579+ "kernel_size,stride,padding,ceil_mode,count_include_pad" , [
580+ (3 , 2 , 1 , False , False ),
581+ ]
582+ )
583+ @pytest .mark .parametrize ("nd" , [1 ,2 ,3 ])
584+ def test_avg_pool_nd (nd , kernel_size , stride , padding , ceil_mode , count_include_pad ):
585+ if nd == 1 :
586+ cls = nn .AvgPool1d
587+ elif nd == 2 :
588+ cls = nn .AvgPool2d
589+ elif nd == 3 :
590+ cls = nn .AvgPool3d
591+ module = cls (kernel_size ,stride ,padding ,ceil_mode = ceil_mode , count_include_pad = count_include_pad ).cuda ().eval ()
592+ input_size = [1 , 3 ] + [4 ]* nd
593+ inputs = [torch .randn (* input_size ).cuda ()]
594+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
595+
596+
601597@pytest .mark .parametrize ("op" , ["min" ,"max" , "fmod" ])
602598def test_binary_op_elementwise (op ):
603599 if op == "max" :
0 commit comments