Skip to content

Commit cf54792

Browse files
committed
Checkpoint, initial test and implementation
1 parent 61adecf commit cf54792

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

core/conversion/converters/impl/constant_pad.cpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
2121
auto padding = args[1].unwrapToIntList().vec();
2222
int64_t padSize = padding.size();
2323
auto value = args[2].unwrapToScalar().to<float>();
24-
24+
at::Tensor value_tensor = torch::tensor(value, util::TRTDataTypeToScalarType(in->getType()));
25+
auto valueTensor = tensor_to_const(ctx, value_tensor);
2526
TORCHTRT_CHECK(padSize % 2 == 0, "Length of pad must be even but instead it equals " << padSize);
2627

2728
int64_t l_pad = padSize / 2;
@@ -55,8 +56,6 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
5556
auto fill_layer = ctx->net->addFill(nvinfer1::Dims{1, {1}}, nvinfer1::FillOperation::kLINSPACE);
5657
auto shape_gather_out = ctx->net->addShape(*left_gather_out)->getOutput(0);
5758
fill_layer->setInput(0, *shape_gather_out);
58-
at::Tensor value_tensor = torch::tensor(value, torch::kFloat32);
59-
auto valueTensor = tensor_to_const(ctx, value_tensor);
6059
fill_layer->setInput(1, *valueTensor);
6160
at::Tensor delta_tensor = torch::zeros(inRank);
6261
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
@@ -69,8 +68,6 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
6968
} else {
7069
inDims.d[axis] = padding[padding_index];
7170
auto fill_layer = ctx->net->addFill(inDims, nvinfer1::FillOperation::kLINSPACE);
72-
at::Tensor value_tensor = torch::tensor(value, torch::kFloat32);
73-
auto valueTensor = tensor_to_const(ctx, value_tensor);
7471
fill_layer->setInput(1, *valueTensor);
7572
at::Tensor delta_tensor = torch::zeros(inRank);
7673
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
@@ -111,9 +108,7 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
111108
// fill the right_gather_out with value
112109
auto fill_layer = ctx->net->addFill(nvinfer1::Dims{1, {1}}, nvinfer1::FillOperation::kLINSPACE);
113110
auto shape_gather_out = ctx->net->addShape(*right_gather_out)->getOutput(0);
114-
fill_layer->setInput(0, *shape_gather_out);
115-
at::Tensor value_tensor = torch::tensor(value, torch::kFloat32);
116-
auto valueTensor = tensor_to_const(ctx, value_tensor);
111+
fill_layer->setInput(0, *shape_gather_out);
117112
fill_layer->setInput(1, *valueTensor);
118113
at::Tensor delta_tensor = torch::zeros(inRank);
119114
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
@@ -126,8 +121,6 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
126121
} else {
127122
inDims.d[axis] = padding[padding_index + 1];
128123
auto fill_layer = ctx->net->addFill(inDims, nvinfer1::FillOperation::kLINSPACE);
129-
at::Tensor value_tensor = torch::tensor(value, torch::kFloat32);
130-
auto valueTensor = tensor_to_const(ctx, value_tensor);
131124
fill_layer->setInput(1, *valueTensor);
132125
at::Tensor delta_tensor = torch::zeros(inRank);
133126
auto deltaTensor = tensor_to_const(ctx, delta_tensor);

tests/core/conversion/converters/test_constant_pad.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,29 @@ TEST(Converters, ATenConstantPad1dTensorConvertsCorrectly) {
2828
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
2929
}
3030

31+
TEST(Converters, ATenConstantPad1dIntTensorConvertsCorrectly) {
32+
const auto graph = R"IR(
33+
graph(%0 : Tensor):
34+
%1 : int[] = prim::Constant[value=[2, 3]]()
35+
%2 : Scalar = prim::Constant[value=2]()
36+
%3 : Tensor = aten::constant_pad_nd(%0, %1, %2)
37+
return (%3))IR";
38+
39+
auto g = std::make_shared<torch::jit::Graph>();
40+
torch::jit::parseIR(graph, g.get());
41+
42+
auto in1 = at::randint(1, 10, {1, 3, 4}, {at::kCUDA}).toType(at::kInt);
43+
44+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
45+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1});
46+
47+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
48+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1});
49+
50+
ASSERT_TRUE(
51+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
52+
}
53+
3154
TEST(Converters, ATenConstantPad1dRightZeroTensorConvertsCorrectly) {
3255
const auto graph = R"IR(
3356
graph(%0 : Tensor):

0 commit comments

Comments
 (0)