@@ -57,7 +57,7 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
57
57
auto shape_gather_out = ctx->net ->addShape (*left_gather_out)->getOutput (0 );
58
58
fill_layer->setInput (0 , *shape_gather_out);
59
59
fill_layer->setInput (1 , *valueTensor);
60
- at::Tensor delta_tensor = torch::zeros (inRank);
60
+ at::Tensor delta_tensor = torch::zeros (inRank, util::TRTDataTypeToScalarType (in-> getType ()) );
61
61
auto deltaTensor = tensor_to_const (ctx, delta_tensor);
62
62
fill_layer->setInput (2 , *deltaTensor);
63
63
auto padTensor = fill_layer->getOutput (0 );
@@ -69,7 +69,7 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
69
69
inDims.d [axis] = padding[padding_index];
70
70
auto fill_layer = ctx->net ->addFill (inDims, nvinfer1::FillOperation::kLINSPACE );
71
71
fill_layer->setInput (1 , *valueTensor);
72
- at::Tensor delta_tensor = torch::zeros (inRank);
72
+ at::Tensor delta_tensor = torch::zeros (inRank, util::TRTDataTypeToScalarType (in-> getType ()) );
73
73
auto deltaTensor = tensor_to_const (ctx, delta_tensor);
74
74
fill_layer->setInput (2 , *deltaTensor);
75
75
auto padTensor = fill_layer->getOutput (0 );
@@ -110,7 +110,7 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
110
110
auto shape_gather_out = ctx->net ->addShape (*right_gather_out)->getOutput (0 );
111
111
fill_layer->setInput (0 , *shape_gather_out);
112
112
fill_layer->setInput (1 , *valueTensor);
113
- at::Tensor delta_tensor = torch::zeros (inRank);
113
+ at::Tensor delta_tensor = torch::zeros (inRank, util::TRTDataTypeToScalarType (in-> getType ()) );
114
114
auto deltaTensor = tensor_to_const (ctx, delta_tensor);
115
115
fill_layer->setInput (2 , *deltaTensor);
116
116
auto padTensor = fill_layer->getOutput (0 );
@@ -122,7 +122,7 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
122
122
inDims.d [axis] = padding[padding_index + 1 ];
123
123
auto fill_layer = ctx->net ->addFill (inDims, nvinfer1::FillOperation::kLINSPACE );
124
124
fill_layer->setInput (1 , *valueTensor);
125
- at::Tensor delta_tensor = torch::zeros (inRank);
125
+ at::Tensor delta_tensor = torch::zeros (inRank, util::TRTDataTypeToScalarType (in-> getType ()) );
126
126
auto deltaTensor = tensor_to_const (ctx, delta_tensor);
127
127
fill_layer->setInput (2 , *deltaTensor);
128
128
auto padTensor = fill_layer->getOutput (0 );
0 commit comments