Skip to content

Commit 770b5a2

Browse files
committed
resolve conflicts
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent fee7ba3 commit 770b5a2

File tree

1 file changed

+108
-5
lines changed

1 file changed

+108
-5
lines changed

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 108 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ TEST(Evaluators, DivIntEvaluatesCorrectly) {
1313
return (%3))IR";
1414

1515
auto g = std::make_shared<torch::jit::Graph>();
16-
torch::jit::parseIR(graph, &*g);
16+
torch::jit::parseIR(graph, g.get());
1717

1818
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
1919
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
@@ -30,7 +30,7 @@ TEST(Evaluators, DivFloatEvaluatesCorrectly) {
3030
return (%3))IR";
3131

3232
auto g = std::make_shared<torch::jit::Graph>();
33-
torch::jit::parseIR(graph, &*g);
33+
torch::jit::parseIR(graph, g.get());
3434

3535
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
3636
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
@@ -49,7 +49,7 @@ TEST(Evaluators, ZerosEvaluatesCorrectly) {
4949
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
5050

5151
auto g = std::make_shared<torch::jit::Graph>();
52-
torch::jit::parseIR(graph, &*g);
52+
torch::jit::parseIR(graph, g.get());
5353

5454
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {in});
5555
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in});
@@ -69,15 +69,118 @@ TEST(Evaluators, ZerosDataTypeEvaluatesCorrectly) {
6969
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
7070

7171
auto g = std::make_shared<torch::jit::Graph>();
72-
torch::jit::parseIR(graph, &*g);
72+
torch::jit::parseIR(graph, g.get());
7373

7474
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {in});
7575
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in});
7676

7777
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
7878
}
7979

80-
TEST(Evaluators, SizeConvertsCorrectly) {
80+
TEST(Evaluators, ATenArangeIntEvaluatesCorrectly) {
81+
const auto graph = R"IR(
82+
graph():
83+
%0 : int = prim::Constant[value=51]()
84+
%1 : None = prim::Constant()
85+
%2 : Tensor = aten::arange(%0, %1, %1, %1, %1)
86+
return (%2))IR";
87+
88+
auto g = std::make_shared<torch::jit::Graph>();
89+
torch::jit::parseIR(graph, &*g);
90+
91+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
92+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
93+
94+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
95+
}
96+
97+
TEST(Evaluators, ATenArangeFloatEvaluatesCorrectly) {
98+
const auto graph = R"IR(
99+
graph():
100+
%0 : float = prim::Constant[value=51.2]()
101+
%1 : None = prim::Constant()
102+
%2 : Tensor = aten::arange(%0, %1, %1, %1, %1)
103+
return (%2))IR";
104+
105+
auto g = std::make_shared<torch::jit::Graph>();
106+
torch::jit::parseIR(graph, &*g);
107+
108+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
109+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
110+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
111+
}
112+
113+
TEST(Evaluators, ATenArangeStartEndIntEvaluatesCorrectly) {
114+
const auto graph = R"IR(
115+
graph():
116+
%0 : int = prim::Constant[value=1]()
117+
%1 : int = prim::Constant[value=51]()
118+
%2 : None = prim::Constant()
119+
%3 : Tensor = aten::arange(%0, %1, %2, %2, %2, %2)
120+
return (%3))IR";
121+
122+
auto g = std::make_shared<torch::jit::Graph>();
123+
torch::jit::parseIR(graph, &*g);
124+
125+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
126+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
127+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
128+
}
129+
130+
TEST(Evaluators, ATenArangeStartEndFloatEvaluatesCorrectly) {
131+
const auto graph = R"IR(
132+
graph():
133+
%0 : float = prim::Constant[value=1.5]()
134+
%1 : float = prim::Constant[value=51.2]()
135+
%2 : None = prim::Constant()
136+
%3 : Tensor = aten::arange(%0, %1, %2, %2, %2, %2)
137+
return (%3))IR";
138+
139+
auto g = std::make_shared<torch::jit::Graph>();
140+
torch::jit::parseIR(graph, &*g);
141+
142+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
143+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
144+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
145+
}
146+
147+
TEST(Evaluators, ATenArangeStartEndStepIntEvaluatesCorrectly) {
148+
const auto graph = R"IR(
149+
graph():
150+
%0 : int = prim::Constant[value=1]()
151+
%1 : int = prim::Constant[value=51]()
152+
%2 : int = prim::Constant[value=1]()
153+
%3 : None = prim::Constant()
154+
%4 : Tensor = aten::arange(%0, %1, %2, %3, %3, %3, %3)
155+
return (%4))IR";
156+
157+
auto g = std::make_shared<torch::jit::Graph>();
158+
torch::jit::parseIR(graph, &*g);
159+
160+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
161+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
162+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
163+
}
164+
165+
TEST(Evaluators, ATenArangeStartEndStepFloatEvaluatesCorrectly) {
166+
const auto graph = R"IR(
167+
graph():
168+
%0 : float = prim::Constant[value=1.2]()
169+
%1 : float = prim::Constant[value=51.6]()
170+
%2 : float = prim::Constant[value=1.5]()
171+
%3 : None = prim::Constant()
172+
%4 : Tensor = aten::arange(%0, %1, %2, %3, %3, %3, %3)
173+
return (%4))IR";
174+
175+
auto g = std::make_shared<torch::jit::Graph>();
176+
torch::jit::parseIR(graph, &*g);
177+
178+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
179+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
180+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
181+
}
182+
183+
TEST(Evaluators, ATenSizeNegativeConvertsCorrectly) {
81184
const auto graph = R"IR(
82185
graph(%0 : Tensor):
83186
%1 : int = prim::Constant[value=-1]()

0 commit comments

Comments
 (0)