Skip to content

Commit 8f9f041

Browse files
committed
tests(aten::zeros): Add a dtype test
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 3744847 commit 8f9f041

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,25 @@ TEST(Evaluators, ZerosEvaluatesCorrectly) {
5454
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {in});
5555
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in});
5656

57+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
58+
}
59+
60+
TEST(Evaluators, ZerosDataTypeEvaluatesCorrectly) {
61+
const auto graph = R"IR(
62+
graph(%x.1 : Tensor):
63+
%2 : int = prim::Constant[value=5]() # :0:0 (Float16)
64+
%3 : None = prim::Constant() # :0:0
65+
%4 : int[] = aten::size(%x.1) # <string>:7:9
66+
%z.1 : Tensor = aten::zeros(%4, %2, %3, %3, %3) # experiments/test_zeros.py:8:12
67+
return (%z.1))IR";
68+
69+
auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA});
70+
71+
auto g = std::make_shared<torch::jit::Graph>();
72+
torch::jit::parseIR(graph, &*g);
73+
74+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {in});
75+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in});
76+
5777
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
5878
}

0 commit comments

Comments
 (0)