1
- #include " torch/torch.h"
2
-
1
+ #include " core/conversion/converters/converter_util.h"
3
2
#include " core/conversion/converters/converters.h"
4
3
#include " core/util/prelude.h"
4
+ #include " torch/torch.h"
5
5
6
6
namespace trtorch {
7
7
namespace core {
@@ -14,15 +14,46 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
14
14
auto in = args[0 ].ITensor (); // assumes non-static input Tensor
15
15
auto w = Weights (ctx, args[1 ].unwrapToTensor ());
16
16
auto stride = util::toDims (args[3 ].unwrapToIntList ());
17
- LOG_DEBUG (" stride: " << stride);
18
17
auto padding = util::toDims (args[4 ].unwrapToIntList ());
19
- LOG_DEBUG (" padding: " << padding);
20
18
auto dilation = util::toDims (args[5 ].unwrapToIntList ());
21
- LOG_DEBUG (" dilation: " << dilation);
22
19
bool transposed = args[6 ].unwrapToBool ();
23
20
auto out_padding = util::toDims (args[7 ].unwrapToIntList ());
24
- LOG_DEBUG (" out_padding: " << out_padding);
25
21
int64_t groups = args[8 ].unwrapToInt ();
22
+
23
+ auto dims = in->getDimensions ();
24
+ auto orig_dims = dims;
25
+ LOG_DEBUG (" Original input dims: " << orig_dims);
26
+
27
+ // Expand spatial dims from 1D to 2D if needed
28
+ auto expandDims = addPaddingLayer (ctx, n, in, 4 );
29
+ if (expandDims) {
30
+ auto tensorPtr = expandDims->getOutput (0 );
31
+ assert (tensorPtr);
32
+ dims = tensorPtr->getDimensions ();
33
+ in = tensorPtr;
34
+ }
35
+ if (w.shape .nbDims < 4 ) {
36
+ for (int i = w.shape .nbDims ; i < 4 ; ++i)
37
+ w.shape .d [i] = 1 ;
38
+ w.shape .nbDims = 4 ;
39
+ w.kernel_shape .nbDims = 2 ;
40
+ w.kernel_shape .d [1 ] = 1 ;
41
+ }
42
+ if (stride.nbDims ==1 )
43
+ stride = util::unsqueezeDims (stride, 1 , 1 );
44
+ if (dilation.nbDims ==1 )
45
+ dilation = util::unsqueezeDims (dilation, 1 , 1 );
46
+ if (padding.nbDims ==1 )
47
+ padding = util::unsqueezeDims (padding, 1 , 0 );
48
+ if (out_padding.nbDims ==1 )
49
+ out_padding = util::unsqueezeDims (out_padding, 1 , 0 );
50
+
51
+ LOG_DEBUG (" Input dims: " << dims);
52
+ LOG_DEBUG (" Weights: " << w);
53
+ LOG_DEBUG (" stride: " << stride);
54
+ LOG_DEBUG (" padding: " << padding);
55
+ LOG_DEBUG (" dilation: " << dilation);
56
+ LOG_DEBUG (" out_padding: " << out_padding);
26
57
LOG_DEBUG (" groups: " << groups);
27
58
28
59
nvinfer1::ILayer* new_layer;
@@ -31,12 +62,11 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
31
62
if (args[2 ].IValue ()->isTensor ()) {
32
63
bias = Weights (ctx, args[2 ].unwrapToTensor ());
33
64
} else {
34
- bias = Weights (ctx, torch::zeros (args[ 1 ]. unwrapToTensor (). sizes () [1 ] * groups));
65
+ bias = Weights (ctx, torch::zeros (w. shape . d [1 ] * groups));
35
66
}
36
67
37
68
// shape of deconvolution's weight: [in, out/groups, ...]
38
- auto deconv = ctx->net ->addDeconvolutionNd (
39
- *in, args[1 ].unwrapToTensor ().sizes ()[1 ] * groups, w.kernel_shape , w.data , bias.data );
69
+ auto deconv = ctx->net ->addDeconvolutionNd (*in, w.shape .d [1 ] * groups, w.kernel_shape , w.data , bias.data );
40
70
TRTORCH_CHECK (deconv, " Unable to create deconvolution layer from node: " << *n);
41
71
42
72
deconv->setStrideNd (stride);
@@ -56,11 +86,11 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
56
86
if (args[2 ].IValue ()->isTensor ()) {
57
87
bias = Weights (ctx, args[2 ].unwrapToTensor ());
58
88
} else {
59
- bias = Weights (ctx, torch::zeros (args[ 1 ]. unwrapToTensor (). sizes () [0 ]));
89
+ bias = Weights (ctx, torch::zeros (w. shape . d [0 ]));
60
90
}
61
91
62
92
// shape of convolution's weight: [out, in/groups, ...]
63
- auto conv = ctx->net ->addConvolutionNd (*in, args[ 1 ]. unwrapToTensor (). sizes () [0 ], w.kernel_shape , w.data , bias.data );
93
+ auto conv = ctx->net ->addConvolutionNd (*in, w. shape . d [0 ], w.kernel_shape , w.data , bias.data );
64
94
TRTORCH_CHECK (conv, " Unable to create convolution layer from node: " << *n);
65
95
66
96
conv->setStrideNd (stride);
@@ -73,6 +103,11 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
73
103
}
74
104
new_layer->setName (util::node_info (n).c_str ());
75
105
106
+ if (expandDims) {
107
+ // Un-expand spatial dims back to 1D
108
+ new_layer = addUnpaddingLayer (ctx, n, new_layer->getOutput (0 ), orig_dims.nbDims );
109
+ }
110
+
76
111
auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], new_layer->getOutput (0 ));
77
112
78
113
LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
0 commit comments