Skip to content

Commit 97c8f52

Browse files
committed
test(aten::stack): Added test for aten::stack
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent 6659b44 commit 97c8f52

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

tests/core/converters/BUILD

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ converter_test(
6363
name = "test_select"
6464
)
6565

66+
converter_test(
67+
name = "test_stack"
68+
)
69+
6670
test_suite(
6771
name = "test_converters",
6872
tests = [
@@ -78,7 +82,8 @@ test_suite(
7882
":test_softmax",
7983
":test_unary",
8084
":test_interpolate",
81-
":test_select"
85+
":test_select",
86+
":test_stack"
8287
]
8388
)
8489

tests/core/converters/test_stack.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#include <string>
2+
#include "gtest/gtest.h"
3+
#include "torch/csrc/jit/ir/irparser.h"
4+
#include "tests/util/util.h"
5+
#include "core/compiler.h"
6+
7+
TEST(Converters, ATenStackPureTensorConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor,
10+
%1 : Tensor):
11+
%2 : Tensor[] = prim::ListConstruct(%0, %1)
12+
%3 : int = prim::Constant[value=3]()
13+
%4 : Tensor = aten::stack(%2, %3)
14+
return (%4))IR";
15+
16+
auto g = std::make_shared<torch::jit::Graph>();
17+
torch::jit::parseIR(graph, &*g);
18+
19+
auto in1 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
20+
auto in2 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
21+
22+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
23+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1, in2});
24+
25+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
26+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1, in2});
27+
28+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
29+
}
30+
31+
TEST(Converters, ATenStackDiffTensorConvertsCorrectly) {
32+
const auto graph = R"IR(
33+
graph(%0 : Tensor,
34+
%1 : Float(4, 4, 4)):
35+
%2 : Tensor[] = prim::ListConstruct(%0, %1)
36+
%3 : int = prim::Constant[value=1]()
37+
%4 : Tensor = aten::stack(%2, %3)
38+
return (%4))IR";
39+
40+
auto g = std::make_shared<torch::jit::Graph>();
41+
torch::jit::parseIR(graph, &*g);
42+
43+
auto in1 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
44+
auto in2 = at::randint(1, 10, {4, 4, 4}, {at::kCUDA});
45+
46+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {in2});
47+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in1});
48+
49+
params = trtorch::core::conversion::get_named_params(g->inputs(), {in2});
50+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in1});
51+
52+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
53+
}

0 commit comments

Comments
 (0)