Skip to content

Commit d4fe8da

Browse files
committed
refactor(interpolate_plugin): Leaves the old parts for continuned
support of TensorRT 7.0 but uses new systems for TRT 7.1 Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent da28184 commit d4fe8da

File tree

2 files changed

+92
-1
lines changed

2 files changed

+92
-1
lines changed

core/conversion/converters/impl/interpolate.cpp

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,27 @@ namespace {
1515
/*
1616
* Helper functions
1717
*/
18+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
19+
void create_plugin(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, const char* name,
20+
std::vector<int64_t> in_shape,
21+
std::vector<int64_t> out_shape,
22+
std::vector<int64_t> out_size,
23+
std::string mode) {
24+
LOG_WARNING("Interpolation layer will be run through ATen, not TensorRT. Performance may differ.");
25+
26+
auto creator = new plugins::InterpolatePluginCreator();
27+
auto plugin = creator->createPlugin(name, in_shape, out_shape, out_size, mode, false);
28+
29+
auto resize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
30+
TRTORCH_CHECK(resize_layer, "Unable to create interpolation plugin from node" << *n);
31+
32+
resize_layer->setName(util::node_info(n).c_str());
33+
34+
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
35+
36+
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
37+
}
38+
#endif
1839

1940
void resize_layer_size(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, std::vector<int64_t> out_shape,
2041
nvinfer1::ResizeMode mode, bool align_corners=false) {
@@ -27,7 +48,11 @@ void resize_layer_size(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::
2748

2849
// if interpolation mode is linear, align corners must have been set to true. else, don't use align corners.
2950
if (mode == nvinfer1::ResizeMode::kLINEAR) {
30-
resize_layer->setAlignCorners(align_corners);
51+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
52+
resize_layer->setAlignCorners(true);
53+
#else
54+
resize_layer->setAlignCorners(align_corners);
55+
#endif
3156
}
3257

3358
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
@@ -123,7 +148,16 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
123148
auto out_shape = in_shape;
124149
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
125150

151+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
152+
if (!align_corners) {
153+
// align_corners not supported in TensorRT, create plugin and run layer through PyTorch
154+
create_plugin(ctx, n, in, "linear1d", in_shape, out_shape, out_size, std::string("linear"));
155+
} else {
156+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR. true);
157+
}
158+
#else
126159
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, align_corners);
160+
#endif
127161
} else {
128162
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_linear1d not supported yet.");
129163
}
@@ -146,7 +180,16 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
146180
auto out_shape = in_shape;
147181
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
148182

183+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
184+
if (!align_corners) {
185+
// align_corners not supported in TensorRT, create plugin and run layer through PyTorch
186+
create_plugin(ctx, n, in, "bilinear2d", in_shape, out_shape, out_size, std::string("bilinear"));
187+
} else {
188+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR. true);
189+
}
190+
#else
149191
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, align_corners);
192+
#endif
150193
} else {
151194
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_bilinear2d not supported yet.");
152195
}
@@ -169,7 +212,16 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
169212
auto out_shape = in_shape;
170213
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
171214

215+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
216+
if (!align_corners) {
217+
// align_corners not supported in TensorRT, create plugin and run layer through PyTorch
218+
create_plugin(ctx, n, in, "trilinear3d", in_shape, out_shape, out_size, std::string("trilinear"));
219+
} else {
220+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR. true);
221+
}
222+
#else
172223
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, align_corners);
224+
#endif
173225
} else {
174226
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_trilinear3d not supported yet.");
175227
}

core/conversion/converters/impl/plugins/interpolate_plugin.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,11 @@ nvinfer1::DataType InterpolatePlugin::getOutputDataType(int index, const nvinfer
9898
}
9999

100100
int InterpolatePlugin::initialize() {
101+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
102+
tensor_options_ = tensor_options_.device(c10::kCUDA);
103+
#else
101104
tensor_options_ = tensor_options_.device(c10::kCPU);
105+
#endif
102106

103107
// c10::kFloat = FLOAT32
104108
tensor_options_ = tensor_options_.dtype(c10::kFloat);
@@ -161,6 +165,40 @@ size_t InterpolatePlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inp
161165
int InterpolatePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs,
162166
void* const* outputs, void* workspace,
163167
cudaStream_t stream) {
168+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
169+
at::Tensor input = at::from_blob((void*) inputs[0], util::toVec(inputDesc->dims), [](void*){}, tensor_options_);
170+
at::Tensor output = at::from_blob(outputs[0], util::volume(outputDesc->dims), [](void*){}, tensor_options_);
171+
172+
at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool();
173+
at::cuda::CUDAStreamGuard torch_guard(torch_stream);
174+
175+
cudaEvent_t event;
176+
cudaEventCreate(&event);
177+
cudaEventRecord(event, stream);
178+
179+
cudaStreamWaitEvent(torch_stream.stream(), event, 0);
180+
181+
if (mode == "linear") {
182+
at::upsample_linear1d_out(output, input, {size_[0]}, align_corners_);
183+
} else if (mode == "bilinear") {
184+
at::upsample_bilinear2d_out(output, input, {size_[0], size_[1]}, align_corners_);
185+
} else if (mode == "trilinear") {
186+
at::upsample_trilinear3d_out(output, input, {size_[0], size_[1], size_[2]}, align_corners_);
187+
} else if (mode == "adaptive_pool2d") {
188+
at::adaptive_avg_pool2d_out(output, input, {size_[0], size_[1]});
189+
}
190+
191+
cudaEvent_t torch_event;
192+
cudaEventCreate(&torch_event);
193+
cudaEventRecord(torch_event, torch_stream.stream());
194+
195+
cudaStreamWaitEvent(stream, torch_event, 0);
196+
197+
cudaEventDestroy(event);
198+
cudaEventDestroy(torch_event);
199+
200+
return 0;
201+
#else
164202
// TODO: When PyTorch updates to cuDNN 8 try moving back to CUDA based ATen kernels
165203
// HACK: WAR because there is a segfault if you try to create a CUDA Tensor in the context of TensorRT execution
166204
float* input_blob = (float*) malloc(util::volume(inputDesc->dims) * sizeof(float));
@@ -185,6 +223,7 @@ int InterpolatePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, cons
185223
free(input_blob);
186224

187225
return 0;
226+
#endif
188227
}
189228

190229
/*

0 commit comments

Comments
 (0)