Skip to content

Commit b696497

Browse files
committed
add comments
Signed-off-by: uni19 <[email protected]>
1 parent e994354 commit b696497

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

core/conversion/converters/impl/conv_deconv.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().patter
3939
bias = Weights(ctx, torch::zeros(args[1].unwrapToTensor().sizes()[1] * groups));
4040
}
4141

42-
auto deconv = ctx->net->addDeconvolutionNd(*in, w.num_input_maps * groups, w.kernel_shape, w.data, bias.data);
42+
// shape of deconvolution's weight: [in, out/groups, ...]
43+
auto deconv = ctx->net->addDeconvolutionNd(
44+
*in, args[1].unwrapToTensor().sizes()[1] * groups, w.kernel_shape, w.data, bias.data);
4345
TRTORCH_CHECK(deconv, "Unable to create deconvolution layer from node: " << *n);
4446

4547
deconv->setStrideNd(stride);
@@ -62,7 +64,9 @@ auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().patter
6264
bias = Weights(ctx, torch::zeros(args[1].unwrapToTensor().sizes()[0]));
6365
}
6466

65-
auto conv = ctx->net->addConvolutionNd(*in, w.num_output_maps, w.kernel_shape, w.data, bias.data);
67+
// shape of convolution's weight: [out, in/groups, ...]
68+
auto conv =
69+
ctx->net->addConvolutionNd(*in, args[1].unwrapToTensor().sizes()[0], w.kernel_shape, w.data, bias.data);
6670
TRTORCH_CHECK(conv, "Unable to create convolution layer from node: " << *n);
6771

6872
conv->setStrideNd(stride);

0 commit comments

Comments
 (0)