Skip to content

Commit 2f11791

Browse files
committed
tests(//tests/core/lowering): Adding a lowering testing framework
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 55cb1d1 commit 2f11791

File tree

5 files changed

+64
-1
lines changed

5 files changed

+64
-1
lines changed

tests/core/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ test_suite(
22
name = "core_tests",
33
tests = [
44
"//tests/core/conversion:conversion_tests",
5+
"//tests/core/lowering:lowering_tests",
56
],
67
)

tests/core/conversion/evaluators/test_prim_evaluators.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include "tests/util/util.h"
55
#include "torch/csrc/jit/ir/irparser.h"
66

7-
TEST(Evaluators, PrimConstantConvertsCorrectly) {
7+
TEST(Evaluators, PrimConstantEvaluatesCorrectly) {
88
const auto graph = R"IR(
99
graph():
1010
%0 : int = prim::Constant[value=1]()

tests/core/lowering/BUILD

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
load("//tests/core/lowering:lowering_test.bzl", "lowering_test")
2+
3+
config_setting(
4+
name = "use_pre_cxx11_abi",
5+
values = {
6+
"define": "abi=pre_cxx11_abi",
7+
}
8+
)
9+
10+
lowering_test(
11+
name = "test_remove_contiguous_pass",
12+
)
13+
14+
test_suite(
15+
name = "lowering_tests",
16+
tests = [
17+
":test_remove_contiguous_pass"
18+
]
19+
)

tests/core/lowering/lowering_test.bzl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
def lowering_test(name, visibility=None):
2+
native.cc_test(
3+
name = name,
4+
srcs = [name + ".cpp"],
5+
visibility = visibility,
6+
deps = [
7+
"//tests/util",
8+
"//core",
9+
"@googletest//:gtest_main",
10+
] + select({
11+
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
12+
"//conditions:default": ["@libtorch//:libtorch"],
13+
}),
14+
timeout="short"
15+
)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#include <string>
2+
#include "gtest/gtest.h"
3+
#include "tests/util/util.h"
4+
#include "core/compiler.h"
5+
#include "core/lowering/passes/passes.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
9+
TEST(LoweringPasses, RemoveContiguousLowersCorrectly) {
10+
std::string source_graph = R"IR(
11+
graph(%input, %1):
12+
%2 = aten::contiguous(%input, %1)
13+
%3 = foo::bar(%2)
14+
return (%3))IR";
15+
std::string target_graph = R"IR(
16+
graph(%input, %1):
17+
%3 = foo::bar(%input)
18+
return (%3))IR";
19+
20+
auto sg = std::make_shared<torch::jit::Graph>();
21+
torch::jit::parseIR(source_graph, &*sg);
22+
trtorch::core::lowering::passes::RemoveContiguous(sg);
23+
24+
auto tg = std::make_shared<torch::jit::Graph>();
25+
torch::jit::parseIR(target_graph, &*tg);
26+
27+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
28+
}

0 commit comments

Comments
 (0)