Skip to content

Commit 2e29e4e

Browse files
committed
feat(//core/conversion/converters/impl):
Added support for interpolate plugin, used when align_corners=False and mode is linear Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent a0d8586 commit 2e29e4e

File tree

1 file changed

+110
-94
lines changed

1 file changed

+110
-94
lines changed

core/conversion/converters/impl/interpolate.cpp

Lines changed: 110 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
#include "core/conversion/converters/converters.h"
44
#include "NvInfer.h"
55
#include "plugins/interpolate_plugin.h"
6+
#include "NvInferRuntimeCommon.h"
67

7-
#include <csignal>
8+
#include <tuple>
89

910
namespace trtorch {
1011
namespace core {
@@ -13,12 +14,75 @@ namespace converters {
1314
namespace impl {
1415
namespace {
1516

17+
/*
18+
* Helper functions
19+
*/
20+
21+
auto parse_nearest(args args) {
22+
auto in = args[0].ITensor();
23+
auto in_shape = util::toVec(in->getDimensions());
24+
25+
return std::make_tuple(in, in_shape);
26+
}
27+
28+
auto parse_linear(args args) {
29+
auto in = args[0].ITensor();
30+
auto in_shape = util::toVec(in->getDimensions());
31+
bool align_corners = args[2].unwrapToBool();
32+
33+
return std::make_tuple(in, in_shape, align_corners);
34+
}
35+
36+
void create_plugin(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, const char* name,
37+
std::vector<int64_t> in_shape,
38+
std::vector<int64_t> out_shape,
39+
std::vector<int64_t> out_size,
40+
std::string mode) {
41+
auto creator = new plugins::InterpolatePluginCreator();
42+
auto plugin = creator->createPlugin(name, in_shape, out_shape, out_size, mode, false);
43+
44+
auto resize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
45+
TRTORCH_CHECK(resize_layer, "Unable to create interpolation plugin from node" << *n);
46+
47+
resize_layer->setName(util::node_info(n).c_str());
48+
49+
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
50+
51+
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
52+
}
53+
54+
void resize_layer_size(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, std::vector<int64_t> out_shape,
55+
nvinfer1::ResizeMode mode) {
56+
auto resize_layer = ctx->net->addResize(*in);
57+
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
58+
59+
resize_layer->setOutputDimensions(util::toDims(out_shape));
60+
resize_layer->setResizeMode(mode);
61+
resize_layer->setName(util::node_info(n).c_str());
62+
63+
// if interpolation mode is linear, align corners must have been set to true. else, don't use align corners.
64+
if (mode == nvinfer1::ResizeMode::kLINEAR) {
65+
resize_layer->setAlignCorners(true);
66+
}
67+
68+
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
69+
70+
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
71+
}
72+
73+
74+
/*
75+
* Interpolate Converter
76+
*/
77+
1678
auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1779
.pattern({
1880
"aten::upsample_nearest1d(Tensor self, int[1] output_size, float? scales=None) -> (Tensor)",
1981
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
20-
auto in = args[0].ITensor();
21-
auto in_shape = util::toVec(in->getDimensions());
82+
auto parsed = parse_nearest(args);
83+
84+
auto in = std::get<0>(parsed);
85+
auto in_shape = std::get<1>(parsed);
2286

2387
// Case 1: user uses output size and not scales
2488
if (!args[1].IValue()->isNone() && args[2].IValue()->isNone()) {
@@ -29,15 +93,7 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
2993
auto out_shape = in_shape;
3094
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
3195

32-
auto resize_layer = ctx->net->addResize(*in);
33-
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
34-
35-
resize_layer->setOutputDimensions(util::toDims(out_shape));
36-
resize_layer->setResizeMode(nvinfer1::ResizeMode::kNEAREST);
37-
resize_layer->setName(util::node_info(n).c_str());
38-
39-
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
40-
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
96+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kNEAREST);
4197
} else {
4298
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_nearest1d not supported yet.");
4399
}
@@ -47,8 +103,10 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
47103
}).pattern({
48104
"aten::upsample_nearest2d(Tensor self, int[2] output_size, float? scales_h=None, float? scales_w=None) -> (Tensor)",
49105
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
50-
auto in = args[0].ITensor();
51-
auto in_shape = util::toVec(in->getDimensions());
106+
auto parsed = parse_nearest(args);
107+
108+
auto in = std::get<0>(parsed);
109+
auto in_shape = std::get<1>(parsed);
52110

53111
// Case 1: user uses output_size and not scales_h, scales_w
54112
if (!args[1].IValue()->isNone() && args[2].IValue()->isNone() && args[3].IValue()->isNone()){
@@ -59,15 +117,7 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
59117
auto out_shape = in_shape;
60118
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
61119

62-
auto resize_layer = ctx->net->addResize(*in);
63-
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
64-
65-
resize_layer->setOutputDimensions(util::toDims(out_shape));
66-
resize_layer->setResizeMode(nvinfer1::ResizeMode::kNEAREST);
67-
resize_layer->setName(util::node_info(n).c_str());
68-
69-
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
70-
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
120+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kNEAREST);
71121
} else {
72122
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_nearest2d not supported yet.");
73123
}
@@ -77,8 +127,10 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
77127
}).pattern({
78128
"aten::upsample_nearest3d(Tensor self, int[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> (Tensor)",
79129
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
80-
auto in = args[0].ITensor();
81-
auto in_shape = util::toVec(in->getDimensions());
130+
auto parsed = parse_nearest(args);
131+
132+
auto in = std::get<0>(parsed);
133+
auto in_shape = std::get<1>(parsed);
82134

83135
// Case 1: user uses output size and not scales_d, scales_h, scales_w
84136
if (!args[1].IValue()->isNone() && args[2].IValue()->isNone() && args[3].IValue()->isNone() && args[4].IValue()->isNone()) {
@@ -88,16 +140,8 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
88140

89141
auto out_shape = in_shape;
90142
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
91-
92-
auto resize_layer = ctx->net->addResize(*in);
93-
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
94-
95-
resize_layer->setOutputDimensions(util::toDims(out_shape));
96-
resize_layer->setResizeMode(nvinfer1::ResizeMode::kNEAREST);
97-
resize_layer->setName(util::node_info(n).c_str());
98143

99-
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
100-
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
144+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kNEAREST);
101145
} else {
102146
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_nearest3d not supported yet.");
103147
}
@@ -107,10 +151,11 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
107151
}).pattern({
108152
"aten::upsample_linear1d(Tensor self, int[1] output_size, bool align_corners, float? scales=None) -> (Tensor)",
109153
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
110-
auto in = args[0].ITensor();
111-
auto in_shape = util::toVec(in->getDimensions());
112-
113-
bool align_corners = args[2].unwrapToBool();
154+
auto parsed = parse_linear(args);
155+
156+
auto in = std::get<0>(parsed);
157+
auto in_shape = std::get<1>(parsed);
158+
auto align_corners = std::get<2>(parsed);
114159

115160
// Case 1: user uses output size and not scales
116161
if (!args[1].IValue()->isNone() && args[3].IValue()->isNone()) {
@@ -122,34 +167,10 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
122167
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
123168

124169
if (!align_corners) {
125-
//auto plugin = creator->createPlugin(util::node_info(n).c_str(), in_shape, out_shape, out_size, std::string("linear"), align_corners);
126-
std::raise(SIGINT);
127-
128-
//auto creator_auto = getPluginRegistry()->getPluginCreator("interpolate", "1");
129-
//auto plugin_auto = creator_auto->createPlugin(util::node_info(n).c_str(), nullptr);
130-
131-
//auto creator = getPluginRegistry()->getPluginCreator("interpolate", "1");
132-
133-
auto creator = new plugins::InterpolatePluginCreator();
134-
auto plugin = creator->createPlugin("interpolate_plugin", in_shape, out_shape, out_size, std::string("linear"), align_corners);
135-
136-
auto resize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
137-
resize_layer->setName(util::node_info(n).c_str());
138-
139-
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
140-
141-
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
170+
// align_corners not supported in TensorRT, create plugin and run layer through PyTorch
171+
create_plugin(ctx, n, in, "linear1d", in_shape, out_shape, out_size, std::string("linear"));
142172
} else {
143-
auto resize_layer = ctx->net->addResize(*in);
144-
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
145-
146-
resize_layer->setOutputDimensions(util::toDims(out_shape));
147-
resize_layer->setResizeMode(nvinfer1::ResizeMode::kLINEAR);
148-
resize_layer->setAlignCorners(align_corners);
149-
resize_layer->setName(util::node_info(n).c_str());
150-
151-
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
152-
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
173+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR);
153174
}
154175
} else {
155176
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_linear1d not supported yet.");
@@ -160,10 +181,11 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
160181
}).pattern({
161182
"aten::upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> (Tensor)",
162183
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
163-
auto in = args[0].ITensor();
164-
auto in_shape = util::toVec(in->getDimensions());
165-
166-
bool align_corners = args[2].IValue()->to<bool>();
184+
auto parsed = parse_linear(args);
185+
186+
auto in = std::get<0>(parsed);
187+
auto in_shape = std::get<1>(parsed);
188+
auto align_corners = std::get<2>(parsed);
167189

168190
// Case 1: user uses output size and not scales_h, scales_w
169191
if (!args[1].IValue()->isNone() && args[3].IValue()->isNone() && args[4].IValue()->isNone()) {
@@ -174,16 +196,12 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
174196
auto out_shape = in_shape;
175197
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
176198

177-
auto resize_layer = ctx->net->addResize(*in);
178-
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
179-
180-
resize_layer->setOutputDimensions(util::toDims(out_shape));
181-
resize_layer->setResizeMode(nvinfer1::ResizeMode::kLINEAR);
182-
resize_layer->setAlignCorners(align_corners);
183-
resize_layer->setName(util::node_info(n).c_str());
184-
185-
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
186-
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
199+
if (!align_corners) {
200+
// align_corners not supported in TensorRT, create plugin and run layer through PyTorch
201+
create_plugin(ctx, n, in, "bilinear2d", in_shape, out_shape, out_size, std::string("bilinear"));
202+
} else {
203+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR);
204+
}
187205
} else {
188206
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_bilinear2d not supported yet.");
189207
}
@@ -193,10 +211,11 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
193211
}).pattern({
194212
"aten::upsample_trilinear3d(Tensor self, int[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> (Tensor)",
195213
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
196-
auto in = args[0].ITensor();
197-
auto in_shape = util::toVec(in->getDimensions());
198-
199-
bool align_corners = args[2].IValue()->to<bool>();
214+
auto parsed = parse_linear(args);
215+
216+
auto in = std::get<0>(parsed);
217+
auto in_shape = std::get<1>(parsed);
218+
auto align_corners = std::get<2>(parsed);
200219

201220
// Case 1: user uses output size and not scales_d, scales_h, scales_w
202221
if (!args[1].IValue()->isNone() && args[3].IValue()->isNone() && args[4].IValue()->isNone() && args[5].IValue()->isNone()) {
@@ -207,16 +226,12 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
207226
auto out_shape = in_shape;
208227
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
209228

210-
auto resize_layer = ctx->net->addResize(*in);
211-
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
212-
213-
resize_layer->setOutputDimensions(util::toDims(out_shape));
214-
resize_layer->setResizeMode(nvinfer1::ResizeMode::kLINEAR);
215-
resize_layer->setAlignCorners(align_corners);
216-
resize_layer->setName(util::node_info(n).c_str());
217-
218-
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
219-
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
229+
if (!align_corners) {
230+
// align_corners not supported in TensorRT, create plugin and run layer through PyTorch
231+
create_plugin(ctx, n, in, "trilinear3d", in_shape, out_shape, out_size, std::string("trilinear"));
232+
} else {
233+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR);
234+
}
220235
} else {
221236
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_trilinear3d not supported yet.");
222237
}
@@ -225,9 +240,10 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
225240
}
226241
});
227242

243+
228244
} // namespace
229245
} // namespace impl
230246
} // namespace converters
231247
} // namespace conversion
232248
} // namespace core
233-
} // namespace trtorch
249+
} // namespace trtorch

0 commit comments

Comments
 (0)