Skip to content

Commit fc04d4a

Browse files
authored
Merge pull request #1108 from inocsin/fix_aten_to
[fix]: fix bug in aten::to, when network only have aten::to layer wil…
2 parents 377547e + f69cfc4 commit fc04d4a

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

tests/core/conversion/converters/test_cast.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,36 @@ TEST(Converters, ATenBoolToINT32TensorConvertsCorrectly) {
135135
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
136136
}
137137

138+
139+
TEST(Converters, ATenToSingleConvertsCorrectly) {
140+
const auto graph = R"IR(
141+
graph(%y.1 : Tensor):
142+
%4 : int = prim::Constant[value=6]()
143+
%5 : bool = prim::Constant[value=0]()
144+
%6 : None = prim::Constant()
145+
%y0.1 : Tensor = aten::to(%y.1, %4, %5, %5, %6)
146+
return (%y0.1))IR";
147+
148+
auto g = std::make_shared<torch::jit::Graph>();
149+
150+
torch::jit::parseIR(graph, &*g);
151+
152+
auto in = at::randint(1, 10, {3}, {at::kCUDA});
153+
154+
auto jit_in = at::clone(in);
155+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
156+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
157+
158+
auto trt_in = at::clone(in);
159+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
160+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
161+
162+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
163+
ASSERT_TRUE(jit_results[0].scalar_type() == trt.scalar_type());
164+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
165+
}
166+
167+
138168
TEST(Converters, ATenTypeAsConvertsCorrectly) {
139169
const auto graph = R"IR(
140170
graph(%0 : Tensor,

0 commit comments

Comments
 (0)