Skip to content

Commit 0a45bdf

Browse files
committed
refactor(supportedops): refactor registered ops code
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 0c340b4 commit 0a45bdf

File tree

4 files changed

+124
-4
lines changed

4 files changed

+124
-4
lines changed

core/conversion/converters/NodeConverterRegistry.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ class NodeConverterRegistry {
8686

8787
std::vector<std::string> GetRegisteredConverterList() {
8888
std::vector<std::string> converter_list;
89-
std::copy(registered_converter_schemas_.begin(), registered_converter_schemas_.end(), std::back_inserter(converter_list));
89+
std::copy(
90+
registered_converter_schemas_.begin(), registered_converter_schemas_.end(), std::back_inserter(converter_list));
9091
return converter_list;
9192
}
9293

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#include "NvInfer.h"
2+
#include "NvInferRuntimeCommon.h"
3+
#include "core/conversion/converters/converters.h"
4+
#include "core/util/prelude.h"
5+
#include "plugins/interpolate_plugin.h"
6+
#include "torch/torch.h"
7+
8+
namespace trtorch {
9+
namespace core {
10+
namespace conversion {
11+
namespace converters {
12+
namespace impl {
13+
namespace {
14+
15+
/*
16+
* Helper functions
17+
*/
18+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
19+
void create_plugin(
20+
ConversionCtx* ctx,
21+
const torch::jit::Node* n,
22+
nvinfer1::ITensor* in,
23+
const char* name,
24+
std::vector<int64_t> in_shape,
25+
std::vector<int64_t> out_shape,
26+
std::vector<int64_t> out_size,
27+
std::string mode) {
28+
LOG_WARNING("Interpolation layer will be run through ATen, not TensorRT. Performance may be lower than expected");
29+
30+
auto creator = new plugins::InterpolatePluginCreator();
31+
auto plugin = creator->createPlugin(name, in_shape, out_shape, out_size, mode, false);
32+
33+
auto resize_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *plugin);
34+
TRTORCH_CHECK(resize_layer, "Unable to create interpolation plugin from node" << *n);
35+
36+
resize_layer->setName(util::node_info(n).c_str());
37+
38+
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
39+
40+
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
41+
}
42+
#endif
43+
44+
void resize_layer_size(
45+
ConversionCtx* ctx,
46+
const torch::jit::Node* n,
47+
nvinfer1::ITensor* in,
48+
std::vector<int64_t> out_shape,
49+
nvinfer1::ResizeMode mode,
50+
bool align_corners = false) {
51+
auto resize_layer = ctx->net->addResize(*in);
52+
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
53+
54+
resize_layer->setOutputDimensions(util::toDims(out_shape));
55+
resize_layer->setResizeMode(mode);
56+
resize_layer->setName(util::node_info(n).c_str());
57+
58+
// if interpolation mode is linear, align corners must have been set to true.
59+
// else, don't use align corners.
60+
if (mode == nvinfer1::ResizeMode::kLINEAR) {
61+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
62+
resize_layer->setAlignCorners(true);
63+
#else
64+
resize_layer->setAlignCorners(align_corners);
65+
#endif
66+
}
67+
68+
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
69+
70+
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
71+
}
72+
73+
/*
74+
* Interpolate Converter
75+
*/
76+
77+
auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
78+
{"aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> (Tensor))",
79+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
80+
auto in = args[0].ITensor();
81+
auto in_shape = util::toVec(in->getDimensions());
82+
bool align_corners = args[2].unwrapToBool();
83+
84+
// Case 1: user uses output size and not scales_h, scales_w
85+
if (!args[1].IValue()->isNone() && args[3].IValue()->isNone() && args[4].IValue()->isNone()) {
86+
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
87+
88+
TRTORCH_ASSERT(
89+
out_size.size() == 2, "aten::upsample_bilinear2d input Tensor and output size dimension mismatch");
90+
91+
auto out_shape = in_shape;
92+
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
93+
94+
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
95+
if (!align_corners) {
96+
// align_corners not supported in TensorRT, create plugin and
97+
// run layer through PyTorch
98+
create_plugin(ctx, n, in, "bilinear2d", in_shape, out_shape, out_size, std::string("bilinear"));
99+
} else {
100+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, true);
101+
}
102+
#else
103+
resize_layer_size(ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR, align_corners);
104+
#endif
105+
} else {
106+
TRTORCH_THROW_ERROR(
107+
"Unable to convert node: " << util::node_info(n)
108+
<< "\nScale factor parameter for upsample_bilinear2d not supported yet.");
109+
}
110+
111+
return true;
112+
}})
113+
} // namespace
114+
} // namespace impl
115+
} // namespace converters
116+
} // namespace conversion
117+
} // namespace core
118+
} // namespace trtorch

core/conversion/evaluators/NodeEvaluatorRegistry.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ class NodeEvaluatorRegistry {
8181

8282
std::vector<std::string> GetRegisteredEvaluatorList() {
8383
std::vector<std::string> evaluator_list;
84-
std::copy(registered_evaluator_schemas_.begin(), registered_evaluator_schemas_.end(), std::back_inserter(evaluator_list));
84+
std::copy(
85+
registered_evaluator_schemas_.begin(), registered_evaluator_schemas_.end(), std::back_inserter(evaluator_list));
8586
return evaluator_list;
8687
}
8788

cpp/supportedops/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#include "core/conversion/converters/converters.h"
22
#include "core/conversion/evaluators/evaluators.h"
33

4-
#include <string>
4+
#include <iostream>
55
#include <sstream>
6+
#include <string>
67
#include <vector>
7-
#include <iostream>
88

99
int main(int argc, const char* argv[]) {
1010
std::vector<std::string> converters = trtorch::core::conversion::converters::get_converter_list();

0 commit comments

Comments
 (0)