Skip to content

Commit df24079

Browse files
committed
Fix error in layer conversion caused by zero/ones tensors of wrong type
1 parent cf54792 commit df24079

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

core/conversion/converters/impl/constant_pad.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
5757
auto shape_gather_out = ctx->net->addShape(*left_gather_out)->getOutput(0);
5858
fill_layer->setInput(0, *shape_gather_out);
5959
fill_layer->setInput(1, *valueTensor);
60-
at::Tensor delta_tensor = torch::zeros(inRank);
60+
at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType()));
6161
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
6262
fill_layer->setInput(2, *deltaTensor);
6363
auto padTensor = fill_layer->getOutput(0);
@@ -69,7 +69,7 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
6969
inDims.d[axis] = padding[padding_index];
7070
auto fill_layer = ctx->net->addFill(inDims, nvinfer1::FillOperation::kLINSPACE);
7171
fill_layer->setInput(1, *valueTensor);
72-
at::Tensor delta_tensor = torch::zeros(inRank);
72+
at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType()));
7373
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
7474
fill_layer->setInput(2, *deltaTensor);
7575
auto padTensor = fill_layer->getOutput(0);
@@ -110,7 +110,7 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
110110
auto shape_gather_out = ctx->net->addShape(*right_gather_out)->getOutput(0);
111111
fill_layer->setInput(0, *shape_gather_out);
112112
fill_layer->setInput(1, *valueTensor);
113-
at::Tensor delta_tensor = torch::zeros(inRank);
113+
at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType()));
114114
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
115115
fill_layer->setInput(2, *deltaTensor);
116116
auto padTensor = fill_layer->getOutput(0);
@@ -122,7 +122,7 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
122122
inDims.d[axis] = padding[padding_index + 1];
123123
auto fill_layer = ctx->net->addFill(inDims, nvinfer1::FillOperation::kLINSPACE);
124124
fill_layer->setInput(1, *valueTensor);
125-
at::Tensor delta_tensor = torch::zeros(inRank);
125+
at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType()));
126126
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
127127
fill_layer->setInput(2, *deltaTensor);
128128
auto padTensor = fill_layer->getOutput(0);

0 commit comments

Comments
 (0)