@@ -21,7 +21,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
21
21
auto padding = args[1 ].unwrapToIntList ().vec ();
22
22
int64_t padSize = padding.size ();
23
23
auto value = args[2 ].unwrapToScalar ().to <float >();
24
-
24
+ at::Tensor value_tensor = torch::tensor (value, util::TRTDataTypeToScalarType (in->getType ()));
25
+ auto valueTensor = tensor_to_const (ctx, value_tensor);
25
26
TORCHTRT_CHECK (padSize % 2 == 0 , " Length of pad must be even but instead it equals " << padSize);
26
27
27
28
int64_t l_pad = padSize / 2 ;
@@ -55,8 +56,6 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
55
56
auto fill_layer = ctx->net ->addFill (nvinfer1::Dims{1 , {1 }}, nvinfer1::FillOperation::kLINSPACE );
56
57
auto shape_gather_out = ctx->net ->addShape (*left_gather_out)->getOutput (0 );
57
58
fill_layer->setInput (0 , *shape_gather_out);
58
- at::Tensor value_tensor = torch::tensor (value, torch::kFloat32 );
59
- auto valueTensor = tensor_to_const (ctx, value_tensor);
60
59
fill_layer->setInput (1 , *valueTensor);
61
60
at::Tensor delta_tensor = torch::zeros (inRank);
62
61
auto deltaTensor = tensor_to_const (ctx, delta_tensor);
@@ -69,8 +68,6 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
69
68
} else {
70
69
inDims.d [axis] = padding[padding_index];
71
70
auto fill_layer = ctx->net ->addFill (inDims, nvinfer1::FillOperation::kLINSPACE );
72
- at::Tensor value_tensor = torch::tensor (value, torch::kFloat32 );
73
- auto valueTensor = tensor_to_const (ctx, value_tensor);
74
71
fill_layer->setInput (1 , *valueTensor);
75
72
at::Tensor delta_tensor = torch::zeros (inRank);
76
73
auto deltaTensor = tensor_to_const (ctx, delta_tensor);
@@ -111,9 +108,7 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
111
108
// fill the right_gather_out with value
112
109
auto fill_layer = ctx->net ->addFill (nvinfer1::Dims{1 , {1 }}, nvinfer1::FillOperation::kLINSPACE );
113
110
auto shape_gather_out = ctx->net ->addShape (*right_gather_out)->getOutput (0 );
114
- fill_layer->setInput (0 , *shape_gather_out);
115
- at::Tensor value_tensor = torch::tensor (value, torch::kFloat32 );
116
- auto valueTensor = tensor_to_const (ctx, value_tensor);
111
+ fill_layer->setInput (0 , *shape_gather_out);
117
112
fill_layer->setInput (1 , *valueTensor);
118
113
at::Tensor delta_tensor = torch::zeros (inRank);
119
114
auto deltaTensor = tensor_to_const (ctx, delta_tensor);
@@ -126,8 +121,6 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
126
121
} else {
127
122
inDims.d [axis] = padding[padding_index + 1 ];
128
123
auto fill_layer = ctx->net ->addFill (inDims, nvinfer1::FillOperation::kLINSPACE );
129
- at::Tensor value_tensor = torch::tensor (value, torch::kFloat32 );
130
- auto valueTensor = tensor_to_const (ctx, value_tensor);
131
124
fill_layer->setInput (1 , *valueTensor);
132
125
at::Tensor delta_tensor = torch::zeros (inRank);
133
126
auto deltaTensor = tensor_to_const (ctx, delta_tensor);
0 commit comments