Skip to content

Commit 32b7308

Browse files
author
John Welsh
committed
added silu converter
1 parent 98a666b commit 32b7308

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

torch2trt/converters/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from .relu import *
5151
from .relu6 import *
5252
from .sigmoid import *
53+
from .silu import *
5354
from .softmax import *
5455
from .split import *
5556
from .stack import *

torch2trt/converters/silu.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from torch2trt.torch2trt import *
2+
from torch2trt.module_test import add_module_test
3+
4+
5+
@tensorrt_converter('torch.nn.functional.silu')
6+
def convert_silu(ctx):
7+
input = get_arg(ctx, 'input', pos=0, default=None)
8+
output = ctx.method_return
9+
input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
10+
11+
layer = ctx.network.add_activation(input_trt, trt.ActivationType.SIGMOID)
12+
layer = ctx.network.add_elementwise(input_trt, layer.get_output(0), trt.ElementWiseOperation.PROD)
13+
14+
output._trt = layer.get_output(0)
15+
16+
17+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5)])
18+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3)])
19+
@add_module_test(torch.float32, torch.device('cuda'), [(1, 5, 3, 3)])
20+
def test_silu():
21+
return torch.nn.SiLU()

0 commit comments

Comments
 (0)