@@ -15,14 +15,14 @@ namespace {
15
15
/*
16
16
* Helper functions
17
17
*/
18
-
19
- void create_plugin (ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, const char * name,
20
- std::vector<int64_t > in_shape,
21
- std::vector<int64_t > out_shape,
22
- std::vector<int64_t > out_size,
18
+ # if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
19
+ void create_plugin (ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, const char * name,
20
+ std::vector<int64_t > in_shape,
21
+ std::vector<int64_t > out_shape,
22
+ std::vector<int64_t > out_size,
23
23
std::string mode) {
24
- LOG_WARNING (" Interpolation layer will be run through ATen, not TensorRT. Performance may differ. " );
25
-
24
+ LOG_WARNING (" Interpolation layer will be run through ATen, not TensorRT. Performance may be lower than expected " );
25
+
26
26
auto creator = new plugins::InterpolatePluginCreator ();
27
27
auto plugin = creator->createPlugin (name, in_shape, out_shape, out_size, mode, false );
28
28
@@ -35,23 +35,28 @@ void create_plugin(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITen
35
35
36
36
LOG_DEBUG (" Output tensor shape: " << layer_output->getDimensions ());
37
37
}
38
+ #endif
38
39
39
- void resize_layer_size (ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, std::vector<int64_t > out_shape,
40
- nvinfer1::ResizeMode mode) {
40
+ void resize_layer_size (ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, std::vector<int64_t > out_shape,
41
+ nvinfer1::ResizeMode mode, bool align_corners= false ) {
41
42
auto resize_layer = ctx->net ->addResize (*in);
42
43
TRTORCH_CHECK (resize_layer, " Unable to create interpolation (resizing) layer from node" << *n);
43
44
44
45
resize_layer->setOutputDimensions (util::toDims (out_shape));
45
46
resize_layer->setResizeMode (mode);
46
47
resize_layer->setName (util::node_info (n).c_str ());
47
-
48
+
48
49
// if interpolation mode is linear, align corners must have been set to true. else, don't use align corners.
49
50
if (mode == nvinfer1::ResizeMode::kLINEAR ) {
50
- resize_layer->setAlignCorners (true );
51
+ #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
52
+ resize_layer->setAlignCorners (true );
53
+ #else
54
+ resize_layer->setAlignCorners (align_corners);
55
+ #endif
51
56
}
52
57
53
58
auto layer_output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], resize_layer->getOutput (0 ));
54
-
59
+
55
60
LOG_DEBUG (" Output tensor shape: " << layer_output->getDimensions ());
56
61
}
57
62
@@ -72,7 +77,7 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
72
77
auto out_size = util::toVec (util::toDims (args[1 ].unwrapToIntList ()));
73
78
74
79
TRTORCH_ASSERT (out_size.size () == 1 , " aten::upsample_nearest1d input Tensor and output size dimension mismatch" );
75
-
80
+
76
81
auto out_shape = in_shape;
77
82
std::copy (out_size.begin (), out_size.end (), out_shape.begin () + (in_shape.size () - out_size.size ()));
78
83
@@ -94,10 +99,10 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
94
99
auto out_size = util::toVec (util::toDims (args[1 ].unwrapToIntList ()));
95
100
96
101
TRTORCH_ASSERT (out_size.size () == 2 , " aten::upsample_nearest2d input Tensor and output size dimension mismatch" );
97
-
102
+
98
103
auto out_shape = in_shape;
99
104
std::copy (out_size.begin (), out_size.end (), out_shape.begin () + (in_shape.size () - out_size.size ()));
100
-
105
+
101
106
resize_layer_size (ctx, n, in, out_shape, nvinfer1::ResizeMode::kNEAREST );
102
107
} else {
103
108
TRTORCH_THROW_ERROR (" Unable to convert node: " << util::node_info (n) << " \n Scale factor parameter for upsample_nearest2d not supported yet." );
@@ -116,7 +121,7 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
116
121
auto out_size = util::toVec (util::toDims (args[1 ].unwrapToIntList ()));
117
122
118
123
TRTORCH_ASSERT (out_size.size () == 3 , " aten::upsample_nearest3d input Tensor and output size dimension mismatch" );
119
-
124
+
120
125
auto out_shape = in_shape;
121
126
std::copy (out_size.begin (), out_size.end (), out_shape.begin () + (in_shape.size () - out_size.size ()));
122
127
@@ -139,16 +144,20 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
139
144
auto out_size = util::toVec (util::toDims (args[1 ].unwrapToIntList ()));
140
145
141
146
TRTORCH_ASSERT (out_size.size () == 1 , " aten::upsample_linear1d input Tensor and output size dimension mismatch" );
142
-
143
- auto out_shape = in_shape;
147
+
148
+ auto out_shape = in_shape;
144
149
std::copy (out_size.begin (), out_size.end (), out_shape.begin () + (in_shape.size () - out_size.size ()));
145
150
151
+ #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
146
152
if (!align_corners) {
147
153
// align_corners not supported in TensorRT, create plugin and run layer through PyTorch
148
154
create_plugin (ctx, n, in, " linear1d" , in_shape, out_shape, out_size, std::string (" linear" ));
149
155
} else {
150
- resize_layer_size (ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR );
156
+ resize_layer_size (ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR . true );
151
157
}
158
+ #else
159
+ resize_layer_size (ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR , align_corners);
160
+ #endif
152
161
} else {
153
162
TRTORCH_THROW_ERROR (" Unable to convert node: " << util::node_info (n) << " \n Scale factor parameter for upsample_linear1d not supported yet." );
154
163
}
@@ -167,16 +176,20 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
167
176
auto out_size = util::toVec (util::toDims (args[1 ].unwrapToIntList ()));
168
177
169
178
TRTORCH_ASSERT (out_size.size () == 2 , " aten::upsample_bilinear2d input Tensor and output size dimension mismatch" );
170
-
179
+
171
180
auto out_shape = in_shape;
172
181
std::copy (out_size.begin (), out_size.end (), out_shape.begin () + (in_shape.size () - out_size.size ()));
173
182
183
+ #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
174
184
if (!align_corners) {
175
185
// align_corners not supported in TensorRT, create plugin and run layer through PyTorch
176
186
create_plugin (ctx, n, in, " bilinear2d" , in_shape, out_shape, out_size, std::string (" bilinear" ));
177
187
} else {
178
- resize_layer_size (ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR );
188
+ resize_layer_size (ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR . true );
179
189
}
190
+ #else
191
+ resize_layer_size (ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR , align_corners);
192
+ #endif
180
193
} else {
181
194
TRTORCH_THROW_ERROR (" Unable to convert node: " << util::node_info (n) << " \n Scale factor parameter for upsample_bilinear2d not supported yet." );
182
195
}
@@ -195,16 +208,20 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
195
208
auto out_size = util::toVec (util::toDims (args[1 ].unwrapToIntList ()));
196
209
197
210
TRTORCH_ASSERT (out_size.size () == 3 , " aten::upsample_trilinear3d input Tensor and output size dimension mismatch" );
198
-
211
+
199
212
auto out_shape = in_shape;
200
213
std::copy (out_size.begin (), out_size.end (), out_shape.begin () + (in_shape.size () - out_size.size ()));
201
214
215
+ #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
202
216
if (!align_corners) {
203
217
// align_corners not supported in TensorRT, create plugin and run layer through PyTorch
204
218
create_plugin (ctx, n, in, " trilinear3d" , in_shape, out_shape, out_size, std::string (" trilinear" ));
205
219
} else {
206
- resize_layer_size (ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR );
220
+ resize_layer_size (ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR . true );
207
221
}
222
+ #else
223
+ resize_layer_size (ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR , align_corners);
224
+ #endif
208
225
} else {
209
226
TRTORCH_THROW_ERROR (" Unable to convert node: " << util::node_info (n) << " \n Scale factor parameter for upsample_trilinear3d not supported yet." );
210
227
}
0 commit comments