Skip to content

Commit 010b801

Browse files
committed
fix!(aten::adaptive_avg_pool_2d, aten::interpolate): Moves interpolate
align_corners cases to TensorRT and also creates a work around for dynamic adaptive_avg_pool_2d that runs the kernel on CPU HACK: Runs adaptive_avg_pool_2d on CPU, this should be replaced when it can be determined why creating a Tensor on GPU causes segfault. Current idea is there is a cuDNN version mismatch. Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 48b950a commit 010b801

File tree

5 files changed

+87
-124
lines changed

5 files changed

+87
-124
lines changed

core/conversion/converters/impl/interpolate.cpp

Lines changed: 16 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -16,42 +16,22 @@ namespace {
1616
* Helper functions
1717
*/
1818

19-
void create_plugin(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, const char* name,
20-
std::vector<int64_t> in_shape,
21-
std::vector<int64_t> out_shape,
22-
std::vector<int64_t> out_size,
23-
std::string mode) {
24-
LOG_WARNING("Interpolation layer will be run through ATen, not TensorRT. Performance may differ.");
25-
26-
auto creator = new plugins::InterpolatePluginCreator();
27-
auto plugin = creator->createPlugin(name, in_shape, out_shape, out_size, mode, false);
28-
29-
auto resize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
30-
TRTORCH_CHECK(resize_layer, "Unable to create interpolation plugin from node" << *n);
31-
32-
resize_layer->setName(util::node_info(n).c_str());
33-
34-
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
35-
36-
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
37-
}
38-
39-
void resize_layer_size(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, std::vector<int64_t> out_shape,
40-
nvinfer1::ResizeMode mode) {
19+
void resize_layer_size(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, std::vector<int64_t> out_shape,
20+
nvinfer1::ResizeMode mode, bool align_corners=false) {
4121
auto resize_layer = ctx->net->addResize(*in);
4222
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
4323

4424
resize_layer->setOutputDimensions(util::toDims(out_shape));
4525
resize_layer->setResizeMode(mode);
4626
resize_layer->setName(util::node_info(n).c_str());
47-
27+
4828
// if interpolation mode is linear, align corners must have been set to true. else, don't use align corners.
4929
if (mode == nvinfer1::ResizeMode::kLINEAR) {
50-
resize_layer->setAlignCorners(true);
30+
resize_layer->setAlignCorners(align_corners);
5131
}
5232

5333
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
54-
34+
5535
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
5636
}
5737

@@ -72,7 +52,7 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
7252
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
7353

7454
TRTORCH_ASSERT(out_size.size() == 1, "aten::upsample_nearest1d input Tensor and output size dimension mismatch");
75-
55+
7656
auto out_shape = in_shape;
7757
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
7858

@@ -94,10 +74,10 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
9474
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
9575

9676
TRTORCH_ASSERT(out_size.size() == 2, "aten::upsample_nearest2d input Tensor and output size dimension mismatch");
97-
77+
9878
auto out_shape = in_shape;
9979
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
100-
80+
10181
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kNEAREST);
10282
} else {
10383
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_nearest2d not supported yet.");
@@ -116,7 +96,7 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
11696
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
11797

11898
TRTORCH_ASSERT(out_size.size() == 3, "aten::upsample_nearest3d input Tensor and output size dimension mismatch");
119-
99+
120100
auto out_shape = in_shape;
121101
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
122102

@@ -139,16 +119,11 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
139119
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
140120

141121
TRTORCH_ASSERT(out_size.size() == 1, "aten::upsample_linear1d input Tensor and output size dimension mismatch");
142-
143-
auto out_shape = in_shape;
122+
123+
auto out_shape = in_shape;
144124
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
145125

146-
if (!align_corners) {
147-
// align_corners not supported in TensorRT, create plugin and run layer through PyTorch
148-
create_plugin(ctx, n, in, "linear1d", in_shape, out_shape, out_size, std::string("linear"));
149-
} else {
150-
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR);
151-
}
126+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, align_corners);
152127
} else {
153128
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_linear1d not supported yet.");
154129
}
@@ -167,16 +142,11 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
167142
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
168143

169144
TRTORCH_ASSERT(out_size.size() == 2, "aten::upsample_bilinear2d input Tensor and output size dimension mismatch");
170-
145+
171146
auto out_shape = in_shape;
172147
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
173148

174-
if (!align_corners) {
175-
// align_corners not supported in TensorRT, create plugin and run layer through PyTorch
176-
create_plugin(ctx, n, in, "bilinear2d", in_shape, out_shape, out_size, std::string("bilinear"));
177-
} else {
178-
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR);
179-
}
149+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, align_corners);
180150
} else {
181151
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_bilinear2d not supported yet.");
182152
}
@@ -195,16 +165,11 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
195165
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
196166

197167
TRTORCH_ASSERT(out_size.size() == 3, "aten::upsample_trilinear3d input Tensor and output size dimension mismatch");
198-
168+
199169
auto out_shape = in_shape;
200170
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
201171

202-
if (!align_corners) {
203-
// align_corners not supported in TensorRT, create plugin and run layer through PyTorch
204-
create_plugin(ctx, n, in, "trilinear3d", in_shape, out_shape, out_size, std::string("trilinear"));
205-
} else {
206-
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR);
207-
}
172+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, align_corners);
208173
} else {
209174
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_trilinear3d not supported yet.");
210175
}

core/conversion/converters/impl/plugins/BUILD

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ cc_library(
2424
"//conditions:default": ["@libtorch//:libtorch"],
2525
}),
2626
alwayslink = True,
27+
copts = [
28+
"-pthread"
29+
],
30+
linkopts = [
31+
"-lpthread",
32+
]
2733
)
2834

2935
load("@rules_pkg//:pkg.bzl", "pkg_tar")

core/conversion/converters/impl/plugins/interpolate_plugin.cpp

Lines changed: 50 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -9,57 +9,57 @@ namespace converters {
99
namespace impl {
1010
namespace plugins {
1111

12-
/*
12+
/*
1313
* InterpolatePlugin class implementations
1414
*/
1515

16-
InterpolatePlugin::InterpolatePlugin(std::vector<int64_t> in_shape, std::vector<int64_t> out_shape, std::vector<int64_t> size, std::string mode, bool align_corners) :
17-
in_shape(in_shape), out_shape(out_shape), size(size), mode(mode), align_corners(align_corners)
16+
InterpolatePlugin::InterpolatePlugin(std::vector<int64_t> in_shape, std::vector<int64_t> out_shape, std::vector<int64_t> size, std::string mode, bool align_corners) :
17+
in_shape_(in_shape), out_shape_(out_shape), size_(size), mode_(mode), align_corners_(align_corners)
1818
{}
1919

2020
InterpolatePlugin::InterpolatePlugin(const char *data, size_t length) {
2121
std::istringstream data_stream(std::string(data, length));
22-
22+
2323
torch::serialize::InputArchive input_archive;
2424
input_archive.load_from(data_stream);
25-
25+
2626
{
2727
torch::IValue value;
2828
input_archive.read("in_shape", value);
29-
in_shape = value.toIntVector();
29+
in_shape_ = value.toIntVector();
3030
}
3131
{
3232
torch::IValue value;
3333
input_archive.read("out_shape", value);
34-
out_shape = value.toIntVector();
34+
out_shape_ = value.toIntVector();
3535
}
3636
{
3737
torch::IValue value;
3838
input_archive.read("size", value);
39-
size = value.toIntVector();
39+
size_ = value.toIntVector();
4040
}
4141
{
4242
torch::IValue value;
4343
input_archive.read("mode", value);
44-
mode = value.toStringRef();
44+
mode_ = value.toStringRef();
4545
}
4646
{
4747
torch::IValue value;
4848
input_archive.read("align_corners", value);
49-
align_corners = value.toBool();
49+
align_corners_ = value.toBool();
5050
}
5151
}
5252

5353
std::vector<int64_t> InterpolatePlugin::getInputShape() {
54-
return in_shape;
54+
return in_shape_;
5555
}
5656

5757
std::vector<int64_t> InterpolatePlugin::getOutputShape() {
58-
return out_shape;
58+
return out_shape_;
5959
}
6060

6161
std::vector<int64_t> InterpolatePlugin::getOutputSize() {
62-
return size;
62+
return size_;
6363
}
6464

6565
int InterpolatePlugin::getNbOutputs() const {
@@ -80,14 +80,14 @@ const char* InterpolatePlugin::getPluginNamespace() const {
8080

8181

8282
nvinfer1::IPluginV2DynamicExt* InterpolatePlugin::clone() const {
83-
return new InterpolatePlugin(in_shape, out_shape, size, mode, align_corners);
83+
return new InterpolatePlugin(in_shape_, out_shape_, size_, mode_, align_corners_);
8484
}
8585

8686
nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, nvinfer1::IExprBuilder &exprBuilder) {
8787
nvinfer1::DimsExprs output(inputs[0]);
8888

89-
for (unsigned int i = 0; i < out_shape.size(); i++) {
90-
output.d[i] = exprBuilder.constant(out_shape[i]);
89+
for (unsigned int i = 0; i < out_shape_.size(); i++) {
90+
output.d[i] = exprBuilder.constant(out_shape_[i]);
9191
}
9292

9393
return output;
@@ -98,10 +98,10 @@ nvinfer1::DataType InterpolatePlugin::getOutputDataType(int index, const nvinfer
9898
}
9999

100100
int InterpolatePlugin::initialize() {
101-
tensor_options = tensor_options.device(c10::kCUDA);
101+
tensor_options_ = tensor_options_.device(c10::kCPU);
102102

103103
// c10::kFloat = FLOAT32
104-
tensor_options = tensor_options.dtype(c10::kFloat);
104+
tensor_options_ = tensor_options_.dtype(c10::kFloat);
105105

106106
return 0;
107107
}
@@ -117,11 +117,11 @@ void InterpolatePlugin::serialize(void* buffer) const {
117117
std::string InterpolatePlugin::serializeToString() const {
118118
torch::serialize::OutputArchive output_archive;
119119

120-
output_archive.write("in_shape", torch::IValue(in_shape));
121-
output_archive.write("out_shape", torch::IValue(out_shape));
122-
output_archive.write("size", torch::IValue(size));
123-
output_archive.write("mode", torch::IValue(mode));
124-
output_archive.write("align_corners", torch::IValue(align_corners));
120+
output_archive.write("in_shape", torch::IValue(in_shape_));
121+
output_archive.write("out_shape", torch::IValue(out_shape_));
122+
output_archive.write("size", torch::IValue(size_));
123+
output_archive.write("mode", torch::IValue(mode_));
124+
output_archive.write("align_corners", torch::IValue(align_corners_));
125125

126126
std::ostringstream data_str;
127127
output_archive.save_to(data_str);
@@ -146,56 +146,48 @@ bool InterpolatePlugin::supportsFormatCombination(int pos, const nvinfer1::Plugi
146146

147147
// pos == 1, accessing information about output tensor
148148
const PluginTensorDesc& out = inOut[1];
149-
149+
150150
return (in.type == out.type) && (in.format == out.format);
151151
}
152152

153153
void InterpolatePlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) {
154-
dtype = DataType::kFLOAT;
154+
dtype_ = DataType::kFLOAT;
155155
}
156156

157157
size_t InterpolatePlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const {
158158
return 0;
159159
}
160160

161-
int InterpolatePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void *const *inputs,
162-
void *const *outputs, void *workspace,
161+
int InterpolatePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs,
162+
void* const* outputs, void* workspace,
163163
cudaStream_t stream) {
164-
at::Tensor input = at::from_blob((void*) inputs[0], util::toVec(inputDesc->dims), [](void*){}, tensor_options);
165-
at::Tensor output = at::from_blob(outputs[0], out_shape, [](void*){}, tensor_options);
166-
167-
at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool();
168-
at::cuda::CUDAStreamGuard torch_guard(torch_stream);
169-
170-
cudaEvent_t event;
171-
cudaEventCreate(&event);
172-
cudaEventRecord(event, stream);
173-
174-
cudaStreamWaitEvent(torch_stream.stream(), event, 0);
175-
176-
if (mode == "linear") {
177-
at::upsample_linear1d_out(output, input, {size[0]}, align_corners);
178-
} else if (mode == "bilinear") {
179-
at::upsample_bilinear2d_out(output, input, {size[0], size[1]}, align_corners);
180-
} else if (mode == "trilinear") {
181-
at::upsample_trilinear3d_out(output, input, {size[0], size[1], size[2]}, align_corners);
182-
} else if (mode == "adaptive_pool2d") {
183-
at::adaptive_avg_pool2d_out(output, input, {size[0], size[1]});
164+
// TODO: When PyTorch updates to cuDNN 8 try moving back to CUDA based ATen kernels
165+
// HACK: WAR because there is a segfault if you try to create a CUDA Tensor in the context of TensorRT execution
166+
float* input_blob = (float*) malloc(util::volume(inputDesc->dims) * sizeof(float));
167+
cudaMemcpyAsync(input_blob, static_cast<const void*>(inputs[0]), util::volume(inputDesc->dims) * sizeof(float), cudaMemcpyDeviceToHost, stream);
168+
cudaStreamSynchronize(stream);
169+
170+
at::Tensor input = at::from_blob((void*)input_blob, util::toVec(inputDesc->dims), tensor_options_);
171+
172+
at::Tensor output;
173+
if (mode_ == "adaptive_pool2d") {
174+
output = at::adaptive_avg_pool2d(input, {size_[0], size_[1]});
184175
}
185176

186-
cudaEvent_t torch_event;
187-
cudaEventCreate(&torch_event);
188-
cudaEventRecord(torch_event, torch_stream.stream());
177+
output = output.contiguous();
178+
for (int i = 0; i < util::volume(outputDesc->dims); i++) {
179+
std::cout << ((float*)output.data_ptr())[i] << std::endl;
180+
}
189181

190-
cudaStreamWaitEvent(stream, torch_event, 0);
182+
cudaMemcpyAsync(outputs[0], output.data_ptr(), util::volume(outputDesc->dims) * sizeof(float), cudaMemcpyHostToDevice, stream);
183+
cudaStreamSynchronize(stream);
191184

192-
cudaEventDestroy(event);
193-
cudaEventDestroy(torch_event);
185+
free(input_blob);
194186

195187
return 0;
196188
}
197189

198-
/*
190+
/*
199191
* InterpolatePluginCreator class implementations
200192
*/
201193
const char* InterpolatePluginCreator::getPluginNamespace() const {
@@ -214,15 +206,15 @@ nvinfer1::IPluginV2* InterpolatePluginCreator::createPlugin(const char* name, co
214206
return nullptr;
215207
}
216208

217-
InterpolatePlugin* InterpolatePluginCreator::createPlugin(const char* name, std::vector<int64_t> in_shape, std::vector<int64_t> out_shape,
218-
std::vector<int64_t> size,
209+
InterpolatePlugin* InterpolatePluginCreator::createPlugin(const char* name, std::vector<int64_t> in_shape, std::vector<int64_t> out_shape,
210+
std::vector<int64_t> size,
219211
std::string mode, bool align_corners) {
220-
name = name;
212+
name_ = name;
221213
return new InterpolatePlugin(in_shape, out_shape, size, mode, align_corners);
222214
}
223215

224216
nvinfer1::IPluginV2* InterpolatePluginCreator::deserializePlugin(const char* name, const void *serialData, size_t serialLength) {
225-
name = name;
217+
name_ = name;
226218
return new InterpolatePlugin((const char*) serialData, serialLength);
227219
}
228220

0 commit comments

Comments
 (0)