Skip to content

Commit e75ce41

Browse files
committed
add some unit tests
1 parent ebeeb4e commit e75ce41

File tree

1 file changed

+173
-37
lines changed

1 file changed

+173
-37
lines changed
Lines changed: 173 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import pytest
22
import torch
33
import torch2trt
4+
import torch.nn as nn
45
from torch2trt.flattener import Flattener
56

67

7-
def _cross_validate(
8+
def cross_validate(
89
module,
910
inputs,
10-
*args,
11-
**kwargs
11+
fp16_mode: bool,
12+
tol: float
1213
):
1314

1415
module = module
@@ -17,63 +18,198 @@ def _cross_validate(
1718
module_trt = torch2trt.torch2trt(
1819
module,
1920
inputs,
20-
*args,
21-
**kwargs
21+
fp16_mode=fp16_mode
2222
)
2323

2424

2525
output = module(*inputs)
2626
output_trt = module_trt(*inputs)
2727

28-
assert torch.allclose(output, output_trt, atol=1e-2, rtol=1e-2)
28+
assert torch.allclose(output, output_trt, atol=tol, rtol=tol)
2929

3030

31+
32+
# MODULES
33+
34+
3135
class UnaryModule(torch.nn.Module):
3236
def __init__(self, fn):
3337
super(UnaryModule, self).__init__()
3438
self.fn = fn
3539

3640
def forward(self, x):
3741
return self.fn(x)
38-
3942

40-
def test_functional_leaky_relu():
41-
_cross_validate(
42-
UnaryModule(lambda x: torch.nn.functional.leaky_relu(x)).cuda().eval(),
43-
[torch.randn(1, 5, 3).cuda()]
44-
)
4543

44+
class BinaryModule(torch.nn.Module):
45+
def __init__(self, fn):
46+
super(BinaryModule, self).__init__()
47+
self.fn = fn
48+
49+
def forward(self, a, b):
50+
return self.fn(a, b)
51+
# TESTS
4652

47-
def test_functional_elu():
48-
_cross_validate(
49-
UnaryModule(lambda x: torch.nn.functional.elu(x)).cuda().eval(),
50-
[torch.randn(1, 5, 3).cuda()]
51-
)
5253

5354

54-
def test_selu():
55-
_cross_validate(
56-
UnaryModule(lambda x: torch.selu(x)).cuda().eval(),
57-
[torch.randn(1, 5, 3).cuda()]
58-
)
55+
@pytest.mark.parametrize("fp16_mode,tol", [(False, 1e-1), (True, 1e-1)])
56+
def test_leaky_relu(fp16_mode, tol):
57+
module = UnaryModule(lambda x: torch.nn.functional.leaky_relu(x)).cuda().eval()
58+
inputs = [torch.randn(1, 3, 4).cuda()]
59+
cross_validate(module, inputs, fp16_mode=fp16_mode, tol=tol)
5960

6061

61-
def test_functional_selu():
62-
_cross_validate(
63-
UnaryModule(lambda x: torch.nn.functional.selu(x)).cuda().eval(),
64-
[torch.randn(1, 5, 3).cuda()]
65-
)
62+
@pytest.mark.parametrize("fp16_mode,tol", [(False, 1e-1), (True, 1e-1)])
63+
def test_elu(fp16_mode, tol):
64+
module = UnaryModule(lambda x: torch.nn.functional.elu(x)).cuda().eval()
65+
inputs = [torch.randn(1, 3, 4).cuda()]
66+
cross_validate(module, inputs, fp16_mode=fp16_mode, tol=tol)
6667

6768

68-
def test_functional_softsign():
69-
_cross_validate(
70-
UnaryModule(lambda x: torch.nn.functional.softsign(x)).cuda().eval(),
71-
[torch.randn(1, 5, 3).cuda()]
72-
)
69+
@pytest.mark.parametrize("fp16_mode,tol", [(False, 1e-1), (True, 1e-1)])
70+
def test_selu(fp16_mode, tol):
71+
module = UnaryModule(lambda x: torch.nn.functional.selu(x)).cuda().eval()
72+
inputs = [torch.randn(1, 3, 4).cuda()]
73+
cross_validate(module, inputs, fp16_mode=fp16_mode, tol=tol)
7374

7475

75-
def test_functional_softplus():
76-
_cross_validate(
77-
UnaryModule(lambda x: torch.nn.functional.softplus(x)).cuda().eval(),
78-
[torch.randn(1, 5, 3).cuda()]
79-
)
76+
@pytest.mark.parametrize("fp16_mode,tol", [(False, 1e-1), (True, 1e-1)])
77+
def test_softsign(fp16_mode, tol):
78+
module = UnaryModule(lambda x: torch.nn.functional.selu(x)).cuda().eval()
79+
inputs = [torch.randn(1, 3, 4).cuda()]
80+
cross_validate(module, inputs, fp16_mode=fp16_mode, tol=tol)
81+
82+
83+
@pytest.mark.parametrize("fp16_mode,tol", [(False, 1e-1), (True, 1e-1)])
84+
def test_softplus(fp16_mode, tol):
85+
module = UnaryModule(lambda x: torch.nn.functional.softplus(x)).cuda().eval()
86+
inputs = [torch.randn(1, 3, 4).cuda()]
87+
cross_validate(module, inputs, fp16_mode=fp16_mode, tol=tol)
88+
89+
90+
@pytest.mark.parametrize("output_size,fp16_mode,tol", [
91+
((1, 1), False, 1e-1),
92+
((2, 2), False, 1e-1),
93+
((1, 1), True, 1e-1)
94+
])
95+
def test_adaptive_avg_pool2d(output_size, fp16_mode, tol):
96+
module = UnaryModule(lambda x: torch.nn.functional.adaptive_avg_pool2d(x, output_size)).cuda().eval()
97+
inputs = [torch.randn(1, 3, 4, 4).cuda()]
98+
cross_validate(module, inputs, fp16_mode=fp16_mode, tol=tol)
99+
100+
101+
@pytest.mark.parametrize("output_size,fp16_mode,tol", [
102+
((1, 1, 1), False, 1e-1),
103+
((2, 2, 2), False, 1e-1),
104+
((1, 1, 1), True, 1e-1)
105+
])
106+
def test_adaptive_avg_pool3d(output_size, fp16_mode, tol):
107+
module = UnaryModule(lambda x: torch.nn.functional.adaptive_avg_pool3d(x, output_size)).cuda().eval()
108+
inputs = [torch.randn(1, 3, 4, 4, 4).cuda()]
109+
cross_validate(module, inputs, fp16_mode=fp16_mode, tol=tol)
110+
111+
112+
@pytest.mark.parametrize("output_size,fp16_mode,tol", [
113+
((1, 1), False, 1e-1),
114+
((2, 2), False, 1e-1),
115+
((1, 1), True, 1e-1)
116+
])
117+
def test_adaptive_max_pool2d(output_size, fp16_mode, tol):
118+
module = UnaryModule(lambda x: torch.nn.functional.adaptive_max_pool2d(x, output_size)).cuda().eval()
119+
inputs = [torch.randn(1, 3, 4, 4).cuda()]
120+
cross_validate(module, inputs, fp16_mode=fp16_mode, tol=tol)
121+
122+
123+
@pytest.mark.parametrize("output_size,fp16_mode,tol", [
124+
((1, 1, 1), False, 1e-1),
125+
((2, 2, 2), False, 1e-1),
126+
((1, 1, 1), True, 1e-1)
127+
])
128+
def test_adaptive_max_pool3d(output_size, fp16_mode, tol):
129+
module = UnaryModule(lambda x: torch.nn.functional.adaptive_max_pool3d(x, output_size)).cuda().eval()
130+
inputs = [torch.randn(1, 3, 4, 4, 4).cuda()]
131+
cross_validate(module, inputs, fp16_mode=fp16_mode, tol=tol)
132+
133+
134+
def test_add():
135+
module = BinaryModule(lambda a, b: a + b).cuda().eval()
136+
inputs = [torch.randn(1, 3, 4).cuda(), torch.randn(1, 3, 4).cuda()]
137+
cross_validate(module, inputs, fp16_mode=False, tol=1e-2)
138+
139+
140+
def test_torch_add():
141+
module = BinaryModule(lambda a, b: torch.add(a, b)).cuda().eval()
142+
inputs = [torch.randn(1, 3, 4).cuda(), torch.randn(1, 3, 4).cuda()]
143+
cross_validate(module, inputs, fp16_mode=False, tol=1e-2)
144+
145+
146+
def test_iadd():
147+
class IAdd(torch.nn.Module):
148+
def __init__(self):
149+
super(IAdd, self).__init__()
150+
151+
def forward(self, x, y):
152+
x += y
153+
return x
154+
155+
module = IAdd().cuda().eval()
156+
inputs = [torch.randn(1, 3, 4).cuda(), torch.randn(1, 3, 4).cuda()]
157+
cross_validate(module, inputs, fp16_mode=False, tol=1e-2)
158+
159+
160+
def test_radd_int():
161+
module = UnaryModule(lambda x: 1 + x).cuda().eval()
162+
inputs = [torch.randn(1, 3, 4).cuda()]
163+
cross_validate(module, inputs, fp16_mode=False, tol=1e-2)
164+
165+
166+
def test_radd_float():
167+
module = UnaryModule(lambda x: 1.0 + x).cuda().eval()
168+
inputs = [torch.randn(1, 3, 4).cuda()]
169+
cross_validate(module, inputs, fp16_mode=False, tol=1e-2)
170+
171+
172+
# TODO: radd, add, iadd
173+
174+
175+
@pytest.mark.parametrize("kernel_size,stride,padding,ceil_mode,count_include_pad", [
176+
(3, 2, 1, False, True),
177+
(3, 2, 1, True, False)
178+
])
179+
def test_avg_pool2d(kernel_size, stride, padding, ceil_mode, count_include_pad):
180+
module = UnaryModule(lambda x: torch.nn.functional.avg_pool2d(
181+
x, kernel_size, stride, padding, ceil_mode, count_include_pad
182+
)).cuda().eval()
183+
inputs = [torch.randn(1, 3, 8, 8).cuda()]
184+
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
185+
186+
187+
@pytest.mark.parametrize("kernel_size,stride,padding,ceil_mode,count_include_pad", [
188+
(3, 2, 1, False, True),
189+
(3, 2, 1, True, False)
190+
])
191+
def test_avg_pool3d(kernel_size, stride, padding, ceil_mode, count_include_pad):
192+
module = UnaryModule(lambda x: torch.nn.functional.avg_pool3d(
193+
x, kernel_size, stride, padding, ceil_mode, count_include_pad
194+
)).cuda().eval()
195+
inputs = [torch.randn(1, 3, 8, 8, 8).cuda()]
196+
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
197+
198+
199+
def test_batch_norm_1d():
200+
module = nn.BatchNorm2d(3).cuda().eval()
201+
inputs = [torch.randn(2, 3, 4).cuda()]
202+
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
203+
204+
205+
def test_batch_norm_2d():
206+
module = nn.BatchNorm2d(3).cuda().eval()
207+
inputs = [torch.randn(2, 3, 4, 4).cuda()]
208+
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
209+
210+
211+
def test_batch_norm_3d():
212+
module = nn.BatchNorm2d(3).cuda().eval()
213+
inputs = [torch.randn(2, 3, 4, 4, 4).cuda()]
214+
cross_validate(module, inputs, fp16_mode=False, tol=1e-1)
215+

0 commit comments

Comments
 (0)