Skip to content

Commit a0848b1

Browse files
committed
test(tests/core/converters): Added test for plugins
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent 2e29e4e commit a0848b1

File tree

1 file changed

+93
-3
lines changed

1 file changed

+93
-3
lines changed

tests/core/converters/test_interpolate.cpp

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ TEST(Converters, ATenUpsampleNearest3dConvertsCorrectly3dOutputSize) {
9191
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
9292
}
9393

94-
TEST(Converters, ATenUpsampleLinear1dConvertsCorrectly) {
94+
TEST(Converters, ATenUpsampleLinear1dConvertsCorrectlyWithAlignCorners) {
9595
const auto graph = R"IR(
9696
graph(%0 : Tensor):
9797
%2 : int = prim::Constant[value=10]()
@@ -121,7 +121,37 @@ TEST(Converters, ATenUpsampleLinear1dConvertsCorrectly) {
121121
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
122122
}
123123

124-
TEST(Converters, ATenUpsampleBilinear2dConvertsCorrectly2dOutputSize) {
124+
TEST(Converters, ATenUpsampleLinear1dConvertsCorrectlyWithoutAlignCorners) {
125+
const auto graph = R"IR(
126+
graph(%0 : Tensor):
127+
%2 : int = prim::Constant[value=10]()
128+
%3 : int[] = prim::ListConstruct(%2)
129+
%4 : bool = prim::Constant[value=0]()
130+
%5 : None = prim::Constant()
131+
%6 : Tensor = aten::upsample_linear1d(%0, %3, %4, %5)
132+
return (%6))IR";
133+
134+
auto g = std::make_shared<torch::jit::Graph>();
135+
136+
torch::jit::parseIR(graph, &*g);
137+
138+
// Input Tensor needs to be 3D for TensorRT upsample_linear1d
139+
auto in = at::randint(1, 10, {10, 2, 2}, {at::kCUDA});
140+
141+
auto jit_in = at::clone(in);
142+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
143+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
144+
145+
auto trt_in = at::clone(in);
146+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
147+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
148+
149+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
150+
151+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
152+
}
153+
154+
TEST(Converters, ATenUpsampleBilinear2dConvertsCorrectly2dOutputSizeWithAlignCorners) {
125155
const auto graph = R"IR(
126156
graph(%0 : Tensor):
127157
%2 : int = prim::Constant[value=10]()
@@ -151,7 +181,37 @@ TEST(Converters, ATenUpsampleBilinear2dConvertsCorrectly2dOutputSize) {
151181
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
152182
}
153183

154-
TEST(Converters, ATenUpsampleTrilinear3dConvertsCorrectly3dOutputSize) {
184+
TEST(Converters, ATenUpsampleBilinear2dConvertsCorrectly2dOutputSizeWithoutAlignCorners) {
185+
const auto graph = R"IR(
186+
graph(%0 : Tensor):
187+
%2 : int = prim::Constant[value=10]()
188+
%3 : int[] = prim::ListConstruct(%2, %2)
189+
%4 : bool = prim::Constant[value=0]()
190+
%5 : None = prim::Constant()
191+
%6 : Tensor = aten::upsample_bilinear2d(%0, %3, %4, %5, %5)
192+
return (%6))IR";
193+
194+
auto g = std::make_shared<torch::jit::Graph>();
195+
196+
torch::jit::parseIR(graph, &*g);
197+
198+
// Input Tensor needs to be 4D for TensorRT upsample_bilinear2d
199+
auto in = at::randint(1, 10, {10, 2, 2, 2}, {at::kCUDA});
200+
201+
auto jit_in = at::clone(in);
202+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
203+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
204+
205+
auto trt_in = at::clone(in);
206+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
207+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
208+
209+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
210+
211+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
212+
}
213+
214+
TEST(Converters, ATenUpsampleTrilinear3dConvertsCorrectly3dOutputSizeWithAlignCorners) {
155215
const auto graph = R"IR(
156216
graph(%0 : Tensor):
157217
%2 : int = prim::Constant[value=10]()
@@ -178,5 +238,35 @@ TEST(Converters, ATenUpsampleTrilinear3dConvertsCorrectly3dOutputSize) {
178238

179239
auto trt = trt_results[0].reshape(jit_results[0].sizes());
180240

241+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
242+
}
243+
244+
TEST(Converters, ATenUpsampleTrilinear3dConvertsCorrectly3dOutputSizeWithoutAlignCorners) {
245+
const auto graph = R"IR(
246+
graph(%0 : Tensor):
247+
%2 : int = prim::Constant[value=10]()
248+
%3 : int[] = prim::ListConstruct(%2, %2, %2)
249+
%4 : bool = prim::Constant[value=0]()
250+
%5 : None = prim::Constant()
251+
%6 : Tensor = aten::upsample_trilinear3d(%0, %3, %4, %5, %5, %5)
252+
return (%6))IR";
253+
254+
auto g = std::make_shared<torch::jit::Graph>();
255+
256+
torch::jit::parseIR(graph, &*g);
257+
258+
// Input Tensor needs to be 5D for TensorRT upsample_trilinear3d
259+
auto in = at::randint(1, 10, {10, 2, 2, 2, 2}, {at::kCUDA});
260+
261+
auto jit_in = at::clone(in);
262+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
263+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
264+
265+
auto trt_in = at::clone(in);
266+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
267+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
268+
269+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
270+
181271
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
182272
}

0 commit comments

Comments
 (0)