11import pytest
22import torch
33import torch2trt
4+ import torch .nn as nn
45from torch2trt .flattener import Flattener
56
67
7- def _cross_validate (
8+ def cross_validate (
89 module ,
910 inputs ,
10- * args ,
11- ** kwargs
11+ fp16_mode : bool ,
12+ tol : float
1213 ):
1314
1415 module = module
@@ -17,63 +18,198 @@ def _cross_validate(
1718 module_trt = torch2trt .torch2trt (
1819 module ,
1920 inputs ,
20- * args ,
21- ** kwargs
21+ fp16_mode = fp16_mode
2222 )
2323
2424
2525 output = module (* inputs )
2626 output_trt = module_trt (* inputs )
2727
28- assert torch .allclose (output , output_trt , atol = 1e-2 , rtol = 1e-2 )
28+ assert torch .allclose (output , output_trt , atol = tol , rtol = tol )
2929
3030
31+
32+ # MODULES
33+
34+
3135class UnaryModule (torch .nn .Module ):
3236 def __init__ (self , fn ):
3337 super (UnaryModule , self ).__init__ ()
3438 self .fn = fn
3539
3640 def forward (self , x ):
3741 return self .fn (x )
38-
3942
40- def test_functional_leaky_relu ():
41- _cross_validate (
42- UnaryModule (lambda x : torch .nn .functional .leaky_relu (x )).cuda ().eval (),
43- [torch .randn (1 , 5 , 3 ).cuda ()]
44- )
4543
44+ class BinaryModule (torch .nn .Module ):
45+ def __init__ (self , fn ):
46+ super (BinaryModule , self ).__init__ ()
47+ self .fn = fn
48+
49+ def forward (self , a , b ):
50+ return self .fn (a , b )
51+ # TESTS
4652
47- def test_functional_elu ():
48- _cross_validate (
49- UnaryModule (lambda x : torch .nn .functional .elu (x )).cuda ().eval (),
50- [torch .randn (1 , 5 , 3 ).cuda ()]
51- )
5253
5354
54- def test_selu ():
55- _cross_validate (
56- UnaryModule (lambda x : torch .selu (x )).cuda ().eval (),
57- [torch .randn (1 , 5 , 3 ).cuda ()]
58- )
55+ @ pytest . mark . parametrize ( "fp16_mode,tol" , [( False , 1e-1 ), ( True , 1e-1 )])
56+ def test_leaky_relu ( fp16_mode , tol ):
57+ module = UnaryModule (lambda x : torch .nn . functional . leaky_relu (x )).cuda ().eval ()
58+ inputs = [torch .randn (1 , 3 , 4 ).cuda ()]
59+ cross_validate ( module , inputs , fp16_mode = fp16_mode , tol = tol )
5960
6061
61- def test_functional_selu ():
62- _cross_validate (
63- UnaryModule (lambda x : torch .nn .functional .selu (x )).cuda ().eval (),
64- [torch .randn (1 , 5 , 3 ).cuda ()]
65- )
62+ @ pytest . mark . parametrize ( "fp16_mode,tol" , [( False , 1e-1 ), ( True , 1e-1 )])
63+ def test_elu ( fp16_mode , tol ):
64+ module = UnaryModule (lambda x : torch .nn .functional .elu (x )).cuda ().eval ()
65+ inputs = [torch .randn (1 , 3 , 4 ).cuda ()]
66+ cross_validate ( module , inputs , fp16_mode = fp16_mode , tol = tol )
6667
6768
68- def test_functional_softsign ():
69- _cross_validate (
70- UnaryModule (lambda x : torch .nn .functional .softsign (x )).cuda ().eval (),
71- [torch .randn (1 , 5 , 3 ).cuda ()]
72- )
69+ @ pytest . mark . parametrize ( "fp16_mode,tol" , [( False , 1e-1 ), ( True , 1e-1 )])
70+ def test_selu ( fp16_mode , tol ):
71+ module = UnaryModule (lambda x : torch .nn .functional .selu (x )).cuda ().eval ()
72+ inputs = [torch .randn (1 , 3 , 4 ).cuda ()]
73+ cross_validate ( module , inputs , fp16_mode = fp16_mode , tol = tol )
7374
7475
75- def test_functional_softplus ():
76- _cross_validate (
77- UnaryModule (lambda x : torch .nn .functional .softplus (x )).cuda ().eval (),
78- [torch .randn (1 , 5 , 3 ).cuda ()]
79- )
76+ @pytest .mark .parametrize ("fp16_mode,tol" , [(False , 1e-1 ), (True , 1e-1 )])
77+ def test_softsign (fp16_mode , tol ):
78+ module = UnaryModule (lambda x : torch .nn .functional .selu (x )).cuda ().eval ()
79+ inputs = [torch .randn (1 , 3 , 4 ).cuda ()]
80+ cross_validate (module , inputs , fp16_mode = fp16_mode , tol = tol )
81+
82+
83+ @pytest .mark .parametrize ("fp16_mode,tol" , [(False , 1e-1 ), (True , 1e-1 )])
84+ def test_softplus (fp16_mode , tol ):
85+ module = UnaryModule (lambda x : torch .nn .functional .softplus (x )).cuda ().eval ()
86+ inputs = [torch .randn (1 , 3 , 4 ).cuda ()]
87+ cross_validate (module , inputs , fp16_mode = fp16_mode , tol = tol )
88+
89+
90+ @pytest .mark .parametrize ("output_size,fp16_mode,tol" , [
91+ ((1 , 1 ), False , 1e-1 ),
92+ ((2 , 2 ), False , 1e-1 ),
93+ ((1 , 1 ), True , 1e-1 )
94+ ])
95+ def test_adaptive_avg_pool2d (output_size , fp16_mode , tol ):
96+ module = UnaryModule (lambda x : torch .nn .functional .adaptive_avg_pool2d (x , output_size )).cuda ().eval ()
97+ inputs = [torch .randn (1 , 3 , 4 , 4 ).cuda ()]
98+ cross_validate (module , inputs , fp16_mode = fp16_mode , tol = tol )
99+
100+
101+ @pytest .mark .parametrize ("output_size,fp16_mode,tol" , [
102+ ((1 , 1 , 1 ), False , 1e-1 ),
103+ ((2 , 2 , 2 ), False , 1e-1 ),
104+ ((1 , 1 , 1 ), True , 1e-1 )
105+ ])
106+ def test_adaptive_avg_pool3d (output_size , fp16_mode , tol ):
107+ module = UnaryModule (lambda x : torch .nn .functional .adaptive_avg_pool3d (x , output_size )).cuda ().eval ()
108+ inputs = [torch .randn (1 , 3 , 4 , 4 , 4 ).cuda ()]
109+ cross_validate (module , inputs , fp16_mode = fp16_mode , tol = tol )
110+
111+
112+ @pytest .mark .parametrize ("output_size,fp16_mode,tol" , [
113+ ((1 , 1 ), False , 1e-1 ),
114+ ((2 , 2 ), False , 1e-1 ),
115+ ((1 , 1 ), True , 1e-1 )
116+ ])
117+ def test_adaptive_max_pool2d (output_size , fp16_mode , tol ):
118+ module = UnaryModule (lambda x : torch .nn .functional .adaptive_max_pool2d (x , output_size )).cuda ().eval ()
119+ inputs = [torch .randn (1 , 3 , 4 , 4 ).cuda ()]
120+ cross_validate (module , inputs , fp16_mode = fp16_mode , tol = tol )
121+
122+
123+ @pytest .mark .parametrize ("output_size,fp16_mode,tol" , [
124+ ((1 , 1 , 1 ), False , 1e-1 ),
125+ ((2 , 2 , 2 ), False , 1e-1 ),
126+ ((1 , 1 , 1 ), True , 1e-1 )
127+ ])
128+ def test_adaptive_max_pool3d (output_size , fp16_mode , tol ):
129+ module = UnaryModule (lambda x : torch .nn .functional .adaptive_max_pool3d (x , output_size )).cuda ().eval ()
130+ inputs = [torch .randn (1 , 3 , 4 , 4 , 4 ).cuda ()]
131+ cross_validate (module , inputs , fp16_mode = fp16_mode , tol = tol )
132+
133+
134+ def test_add ():
135+ module = BinaryModule (lambda a , b : a + b ).cuda ().eval ()
136+ inputs = [torch .randn (1 , 3 , 4 ).cuda (), torch .randn (1 , 3 , 4 ).cuda ()]
137+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-2 )
138+
139+
140+ def test_torch_add ():
141+ module = BinaryModule (lambda a , b : torch .add (a , b )).cuda ().eval ()
142+ inputs = [torch .randn (1 , 3 , 4 ).cuda (), torch .randn (1 , 3 , 4 ).cuda ()]
143+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-2 )
144+
145+
146+ def test_iadd ():
147+ class IAdd (torch .nn .Module ):
148+ def __init__ (self ):
149+ super (IAdd , self ).__init__ ()
150+
151+ def forward (self , x , y ):
152+ x += y
153+ return x
154+
155+ module = IAdd ().cuda ().eval ()
156+ inputs = [torch .randn (1 , 3 , 4 ).cuda (), torch .randn (1 , 3 , 4 ).cuda ()]
157+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-2 )
158+
159+
160+ def test_radd_int ():
161+ module = UnaryModule (lambda x : 1 + x ).cuda ().eval ()
162+ inputs = [torch .randn (1 , 3 , 4 ).cuda ()]
163+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-2 )
164+
165+
166+ def test_radd_float ():
167+ module = UnaryModule (lambda x : 1.0 + x ).cuda ().eval ()
168+ inputs = [torch .randn (1 , 3 , 4 ).cuda ()]
169+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-2 )
170+
171+
172+ # TODO: radd, add, iadd
173+
174+
175+ @pytest .mark .parametrize ("kernel_size,stride,padding,ceil_mode,count_include_pad" , [
176+ (3 , 2 , 1 , False , True ),
177+ (3 , 2 , 1 , True , False )
178+ ])
179+ def test_avg_pool2d (kernel_size , stride , padding , ceil_mode , count_include_pad ):
180+ module = UnaryModule (lambda x : torch .nn .functional .avg_pool2d (
181+ x , kernel_size , stride , padding , ceil_mode , count_include_pad
182+ )).cuda ().eval ()
183+ inputs = [torch .randn (1 , 3 , 8 , 8 ).cuda ()]
184+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
185+
186+
187+ @pytest .mark .parametrize ("kernel_size,stride,padding,ceil_mode,count_include_pad" , [
188+ (3 , 2 , 1 , False , True ),
189+ (3 , 2 , 1 , True , False )
190+ ])
191+ def test_avg_pool3d (kernel_size , stride , padding , ceil_mode , count_include_pad ):
192+ module = UnaryModule (lambda x : torch .nn .functional .avg_pool3d (
193+ x , kernel_size , stride , padding , ceil_mode , count_include_pad
194+ )).cuda ().eval ()
195+ inputs = [torch .randn (1 , 3 , 8 , 8 , 8 ).cuda ()]
196+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
197+
198+
199+ def test_batch_norm_1d ():
200+ module = nn .BatchNorm2d (3 ).cuda ().eval ()
201+ inputs = [torch .randn (2 , 3 , 4 ).cuda ()]
202+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
203+
204+
205+ def test_batch_norm_2d ():
206+ module = nn .BatchNorm2d (3 ).cuda ().eval ()
207+ inputs = [torch .randn (2 , 3 , 4 , 4 ).cuda ()]
208+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
209+
210+
211+ def test_batch_norm_3d ():
212+ module = nn .BatchNorm2d (3 ).cuda ().eval ()
213+ inputs = [torch .randn (2 , 3 , 4 , 4 , 4 ).cuda ()]
214+ cross_validate (module , inputs , fp16_mode = False , tol = 1e-1 )
215+
0 commit comments