Skip to content

Commit adccbf1

Browse files
Koenvandesande remove duplicate filenames (#448)
* Remove duplicate filenames which do not work on Windows by merging files * Fix * relu tests Co-authored-by: Koen van de Sande <[email protected]>
1 parent d1fa6f9 commit adccbf1

File tree

7 files changed

+73
-55
lines changed

7 files changed

+73
-55
lines changed

torch2trt/converters/Identity.py

Lines changed: 0 additions & 11 deletions
This file was deleted.

torch2trt/converters/ReLU.py

Lines changed: 0 additions & 11 deletions
This file was deleted.

torch2trt/converters/ReLU6.py

Lines changed: 0 additions & 23 deletions
This file was deleted.

torch2trt/converters/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,8 @@
1212
from .Conv2d import *
1313
from .ConvTranspose import *
1414
from .ConvTranspose2d import *
15-
from .Identity import *
1615
from .Linear import *
1716
from .LogSoftmax import *
18-
from .ReLU import *
19-
from .ReLU6 import *
2017
from .activation import *
2118
from .adaptive_avg_pool2d import *
2219
from .adaptive_max_pool2d import *

torch2trt/converters/identity.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,18 @@
55
@tensorrt_converter('torch.nn.functional.dropout')
66
@tensorrt_converter('torch.nn.functional.dropout2d')
77
@tensorrt_converter('torch.nn.functional.dropout3d')
8-
def convert_identity(ctx):
8+
def convert_functional_identity(ctx):
99
input = ctx.method_args[0]
1010
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
1111
output = ctx.method_return
1212
output._trt = input_trt
13+
14+
15+
@tensorrt_converter('torch.nn.Dropout.forward')
16+
@tensorrt_converter('torch.nn.Dropout2d.forward')
17+
@tensorrt_converter('torch.nn.Dropout3d.forward')
18+
def convert_identity(ctx):
19+
input = ctx.method_args[1]
20+
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
21+
output = ctx.method_return
22+
output._trt = input_trt

torch2trt/converters/relu.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,35 @@
11
from torch2trt.torch2trt import *
2-
from .ReLU import *
2+
from torch2trt.module_test import add_module_test
33

44

55
@tensorrt_converter('torch.relu')
66
@tensorrt_converter('torch.relu_')
77
@tensorrt_converter('torch.nn.functional.relu')
88
@tensorrt_converter('torch.nn.functional.relu_')
9-
def convert_relu(ctx):
9+
def convert_functional_relu(ctx):
1010
ctx.method_args = (torch.nn.ReLU(),) + ctx.method_args
11-
convert_ReLU(ctx)
11+
convert_relu(ctx)
12+
13+
14+
@tensorrt_converter('torch.nn.ReLU.forward')
15+
def convert_relu(ctx):
16+
input = ctx.method_args[1]
17+
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
18+
output = ctx.method_return
19+
layer = ctx.network.add_activation(
20+
input=input_trt, type=trt.ActivationType.RELU)
21+
output._trt = layer.get_output(0)
22+
23+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)])
24+
def test_relu_basic():
25+
return torch.nn.ReLU()
26+
27+
28+
class FunctionalRelu(torch.nn.Module):
29+
def forward(self, x):
30+
return torch.nn.functional.relu(x)
31+
32+
33+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)])
34+
def test_functional_relu_basic():
35+
return FunctionalRelu()

torch2trt/converters/relu6.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,40 @@
11
from torch2trt.torch2trt import *
2-
from .ReLU6 import *
2+
from torch2trt.module_test import add_module_test
33

44

55
@tensorrt_converter('torch.nn.functional.relu6')
6-
def convert_relu6(ctx):
6+
def convert_functional_relu6(ctx):
77
ctx.method_args = (torch.nn.ReLU6(),) + ctx.method_args
8-
convert_ReLU6(ctx)
8+
convert_relu6(ctx)
9+
10+
11+
@tensorrt_converter('torch.nn.ReLU6.forward')
12+
def convert_relu6(ctx):
13+
input = ctx.method_args[1]
14+
output = ctx.method_return
15+
16+
input_a_trt, input_b_trt = add_missing_trt_tensors(ctx.network, [input, 6])
17+
input_a_trt, input_b_trt = broadcast_trt_tensors(ctx.network, [input_a_trt, input_b_trt], len(output.shape) - 1)
18+
19+
layer = ctx.network.add_activation(
20+
input=input_a_trt, type=trt.ActivationType.RELU)
21+
layer = ctx.network.add_elementwise(
22+
layer.get_output(0), input_b_trt, trt.ElementWiseOperation.MIN)
23+
24+
output._trt = layer.get_output(0)
25+
26+
27+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)])
28+
def test_relu6_basic():
29+
return torch.nn.ReLU6()
30+
31+
32+
class FunctionalRelu6(torch.nn.Module):
33+
def forward(self, x):
34+
return torch.nn.functional.relu6(x)
35+
36+
37+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)])
38+
def test_functional_relu6_basic():
39+
return FunctionalRelu6()
40+

0 commit comments

Comments
 (0)