Skip to content

Commit eea344d

Browse files
authored
Merge pull request #423 from guoruoqian/replication_padXd
Support replication_padXd converters
2 parents d8aba0e + d801743 commit eea344d

File tree

4 files changed

+318
-0
lines changed

4 files changed

+318
-0
lines changed

core/conversion/converters/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ cc_library(
4444
"impl/matrix_multiply.cpp",
4545
"impl/pooling.cpp",
4646
"impl/reduce.cpp",
47+
"impl/replication_pad.cpp",
4748
"impl/shuffle.cpp",
4849
"impl/softmax.cpp",
4950
"impl/unary.cpp",
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
#include <ATen/ATen.h>
2+
#include <vector>
3+
#include "NvInfer.h"
4+
#include "core/conversion/converters/converters.h"
5+
#include "core/util/prelude.h"
6+
#include "torch/torch.h"
7+
8+
namespace trtorch {
9+
namespace core {
10+
namespace conversion {
11+
namespace converters {
12+
namespace impl {
13+
namespace {
14+
15+
bool replication_padXd(ConversionCtx* ctx, const torch::jit::Node* n, args& args, int x_dim) {
16+
auto in = args[0].ITensor();
17+
auto inDims = in->getDimensions();
18+
int64_t inRank = inDims.nbDims;
19+
auto padding = args[1].unwrapToIntList().vec();
20+
if (padding.size() == 1) {
21+
for (int64_t i = 0; i < x_dim * 2 - 1; i++)
22+
padding.push_back(padding[0]);
23+
}
24+
if (inRank == 3) {
25+
TRTORCH_CHECK(padding.size() == 2, "3D tensors expect 2 values for padding");
26+
} else if (inRank == 4) {
27+
TRTORCH_CHECK(padding.size() == 4, "4D tensors expect 4 values for padding");
28+
} else if (inRank == 5) {
29+
TRTORCH_CHECK(padding.size() == 6, "5D tensors expect 6 values for padding");
30+
} else {
31+
TRTORCH_THROW_ERROR("Only 3D, 4D, 5D padding with non-constant padding are supported for now");
32+
}
33+
34+
std::vector<nvinfer1::ITensor*> tensors_vec;
35+
// input: (N, C, D_in, H_in, W_in).
36+
// padding: (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back)
37+
// When axis is inRank - 1, making W_out = W_in + padding_left + padding_right.
38+
// When axis is inRank - 2, making H_out = H_in + padding_top + padding_bottom.
39+
// When axis is inRank - 1, making D_out = D_in + padding_front + padding_back.
40+
for (int64_t i = 0; i < int(padding.size() / 2); i++) {
41+
int64_t axis = inRank - (i + 1); // axis = {inRank - 1, inRank - 2, inRank - 3}
42+
int64_t padding_index = i * 2;
43+
44+
if (padding[padding_index] > 0) { // left/top/front padding value
45+
tensors_vec.clear();
46+
at::Tensor left_indices = torch::tensor({0}, torch::kInt32);
47+
auto indicesTensor = tensor_to_const(ctx, left_indices);
48+
auto left_gather_layer = ctx->net->addGather(*in, *indicesTensor, axis);
49+
auto left_gather_out = left_gather_layer->getOutput(0);
50+
for (int i = 0; i < padding[padding_index]; i++) {
51+
tensors_vec.push_back(left_gather_out);
52+
}
53+
tensors_vec.push_back(in);
54+
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
55+
concat_layer->setAxis(axis);
56+
in = concat_layer->getOutput(0);
57+
inDims = in->getDimensions();
58+
}
59+
60+
if (padding[padding_index + 1] > 0) { // right/bottom/back padding value
61+
tensors_vec.clear();
62+
tensors_vec.push_back(in);
63+
64+
nvinfer1::ITensor* indicesTensor = NULL;
65+
if (inDims.d[axis] == -1) {
66+
auto shapeTensor = ctx->net->addShape(*in)->getOutput(0);
67+
at::Tensor dimValue = torch::tensor({axis}, torch::kInt32);
68+
auto dimTensor = tensor_to_const(ctx, dimValue);
69+
indicesTensor = ctx->net->addGather(*shapeTensor, *dimTensor, 0)->getOutput(0);
70+
} else {
71+
auto indices = torch::tensor({inDims.d[axis] - 1}, torch::kInt32);
72+
indicesTensor = tensor_to_const(ctx, indices);
73+
}
74+
auto right_gather_layer = ctx->net->addGather(*in, *indicesTensor, axis);
75+
auto right_gather_out = right_gather_layer->getOutput(0);
76+
77+
for (int i = 0; i < padding[padding_index + 1]; i++) {
78+
tensors_vec.push_back(right_gather_out);
79+
}
80+
81+
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
82+
concat_layer->setAxis(axis);
83+
in = concat_layer->getOutput(0);
84+
inDims = in->getDimensions();
85+
}
86+
}
87+
88+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
89+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
90+
91+
return true;
92+
}
93+
94+
auto replication_pad_registrations TRTORCH_UNUSED =
95+
RegisterNodeConversionPatterns()
96+
.pattern({"aten::replication_pad1d(Tensor self, int[2] padding) -> (Tensor)",
97+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
98+
replication_padXd(ctx, n, args, 1);
99+
return true;
100+
}})
101+
.pattern({"aten::replication_pad2d(Tensor self, int[4] padding) -> (Tensor)",
102+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
103+
replication_padXd(ctx, n, args, 2);
104+
return true;
105+
}})
106+
.pattern({"aten::replication_pad3d(Tensor self, int[6] padding) -> (Tensor)",
107+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
108+
replication_padXd(ctx, n, args, 3);
109+
return true;
110+
}});
111+
112+
} // namespace
113+
} // namespace impl
114+
} // namespace converters
115+
} // namespace conversion
116+
} // namespace core
117+
} // namespace trtorch

tests/core/conversion/converters/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ converter_test(
4747
name = "test_reduce"
4848
)
4949

50+
converter_test(
51+
name = "test_replication_pad"
52+
)
53+
5054
converter_test(
5155
name = "test_shuffle"
5256
)
@@ -99,6 +103,7 @@ test_suite(
99103
":test_matrix_multiply",
100104
":test_pooling",
101105
":test_reduce",
106+
":test_replication_pad",
102107
":test_shuffle",
103108
":test_softmax",
104109
":test_unary",
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
#include <iostream>
2+
#include <string>
3+
#include "core/compiler.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
8+
TEST(Converters, ATenReplication_pad1dTensorConvertsCorrectly) {
9+
const auto graph = R"IR(
10+
graph(%0 : Tensor):
11+
%1 : int[] = prim::Constant[value=[2, 3]]()
12+
%2 : Tensor = aten::replication_pad1d(%0, %1)
13+
return (%2))IR";
14+
15+
auto g = std::make_shared<torch::jit::Graph>();
16+
torch::jit::parseIR(graph, g.get());
17+
18+
auto in1 = at::randint(1, 10, {1, 3, 4}, {at::kCUDA});
19+
20+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
21+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1});
22+
23+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
24+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1});
25+
26+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
27+
}
28+
29+
TEST(Converters, ATenReplication_pad1dRightZeroTensorConvertsCorrectly) {
30+
const auto graph = R"IR(
31+
graph(%0 : Tensor):
32+
%1 : int[] = prim::Constant[value=[2, 0]]()
33+
%2 : Tensor = aten::replication_pad1d(%0, %1)
34+
return (%2))IR";
35+
36+
auto g = std::make_shared<torch::jit::Graph>();
37+
torch::jit::parseIR(graph, g.get());
38+
39+
auto in1 = at::randint(1, 10, {1, 3, 4}, {at::kCUDA});
40+
41+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
42+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1});
43+
44+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
45+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1});
46+
47+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
48+
}
49+
50+
TEST(Converters, ATenReplication_pad1dTensorConvertsCorrectlyWithDynamic) {
51+
const auto graph = R"IR(
52+
graph(%0 : Tensor):
53+
%1 : int[] = prim::Constant[value=[2, 3]]()
54+
%2 : Tensor = aten::replication_pad1d(%0, %1)
55+
return (%2))IR";
56+
57+
auto g = std::make_shared<torch::jit::Graph>();
58+
torch::jit::parseIR(graph, g.get());
59+
60+
auto in1 = at::randint(1, 10, {1, 3, 4}, {at::kCUDA});
61+
62+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
63+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1});
64+
65+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
66+
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {in1});
67+
68+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
69+
}
70+
71+
TEST(Converters, ATenReplication_pad2dTensorConvertsCorrectly) {
72+
const auto graph = R"IR(
73+
graph(%0 : Tensor):
74+
%1 : int[] = prim::Constant[value=[2, 3, 2, 3]]()
75+
%2 : Tensor = aten::replication_pad2d(%0, %1)
76+
return (%2))IR";
77+
78+
auto g = std::make_shared<torch::jit::Graph>();
79+
torch::jit::parseIR(graph, g.get());
80+
81+
auto in1 = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA});
82+
83+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
84+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1});
85+
86+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
87+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1});
88+
89+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
90+
}
91+
92+
TEST(Converters, ATenReplication_pad2dRightBottomZeroTensorConvertsCorrectly) {
93+
const auto graph = R"IR(
94+
graph(%0 : Tensor):
95+
%1 : int[] = prim::Constant[value=[2, 0, 2, 0]]()
96+
%2 : Tensor = aten::replication_pad2d(%0, %1)
97+
return (%2))IR";
98+
99+
auto g = std::make_shared<torch::jit::Graph>();
100+
torch::jit::parseIR(graph, g.get());
101+
102+
auto in1 = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA});
103+
104+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
105+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1});
106+
107+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
108+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1});
109+
110+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
111+
}
112+
113+
TEST(Converters, ATenReplication_pad2dTensorConvertsCorrectlyWithDynamic) {
114+
const auto graph = R"IR(
115+
graph(%0 : Tensor):
116+
%1 : int[] = prim::Constant[value=[2, 3, 2, 3]]()
117+
%2 : Tensor = aten::replication_pad2d(%0, %1)
118+
return (%2))IR";
119+
120+
auto g = std::make_shared<torch::jit::Graph>();
121+
torch::jit::parseIR(graph, g.get());
122+
123+
auto in1 = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA});
124+
125+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
126+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1});
127+
128+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
129+
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {in1});
130+
131+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
132+
}
133+
134+
TEST(Converters, ATenReplication_pad3dTensorConvertsCorrectly) {
135+
const auto graph = R"IR(
136+
graph(%0 : Tensor):
137+
%1 : int[] = prim::Constant[value=[2, 3, 2, 3, 1, 4]]()
138+
%2 : Tensor = aten::replication_pad3d(%0, %1)
139+
return (%2))IR";
140+
141+
auto g = std::make_shared<torch::jit::Graph>();
142+
torch::jit::parseIR(graph, g.get());
143+
144+
auto in1 = at::randint(1, 10, {1, 3, 4, 5, 3}, {at::kCUDA});
145+
146+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
147+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1});
148+
149+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
150+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1});
151+
152+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
153+
}
154+
155+
TEST(Converters, ATenReplication_pad3dRightBottomBackZeroTensorConvertsCorrectly) {
156+
const auto graph = R"IR(
157+
graph(%0 : Tensor):
158+
%1 : int[] = prim::Constant[value=[2, 0, 2, 0, 1, 0]]()
159+
%2 : Tensor = aten::replication_pad3d(%0, %1)
160+
return (%2))IR";
161+
162+
auto g = std::make_shared<torch::jit::Graph>();
163+
torch::jit::parseIR(graph, g.get());
164+
165+
auto in1 = at::randint(1, 10, {1, 3, 4, 5, 3}, {at::kCUDA});
166+
167+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
168+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1});
169+
170+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
171+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1});
172+
173+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
174+
}
175+
176+
TEST(Converters, ATenReplication_pad3dTensorConvertsCorrectlyWithDynamic) {
177+
const auto graph = R"IR(
178+
graph(%0 : Tensor):
179+
%1 : int[] = prim::Constant[value=[2, 3, 2, 3, 1, 4]]()
180+
%2 : Tensor = aten::replication_pad3d(%0, %1)
181+
return (%2))IR";
182+
183+
auto g = std::make_shared<torch::jit::Graph>();
184+
torch::jit::parseIR(graph, g.get());
185+
186+
auto in1 = at::randint(1, 10, {1, 3, 4, 5, 3}, {at::kCUDA});
187+
188+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
189+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1});
190+
191+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
192+
auto trt_results = trtorch::tests::util::RunGraphEngineDynamic(g, params, {in1});
193+
194+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
195+
}

0 commit comments

Comments
 (0)