@@ -39,7 +39,9 @@ auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().patter
39
39
bias = Weights (ctx, torch::zeros (args[1 ].unwrapToTensor ().sizes ()[1 ] * groups));
40
40
}
41
41
42
- auto deconv = ctx->net ->addDeconvolutionNd (*in, w.num_input_maps * groups, w.kernel_shape , w.data , bias.data );
42
+ // shape of deconvolution's weight: [in, out/groups, ...]
43
+ auto deconv = ctx->net ->addDeconvolutionNd (
44
+ *in, args[1 ].unwrapToTensor ().sizes ()[1 ] * groups, w.kernel_shape , w.data , bias.data );
43
45
TRTORCH_CHECK (deconv, " Unable to create deconvolution layer from node: " << *n);
44
46
45
47
deconv->setStrideNd (stride);
@@ -62,7 +64,9 @@ auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().patter
62
64
bias = Weights (ctx, torch::zeros (args[1 ].unwrapToTensor ().sizes ()[0 ]));
63
65
}
64
66
65
- auto conv = ctx->net ->addConvolutionNd (*in, w.num_output_maps , w.kernel_shape , w.data , bias.data );
67
+ // shape of convolution's weight: [out, in/groups, ...]
68
+ auto conv =
69
+ ctx->net ->addConvolutionNd (*in, args[1 ].unwrapToTensor ().sizes ()[0 ], w.kernel_shape , w.data , bias.data );
66
70
TRTORCH_CHECK (conv, " Unable to create convolution layer from node: " << *n);
67
71
68
72
conv->setStrideNd (stride);
0 commit comments