Skip to content

Commit 071c1d6

Browse files
committed
add test case for aten::div evaluator
Signed-off-by: inocsin <[email protected]>
1 parent e6205a5 commit 071c1d6

File tree

3 files changed

+46
-2
lines changed

3 files changed

+46
-2
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ auto aten_registrations TRTORCH_UNUSED =
427427
},
428428
EvalOptions().validSchemas({
429429
"aten::div.float(float a, float b) -> (float)",
430-
"aten::div.int(int a, int b) -> (int)",
430+
"aten::div.int(int a, int b) -> (float)",
431431
})})
432432
.evaluator({c10::Symbol::fromQualString("aten::floordiv"),
433433
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {

tests/core/conversion/evaluators/BUILD

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@ evaluator_test(
1111
name = "test_prim_evaluators",
1212
)
1313

14+
evaluator_test(
15+
name = "test_aten_evaluators",
16+
)
17+
1418
test_suite(
1519
name = "evaluator_tests",
1620
tests = [
17-
":test_prim_evaluators"
21+
":test_prim_evaluators",
22+
":test_aten_evaluators"
1823
]
1924
)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "gtest/gtest.h"
4+
#include "tests/util/util.h"
5+
#include "torch/csrc/jit/ir/irparser.h"
6+
7+
TEST(Evaluators, DivIntEvaluatesCorrectly) {
8+
const auto graph = R"IR(
9+
graph():
10+
%1 : int = prim::Constant[value=9]()
11+
%2 : int = prim::Constant[value=4]()
12+
%3 : float = aten::div(%1, %2)
13+
return (%3))IR";
14+
15+
auto g = std::make_shared<torch::jit::Graph>();
16+
torch::jit::parseIR(graph, &*g);
17+
18+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
19+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
20+
21+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
22+
}
23+
24+
TEST(Evaluators, DivFloatEvaluatesCorrectly) {
25+
const auto graph = R"IR(
26+
graph():
27+
%1 : float = prim::Constant[value=9.1]()
28+
%2 : float = prim::Constant[value=4.2]()
29+
%3 : float = aten::div(%1, %2)
30+
return (%3))IR";
31+
32+
auto g = std::make_shared<torch::jit::Graph>();
33+
torch::jit::parseIR(graph, &*g);
34+
35+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
36+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
37+
38+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
39+
}

0 commit comments

Comments
 (0)