1
+ #include " NvInfer.h"
2
+ #include " NvInferRuntimeCommon.h"
3
+ #include " core/conversion/converters/converters.h"
4
+ #include " core/util/prelude.h"
5
+ #include " plugins/interpolate_plugin.h"
6
+ #include " torch/torch.h"
7
+
8
+ namespace trtorch {
9
+ namespace core {
10
+ namespace conversion {
11
+ namespace converters {
12
+ namespace impl {
13
+ namespace {
14
+
15
+ /*
16
+ * Helper functions
17
+ */
18
+ #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
19
+ void create_plugin (
20
+ ConversionCtx* ctx,
21
+ const torch::jit::Node* n,
22
+ nvinfer1::ITensor* in,
23
+ const char * name,
24
+ std::vector<int64_t > in_shape,
25
+ std::vector<int64_t > out_shape,
26
+ std::vector<int64_t > out_size,
27
+ std::string mode) {
28
+ LOG_WARNING (" Interpolation layer will be run through ATen, not TensorRT. Performance may be lower than expected" );
29
+
30
+ auto creator = new plugins::InterpolatePluginCreator ();
31
+ auto plugin = creator->createPlugin (name, in_shape, out_shape, out_size, mode, false );
32
+
33
+ auto resize_layer = ctx->net ->addPluginV2 (reinterpret_cast <nvinfer1::ITensor* const *>(&in), 1 , *plugin);
34
+ TRTORCH_CHECK (resize_layer, " Unable to create interpolation plugin from node" << *n);
35
+
36
+ resize_layer->setName (util::node_info (n).c_str ());
37
+
38
+ auto layer_output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], resize_layer->getOutput (0 ));
39
+
40
+ LOG_DEBUG (" Output tensor shape: " << layer_output->getDimensions ());
41
+ }
42
+ #endif
43
+
44
+ void resize_layer_size (
45
+ ConversionCtx* ctx,
46
+ const torch::jit::Node* n,
47
+ nvinfer1::ITensor* in,
48
+ std::vector<int64_t > out_shape,
49
+ nvinfer1::ResizeMode mode,
50
+ bool align_corners = false ) {
51
+ auto resize_layer = ctx->net ->addResize (*in);
52
+ TRTORCH_CHECK (resize_layer, " Unable to create interpolation (resizing) layer from node" << *n);
53
+
54
+ resize_layer->setOutputDimensions (util::toDims (out_shape));
55
+ resize_layer->setResizeMode (mode);
56
+ resize_layer->setName (util::node_info (n).c_str ());
57
+
58
+ // if interpolation mode is linear, align corners must have been set to true.
59
+ // else, don't use align corners.
60
+ if (mode == nvinfer1::ResizeMode::kLINEAR ) {
61
+ #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
62
+ resize_layer->setAlignCorners (true );
63
+ #else
64
+ resize_layer->setAlignCorners (align_corners);
65
+ #endif
66
+ }
67
+
68
+ auto layer_output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], resize_layer->getOutput (0 ));
69
+
70
+ LOG_DEBUG (" Output tensor shape: " << layer_output->getDimensions ());
71
+ }
72
+
73
+ /*
74
+ * Interpolate Converter
75
+ */
76
+
77
+ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
78
+ {" aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> (Tensor))" ,
79
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
80
+ auto in = args[0 ].ITensor ();
81
+ auto in_shape = util::toVec (in->getDimensions ());
82
+ bool align_corners = args[2 ].unwrapToBool ();
83
+
84
+ // Case 1: user uses output size and not scales_h, scales_w
85
+ if (!args[1 ].IValue ()->isNone () && args[3 ].IValue ()->isNone () && args[4 ].IValue ()->isNone ()) {
86
+ auto out_size = util::toVec (util::toDims (args[1 ].unwrapToIntList ()));
87
+
88
+ TRTORCH_ASSERT (
89
+ out_size.size () == 2 , " aten::upsample_bilinear2d input Tensor and output size dimension mismatch" );
90
+
91
+ auto out_shape = in_shape;
92
+ std::copy (out_size.begin (), out_size.end (), out_shape.begin () + (in_shape.size () - out_size.size ()));
93
+
94
+ #if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
95
+ if (!align_corners) {
96
+ // align_corners not supported in TensorRT, create plugin and
97
+ // run layer through PyTorch
98
+ create_plugin (ctx, n, in, " bilinear2d" , in_shape, out_shape, out_size, std::string (" bilinear" ));
99
+ } else {
100
+ resize_layer_size (ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR , true );
101
+ }
102
+ #else
103
+ resize_layer_size (ctx, n, in, out_shape, nvinfer1::ResizeMode::kLINEAR , align_corners);
104
+ #endif
105
+ } else {
106
+ TRTORCH_THROW_ERROR (
107
+ " Unable to convert node: " << util::node_info (n)
108
+ << " \n Scale factor parameter for upsample_bilinear2d not supported yet." );
109
+ }
110
+
111
+ return true ;
112
+ }})
113
+ } // namespace
114
+ } // namespace impl
115
+ } // namespace converters
116
+ } // namespace conversion
117
+ } // namespace core
118
+ } // namespace trtorch
0 commit comments