Skip to content

Commit d06a2b1

Browse files
authored
Merge pull request #134 from NVIDIA/fix_interpolate_trt_7.1
Fix for interpolate plugin segfault
2 parents ce77963 + fa4b2db commit d06a2b1

File tree

8 files changed

+190
-85
lines changed

8 files changed

+190
-85
lines changed

core/conversion/converters/impl/conv_deconv.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
4545

4646
deconv->setStrideNd(stride);
4747
deconv->setPaddingNd(padding);
48+
deconv->setDilationNd(dilation);
49+
deconv->setNbGroups(groups);
50+
4851
new_layer = deconv;
4952
} else {
5053
nvinfer1::IConvolutionLayer* conv;

core/conversion/converters/impl/interpolate.cpp

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ namespace {
1515
/*
1616
* Helper functions
1717
*/
18-
19-
void create_plugin(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, const char* name,
20-
std::vector<int64_t> in_shape,
21-
std::vector<int64_t> out_shape,
22-
std::vector<int64_t> out_size,
18+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
19+
void create_plugin(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, const char* name,
20+
std::vector<int64_t> in_shape,
21+
std::vector<int64_t> out_shape,
22+
std::vector<int64_t> out_size,
2323
std::string mode) {
24-
LOG_WARNING("Interpolation layer will be run through ATen, not TensorRT. Performance may differ.");
25-
24+
LOG_WARNING("Interpolation layer will be run through ATen, not TensorRT. Performance may be lower than expected");
25+
2626
auto creator = new plugins::InterpolatePluginCreator();
2727
auto plugin = creator->createPlugin(name, in_shape, out_shape, out_size, mode, false);
2828

@@ -35,23 +35,28 @@ void create_plugin(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITen
3535

3636
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
3737
}
38+
#endif
3839

39-
void resize_layer_size(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, std::vector<int64_t> out_shape,
40-
nvinfer1::ResizeMode mode) {
40+
void resize_layer_size(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, std::vector<int64_t> out_shape,
41+
nvinfer1::ResizeMode mode, bool align_corners=false) {
4142
auto resize_layer = ctx->net->addResize(*in);
4243
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
4344

4445
resize_layer->setOutputDimensions(util::toDims(out_shape));
4546
resize_layer->setResizeMode(mode);
4647
resize_layer->setName(util::node_info(n).c_str());
47-
48+
4849
// if interpolation mode is linear, align corners must have been set to true. else, don't use align corners.
4950
if (mode == nvinfer1::ResizeMode::kLINEAR) {
50-
resize_layer->setAlignCorners(true);
51+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
52+
resize_layer->setAlignCorners(true);
53+
#else
54+
resize_layer->setAlignCorners(align_corners);
55+
#endif
5156
}
5257

5358
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
54-
59+
5560
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
5661
}
5762

@@ -72,7 +77,7 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
7277
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
7378

7479
TRTORCH_ASSERT(out_size.size() == 1, "aten::upsample_nearest1d input Tensor and output size dimension mismatch");
75-
80+
7681
auto out_shape = in_shape;
7782
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
7883

@@ -94,10 +99,10 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
9499
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
95100

96101
TRTORCH_ASSERT(out_size.size() == 2, "aten::upsample_nearest2d input Tensor and output size dimension mismatch");
97-
102+
98103
auto out_shape = in_shape;
99104
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
100-
105+
101106
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kNEAREST);
102107
} else {
103108
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_nearest2d not supported yet.");
@@ -116,7 +121,7 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
116121
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
117122

118123
TRTORCH_ASSERT(out_size.size() == 3, "aten::upsample_nearest3d input Tensor and output size dimension mismatch");
119-
124+
120125
auto out_shape = in_shape;
121126
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
122127

@@ -139,16 +144,20 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
139144
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
140145

141146
TRTORCH_ASSERT(out_size.size() == 1, "aten::upsample_linear1d input Tensor and output size dimension mismatch");
142-
143-
auto out_shape = in_shape;
147+
148+
auto out_shape = in_shape;
144149
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
145150

151+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
146152
if (!align_corners) {
147153
// align_corners not supported in TensorRT, create plugin and run layer through PyTorch
148154
create_plugin(ctx, n, in, "linear1d", in_shape, out_shape, out_size, std::string("linear"));
149155
} else {
150-
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR);
156+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR. true);
151157
}
158+
#else
159+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, align_corners);
160+
#endif
152161
} else {
153162
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_linear1d not supported yet.");
154163
}
@@ -167,16 +176,20 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
167176
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
168177

169178
TRTORCH_ASSERT(out_size.size() == 2, "aten::upsample_bilinear2d input Tensor and output size dimension mismatch");
170-
179+
171180
auto out_shape = in_shape;
172181
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
173182

183+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
174184
if (!align_corners) {
175185
// align_corners not supported in TensorRT, create plugin and run layer through PyTorch
176186
create_plugin(ctx, n, in, "bilinear2d", in_shape, out_shape, out_size, std::string("bilinear"));
177187
} else {
178-
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR);
188+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR. true);
179189
}
190+
#else
191+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, align_corners);
192+
#endif
180193
} else {
181194
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_bilinear2d not supported yet.");
182195
}
@@ -195,16 +208,20 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
195208
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
196209

197210
TRTORCH_ASSERT(out_size.size() == 3, "aten::upsample_trilinear3d input Tensor and output size dimension mismatch");
198-
211+
199212
auto out_shape = in_shape;
200213
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
201214

215+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
202216
if (!align_corners) {
203217
// align_corners not supported in TensorRT, create plugin and run layer through PyTorch
204218
create_plugin(ctx, n, in, "trilinear3d", in_shape, out_shape, out_size, std::string("trilinear"));
205219
} else {
206-
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR);
220+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR. true);
207221
}
222+
#else
223+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, align_corners);
224+
#endif
208225
} else {
209226
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_trilinear3d not supported yet.");
210227
}

core/conversion/converters/impl/plugins/BUILD

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ cc_library(
2424
"//conditions:default": ["@libtorch//:libtorch"],
2525
}),
2626
alwayslink = True,
27+
copts = [
28+
"-pthread"
29+
],
30+
linkopts = [
31+
"-lpthread",
32+
]
2733
)
2834

2935
load("@rules_pkg//:pkg.bzl", "pkg_tar")

0 commit comments

Comments
 (0)