1+ import torch
2+ import torch .nn as nn
3+ from torch2trt .torch2trt import *
4+ from torch2trt .module_test import add_module_test
5+ import numpy as np
6+ import ctypes
7+
8+
9+ try :
10+ ctypes .CDLL ('libtorch2trt_plugins.so' )
11+
12+ def create_reflection_pad_2d_plugin (paddingLeft , paddingRight , paddingTop , paddingBottom ):
13+
14+ registry = trt .get_plugin_registry ()
15+ creator = registry .get_plugin_creator ('ReflectionPad2dPlugin' , '1' , '' )
16+
17+ fc = trt .PluginFieldCollection ([
18+ trt .PluginField (
19+ 'paddingLeft' ,
20+ np .array ([paddingLeft ]).astype (np .int32 ),
21+ trt .PluginFieldType .INT32
22+ ),
23+ trt .PluginField (
24+ 'paddingRight' ,
25+ np .array ([paddingRight ]).astype (np .int32 ),
26+ trt .PluginFieldType .INT32
27+ ),
28+ trt .PluginField (
29+ 'paddingTop' ,
30+ np .array ([paddingTop ]).astype (np .int32 ),
31+ trt .PluginFieldType .INT32
32+ ),
33+ trt .PluginField (
34+ 'paddingBottom' ,
35+ np .array ([paddingBottom ]).astype (np .int32 ),
36+ trt .PluginFieldType .INT32
37+ )
38+ ])
39+
40+ return creator .create_plugin ('' , fc )
41+ @tensorrt_converter (nn .ReflectionPad2d .forward )
42+ def convert_reflection_pad (ctx ):
43+ module = get_arg (ctx , 'self' , pos = 0 , default = None )
44+ input = get_arg (ctx , 'x' , pos = 1 , default = None )
45+ output = ctx .method_return
46+ input_trt = input ._trt
47+ plugin = create_reflection_pad_2d_plugin (
48+ module .padding [0 ],
49+ module .padding [1 ],
50+ module .padding [2 ],
51+ module .padding [3 ]
52+ )
53+ layer = ctx .network .add_plugin_v2 ([input_trt ], plugin )
54+ output ._trt = layer .get_output (0 )
55+
56+
57+ @add_module_test (torch .float32 , torch .device ("cuda" ), [(1 , 1 , 3 , 3 )])
58+ @add_module_test (torch .float32 , torch .device ("cuda" ), [(1 , 2 , 3 , 3 )])
59+ def test_reflection_pad_2d_simple ():
60+ return nn .ReflectionPad2d (1 )
61+
62+
63+ @add_module_test (torch .float32 , torch .device ("cuda" ), [(1 , 1 , 3 , 3 )])
64+ @add_module_test (torch .float32 , torch .device ("cuda" ), [(1 , 2 , 3 , 3 )])
65+ def test_reflection_pad_2d_simple ():
66+ return nn .ReflectionPad2d (2 )
67+
68+
69+ @add_module_test (torch .float32 , torch .device ("cuda" ), [(1 , 1 , 3 , 3 )])
70+ @add_module_test (torch .float32 , torch .device ("cuda" ), [(1 , 2 , 3 , 3 )])
71+ def test_reflection_pad_2d_simple ():
72+ return nn .ReflectionPad2d ((1 , 0 , 1 , 0 ))
73+ except :
74+ pass
0 commit comments