Skip to content

Commit 5fd2f95

Browse files
committed
mv plugin converters
1 parent 7ea8e65 commit 5fd2f95

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

torch2trt/converters/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .unimplemented_converters import *
2+
from .plugin_converters import *
23
from .native_converters import *
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch2trt.torch2trt import *
4+
from torch2trt.module_test import add_module_test
5+
import numpy as np
6+
import ctypes
7+
8+
9+
try:
10+
ctypes.CDLL('libtorch2trt_plugins.so')
11+
12+
def create_reflection_pad_2d_plugin(paddingLeft, paddingRight, paddingTop, paddingBottom):
13+
14+
registry = trt.get_plugin_registry()
15+
creator = registry.get_plugin_creator('ReflectionPad2dPlugin', '1', '')
16+
17+
fc = trt.PluginFieldCollection([
18+
trt.PluginField(
19+
'paddingLeft',
20+
np.array([paddingLeft]).astype(np.int32),
21+
trt.PluginFieldType.INT32
22+
),
23+
trt.PluginField(
24+
'paddingRight',
25+
np.array([paddingRight]).astype(np.int32),
26+
trt.PluginFieldType.INT32
27+
),
28+
trt.PluginField(
29+
'paddingTop',
30+
np.array([paddingTop]).astype(np.int32),
31+
trt.PluginFieldType.INT32
32+
),
33+
trt.PluginField(
34+
'paddingBottom',
35+
np.array([paddingBottom]).astype(np.int32),
36+
trt.PluginFieldType.INT32
37+
)
38+
])
39+
40+
return creator.create_plugin('', fc)
41+
@tensorrt_converter(nn.ReflectionPad2d.forward)
42+
def convert_reflection_pad(ctx):
43+
module = get_arg(ctx, 'self', pos=0, default=None)
44+
input = get_arg(ctx, 'x', pos=1, default=None)
45+
output = ctx.method_return
46+
input_trt = input._trt
47+
plugin = create_reflection_pad_2d_plugin(
48+
module.padding[0],
49+
module.padding[1],
50+
module.padding[2],
51+
module.padding[3]
52+
)
53+
layer = ctx.network.add_plugin_v2([input_trt], plugin)
54+
output._trt = layer.get_output(0)
55+
56+
57+
@add_module_test(torch.float32, torch.device("cuda"), [(1, 1, 3, 3)])
58+
@add_module_test(torch.float32, torch.device("cuda"), [(1, 2, 3, 3)])
59+
def test_reflection_pad_2d_simple():
60+
return nn.ReflectionPad2d(1)
61+
62+
63+
@add_module_test(torch.float32, torch.device("cuda"), [(1, 1, 3, 3)])
64+
@add_module_test(torch.float32, torch.device("cuda"), [(1, 2, 3, 3)])
65+
def test_reflection_pad_2d_simple():
66+
return nn.ReflectionPad2d(2)
67+
68+
69+
@add_module_test(torch.float32, torch.device("cuda"), [(1, 1, 3, 3)])
70+
@add_module_test(torch.float32, torch.device("cuda"), [(1, 2, 3, 3)])
71+
def test_reflection_pad_2d_simple():
72+
return nn.ReflectionPad2d((1, 0, 1, 0))
73+
except:
74+
pass

0 commit comments

Comments
 (0)