Skip to content

Commit 27c9d42

Browse files
committed
migrate converters and remove unused
1 parent 2d593d5 commit 27c9d42

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+2264
-29
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import pytest
2+
import torch
3+
import torch2trt
4+
from torch2trt.flattener import Flattener
5+
6+
7+
def _cross_validate(
8+
module,
9+
inputs,
10+
*args,
11+
**kwargs
12+
):
13+
14+
module = module
15+
16+
17+
module_trt = torch2trt.torch2trt(
18+
module,
19+
inputs,
20+
*args,
21+
**kwargs
22+
)
23+
24+
25+
output = module(*inputs)
26+
output_trt = module_trt(*inputs)
27+
28+
assert torch.allclose(output, output_trt, atol=1e-2, rtol=1e-2)
29+
30+
31+
class UnaryModule(torch.nn.Module):
32+
def __init__(self, fn):
33+
super(UnaryModule, self).__init__()
34+
self.fn = fn
35+
36+
def forward(self, x):
37+
return self.fn(x)
38+
39+
40+
def test_functional_leaky_relu():
41+
_cross_validate(
42+
UnaryModule(lambda x: torch.nn.functional.leaky_relu(x)).cuda().eval(),
43+
[torch.randn(1, 5, 3).cuda()]
44+
)
45+
46+
47+
def test_functional_elu():
48+
_cross_validate(
49+
UnaryModule(lambda x: torch.nn.functional.elu(x)).cuda().eval(),
50+
[torch.randn(1, 5, 3).cuda()]
51+
)
52+
53+
54+
def test_selu():
55+
_cross_validate(
56+
UnaryModule(lambda x: torch.selu(x)).cuda().eval(),
57+
[torch.randn(1, 5, 3).cuda()]
58+
)
59+
60+
61+
def test_functional_selu():
62+
_cross_validate(
63+
UnaryModule(lambda x: torch.nn.functional.selu(x)).cuda().eval(),
64+
[torch.randn(1, 5, 3).cuda()]
65+
)
66+
67+
68+
def test_functional_softsign():
69+
_cross_validate(
70+
UnaryModule(lambda x: torch.nn.functional.softsign(x)).cuda().eval(),
71+
[torch.randn(1, 5, 3).cuda()]
72+
)
73+
74+
75+
def test_functional_softplus():
76+
_cross_validate(
77+
UnaryModule(lambda x: torch.nn.functional.softplus(x)).cuda().eval(),
78+
[torch.randn(1, 5, 3).cuda()]
79+
)
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import pytest
2+
import torch
3+
import torch.nn as nn
4+
from torch2trt import torch2trt, trt
5+
6+
7+
class YOLOXFocusTestModule(nn.Module):
8+
9+
10+
def forward(self, x):
11+
patch_top_left = x[..., ::2, ::2]
12+
patch_top_right = x[..., ::2, 1::2]
13+
patch_bot_left = x[..., 1::2, ::2]
14+
patch_bot_right = x[..., 1::2, 1::2]
15+
x = torch.cat(
16+
(
17+
patch_top_left,
18+
patch_bot_left,
19+
patch_top_right,
20+
patch_bot_right,
21+
),
22+
dim=1,
23+
)
24+
return x
25+
26+
27+
def test_getitem_dynamic_yolox_layer():
28+
29+
class YOLOXFocusTestModule(nn.Module):
30+
31+
32+
def forward(self, x):
33+
patch_top_left = x[..., ::2, ::2]
34+
patch_top_right = x[..., ::2, 1::2]
35+
patch_bot_left = x[..., 1::2, ::2]
36+
patch_bot_right = x[..., 1::2, 1::2]
37+
x = torch.cat(
38+
(
39+
patch_top_left,
40+
patch_bot_left,
41+
patch_top_right,
42+
patch_bot_right,
43+
),
44+
dim=1,
45+
)
46+
return x
47+
48+
module = YOLOXFocusTestModule().cuda().eval()
49+
50+
data = torch.randn(1, 3, 112, 112).cuda()
51+
52+
module_trt = torch2trt(module, [data], max_batch_size=4, log_level=trt.Logger.VERBOSE)
53+
54+
data = torch.randn(1, 3, 112, 112).cuda()
55+
assert(torch.allclose(module_trt(data), module(data), atol=1e-4, rtol=1e-4))
56+
57+
data = torch.randn(4, 3, 112, 112).cuda()
58+
assert(torch.allclose(module_trt(data), module(data), atol=1e-4, rtol=1e-4))
59+
60+
61+
def test_getitem_dynamic_add_dim():
62+
63+
class TestModule(nn.Module):
64+
65+
66+
def forward(self, x):
67+
patch_top_left = x[..., None]
68+
patch_top_right = x[..., None]
69+
patch_bot_left = x[..., None]
70+
patch_bot_right = x[..., None]
71+
x = torch.cat(
72+
(
73+
patch_top_left,
74+
patch_bot_left,
75+
patch_top_right,
76+
patch_bot_right,
77+
),
78+
dim=1,
79+
)
80+
return x
81+
82+
module = TestModule().cuda().eval()
83+
84+
data = torch.randn(1, 3, 112, 112).cuda()
85+
86+
module_trt = torch2trt(module, [data], max_batch_size=4, log_level=trt.Logger.VERBOSE)
87+
88+
data = torch.randn(1, 3, 112, 112).cuda()
89+
assert(torch.allclose(module_trt(data), module(data), atol=1e-4, rtol=1e-4))
90+
91+
data = torch.randn(4, 3, 112, 112).cuda()
92+
assert(torch.allclose(module_trt(data), module(data), atol=1e-4, rtol=1e-4))
93+
94+
95+
def test_getitem_dynamic_remove_dim():
96+
97+
class TestModule(nn.Module):
98+
99+
100+
def forward(self, x):
101+
patch_top_left = x[..., 0]
102+
patch_top_right = x[..., 0]
103+
patch_bot_left = x[..., 0]
104+
patch_bot_right = x[..., 0]
105+
x = torch.cat(
106+
(
107+
patch_top_left,
108+
patch_bot_left,
109+
patch_top_right,
110+
patch_bot_right,
111+
),
112+
dim=1,
113+
)
114+
return x
115+
116+
module = TestModule().cuda().eval()
117+
118+
data = torch.randn(1, 3, 112, 112).cuda()
119+
120+
module_trt = torch2trt(module, [data], max_batch_size=4, log_level=trt.Logger.VERBOSE)
121+
122+
data = torch.randn(1, 3, 112, 112).cuda()
123+
assert(torch.allclose(module_trt(data), module(data), atol=1e-4, rtol=1e-4))
124+
125+
data = torch.randn(4, 3, 112, 112).cuda()
126+
assert(torch.allclose(module_trt(data), module(data), atol=1e-4, rtol=1e-4))
127+
128+
129+
def test_getitem_dynamic_remove_add_dim():
130+
131+
class TestModule(nn.Module):
132+
133+
134+
def forward(self, x):
135+
patch_top_left = x[..., 0, None]
136+
patch_top_right = x[..., 0, None]
137+
patch_bot_left = x[..., 0, None]
138+
patch_bot_right = x[..., 0, None]
139+
x = torch.cat(
140+
(
141+
patch_top_left,
142+
patch_bot_left,
143+
patch_top_right,
144+
patch_bot_right,
145+
),
146+
dim=1,
147+
)
148+
return x
149+
150+
module = TestModule().cuda().eval()
151+
152+
data = torch.randn(1, 3, 112, 112).cuda()
153+
154+
module_trt = torch2trt(module, [data], max_batch_size=4, log_level=trt.Logger.VERBOSE)
155+
156+
data = torch.randn(1, 3, 112, 112).cuda()
157+
assert(torch.allclose(module_trt(data), module(data), atol=1e-4, rtol=1e-4))
158+
159+
data = torch.randn(4, 3, 112, 112).cuda()
160+
assert(torch.allclose(module_trt(data), module(data), atol=1e-4, rtol=1e-4))
161+

tests/model_tests/torchvision/test_classification_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def _cross_validate_module(model, shape=(224, 224)):
1010
data = torch.randn(1, 3, *shape).cuda()
1111
out = model(data)
1212
out_trt = model_trt(data)
13-
assert torch.allclose(out, out_trt, rtol=1e-2, atol=1e-2)
13+
assert torch.allclose(out, out_trt, rtol=1e-1, atol=1e-1)
1414

1515

1616

0 commit comments

Comments
 (0)