11import torch
22import torch .nn as nn
33from torch2trt .torch2trt import *
4- from torch2trt .module_test import add_module_test
54import numpy as np
65import ctypes
76
@@ -53,22 +52,5 @@ def convert_reflection_pad(ctx):
5352 layer = ctx .network .add_plugin_v2 ([input_trt ], plugin )
5453 output ._trt = layer .get_output (0 )
5554
56-
57- @add_module_test (torch .float32 , torch .device ("cuda" ), [(1 , 1 , 3 , 3 )])
58- @add_module_test (torch .float32 , torch .device ("cuda" ), [(1 , 2 , 3 , 3 )])
59- def test_reflection_pad_2d_simple ():
60- return nn .ReflectionPad2d (1 )
61-
62-
63- @add_module_test (torch .float32 , torch .device ("cuda" ), [(1 , 1 , 3 , 3 )])
64- @add_module_test (torch .float32 , torch .device ("cuda" ), [(1 , 2 , 3 , 3 )])
65- def test_reflection_pad_2d_simple ():
66- return nn .ReflectionPad2d (2 )
67-
68-
69- @add_module_test (torch .float32 , torch .device ("cuda" ), [(1 , 1 , 3 , 3 )])
70- @add_module_test (torch .float32 , torch .device ("cuda" ), [(1 , 2 , 3 , 3 )])
71- def test_reflection_pad_2d_simple ():
72- return nn .ReflectionPad2d ((1 , 0 , 1 , 0 ))
7355except :
7456 pass
0 commit comments