File tree Expand file tree Collapse file tree 5 files changed +64
-1
lines changed Expand file tree Collapse file tree 5 files changed +64
-1
lines changed Original file line number Diff line number Diff line change @@ -2,5 +2,6 @@ test_suite(
2
2
name = "core_tests" ,
3
3
tests = [
4
4
"//tests/core/conversion:conversion_tests" ,
5
+ "//tests/core/lowering:lowering_tests" ,
5
6
],
6
7
)
Original file line number Diff line number Diff line change 4
4
#include " tests/util/util.h"
5
5
#include " torch/csrc/jit/ir/irparser.h"
6
6
7
- TEST (Evaluators, PrimConstantConvertsCorrectly ) {
7
+ TEST (Evaluators, PrimConstantEvaluatesCorrectly ) {
8
8
const auto graph = R"IR(
9
9
graph():
10
10
%0 : int = prim::Constant[value=1]()
Original file line number Diff line number Diff line change
1
+ load ("//tests/core/lowering:lowering_test.bzl" , "lowering_test" )
2
+
3
+ config_setting (
4
+ name = "use_pre_cxx11_abi" ,
5
+ values = {
6
+ "define" : "abi=pre_cxx11_abi" ,
7
+ }
8
+ )
9
+
10
+ lowering_test (
11
+ name = "test_remove_contiguous_pass" ,
12
+ )
13
+
14
+ test_suite (
15
+ name = "lowering_tests" ,
16
+ tests = [
17
+ ":test_remove_contiguous_pass"
18
+ ]
19
+ )
Original file line number Diff line number Diff line change
1
+ def lowering_test (name , visibility = None ):
2
+ native .cc_test (
3
+ name = name ,
4
+ srcs = [name + ".cpp" ],
5
+ visibility = visibility ,
6
+ deps = [
7
+ "//tests/util" ,
8
+ "//core" ,
9
+ "@googletest//:gtest_main" ,
10
+ ] + select ({
11
+ ":use_pre_cxx11_abi" : ["@libtorch_pre_cxx11_abi//:libtorch" ],
12
+ "//conditions:default" : ["@libtorch//:libtorch" ],
13
+ }),
14
+ timeout = "short"
15
+ )
Original file line number Diff line number Diff line change
1
+ #include < string>
2
+ #include " gtest/gtest.h"
3
+ #include " tests/util/util.h"
4
+ #include " core/compiler.h"
5
+ #include " core/lowering/passes/passes.h"
6
+ #include " torch/csrc/jit/ir/irparser.h"
7
+ #include " torch/csrc/jit/ir/subgraph_matcher.h"
8
+
9
+ TEST (LoweringPasses, RemoveContiguousLowersCorrectly) {
10
+ std::string source_graph = R"IR(
11
+ graph(%input, %1):
12
+ %2 = aten::contiguous(%input, %1)
13
+ %3 = foo::bar(%2)
14
+ return (%3))IR" ;
15
+ std::string target_graph = R"IR(
16
+ graph(%input, %1):
17
+ %3 = foo::bar(%input)
18
+ return (%3))IR" ;
19
+
20
+ auto sg = std::make_shared<torch::jit::Graph>();
21
+ torch::jit::parseIR (source_graph, &*sg);
22
+ trtorch::core::lowering::passes::RemoveContiguous (sg);
23
+
24
+ auto tg = std::make_shared<torch::jit::Graph>();
25
+ torch::jit::parseIR (target_graph, &*tg);
26
+
27
+ ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
28
+ }
You can’t perform that action at this time.
0 commit comments