1
- #include " torch/torch.h"
2
-
1
+ #include " core/conversion/converters/converter_util.h"
3
2
#include " core/conversion/converters/converters.h"
4
3
#include " core/util/prelude.h"
4
+ #include " torch/torch.h"
5
5
6
6
namespace trtorch {
7
7
namespace core {
@@ -14,15 +14,49 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
14
14
auto in = args[0 ].ITensor (); // assumes non-static input Tensor
15
15
auto w = Weights (ctx, args[1 ].unwrapToTensor ());
16
16
auto stride = util::toDims (args[3 ].unwrapToIntList ());
17
- LOG_DEBUG (" stride: " << stride);
18
17
auto padding = util::toDims (args[4 ].unwrapToIntList ());
19
- LOG_DEBUG (" padding: " << padding);
20
18
auto dilation = util::toDims (args[5 ].unwrapToIntList ());
21
- LOG_DEBUG (" dilation: " << dilation);
22
19
bool transposed = args[6 ].unwrapToBool ();
23
20
auto out_padding = util::toDims (args[7 ].unwrapToIntList ());
24
- LOG_DEBUG (" out_padding: " << out_padding);
25
21
int64_t groups = args[8 ].unwrapToInt ();
22
+
23
+ auto dims = in->getDimensions ();
24
+ auto orig_dims = dims;
25
+ LOG_DEBUG (" Original input dims: " << orig_dims);
26
+
27
+ // Expand spatial dims from 1D to 2D if needed
28
+ bool expandDims = (orig_dims.nbDims < 4 );
29
+ if (expandDims) {
30
+ in = addPadding (ctx, n, in, 4 );
31
+ dims = in->getDimensions ();
32
+ }
33
+ if (w.shape .nbDims < 4 ) {
34
+ for (int i = w.shape .nbDims ; i < 4 ; ++i) {
35
+ w.shape .d [i] = 1 ;
36
+ }
37
+ w.shape .nbDims = 4 ;
38
+ w.kernel_shape .nbDims = 2 ;
39
+ w.kernel_shape .d [1 ] = 1 ;
40
+ }
41
+ if (stride.nbDims ==1 ) {
42
+ stride = util::unsqueezeDims (stride, 1 , 1 );
43
+ }
44
+ if (dilation.nbDims ==1 ) {
45
+ dilation = util::unsqueezeDims (dilation, 1 , 1 );
46
+ }
47
+ if (padding.nbDims ==1 ) {
48
+ padding = util::unsqueezeDims (padding, 1 , 0 );
49
+ }
50
+ if (out_padding.nbDims ==1 ) {
51
+ out_padding = util::unsqueezeDims (out_padding, 1 , 0 );
52
+ }
53
+
54
+ LOG_DEBUG (" Input dims: " << dims);
55
+ LOG_DEBUG (" Weights: " << w);
56
+ LOG_DEBUG (" stride: " << stride);
57
+ LOG_DEBUG (" padding: " << padding);
58
+ LOG_DEBUG (" dilation: " << dilation);
59
+ LOG_DEBUG (" out_padding: " << out_padding);
26
60
LOG_DEBUG (" groups: " << groups);
27
61
28
62
nvinfer1::ILayer* new_layer;
@@ -31,12 +65,11 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
31
65
if (args[2 ].IValue ()->isTensor ()) {
32
66
bias = Weights (ctx, args[2 ].unwrapToTensor ());
33
67
} else {
34
- bias = Weights (ctx, torch::zeros (args[ 1 ]. unwrapToTensor (). sizes () [1 ] * groups));
68
+ bias = Weights (ctx, torch::zeros (w. shape . d [1 ] * groups));
35
69
}
36
70
37
71
// shape of deconvolution's weight: [in, out/groups, ...]
38
- auto deconv = ctx->net ->addDeconvolutionNd (
39
- *in, args[1 ].unwrapToTensor ().sizes ()[1 ] * groups, w.kernel_shape , w.data , bias.data );
72
+ auto deconv = ctx->net ->addDeconvolutionNd (*in, w.shape .d [1 ] * groups, w.kernel_shape , w.data , bias.data );
40
73
TRTORCH_CHECK (deconv, " Unable to create deconvolution layer from node: " << *n);
41
74
42
75
deconv->setStrideNd (stride);
@@ -56,11 +89,11 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
56
89
if (args[2 ].IValue ()->isTensor ()) {
57
90
bias = Weights (ctx, args[2 ].unwrapToTensor ());
58
91
} else {
59
- bias = Weights (ctx, torch::zeros (args[ 1 ]. unwrapToTensor (). sizes () [0 ]));
92
+ bias = Weights (ctx, torch::zeros (w. shape . d [0 ]));
60
93
}
61
94
62
95
// shape of convolution's weight: [out, in/groups, ...]
63
- auto conv = ctx->net ->addConvolutionNd (*in, args[ 1 ]. unwrapToTensor (). sizes () [0 ], w.kernel_shape , w.data , bias.data );
96
+ auto conv = ctx->net ->addConvolutionNd (*in, w. shape . d [0 ], w.kernel_shape , w.data , bias.data );
64
97
TRTORCH_CHECK (conv, " Unable to create convolution layer from node: " << *n);
65
98
66
99
conv->setStrideNd (stride);
@@ -71,9 +104,13 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
71
104
conv->setNbGroups (groups);
72
105
new_layer = conv;
73
106
}
107
+
74
108
new_layer->setName (util::node_info (n).c_str ());
75
-
76
- auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], new_layer->getOutput (0 ));
109
+
110
+ // Un-expand spatial dims back to 1D if needed
111
+ auto out = addUnpadding (ctx, n, new_layer->getOutput (0 ), orig_dims.nbDims );
112
+
113
+ ctx->AssociateValueAndTensor (n->outputs ()[0 ], out);
77
114
78
115
LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
79
116
0 commit comments