Skip to content

Commit 111349d

Browse files
committed
tests(remove_to): Adding a basic remove to test
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 6c5118a commit 111349d

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

tests/core/lowering/BUILD

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@ lowering_test(
1111
name = "test_remove_contiguous_pass",
1212
)
1313

14+
lowering_test(
15+
name = "test_remove_to",
16+
)
17+
1418
test_suite(
1519
name = "lowering_tests",
1620
tests = [
17-
":test_remove_contiguous_pass"
21+
":test_remove_contiguous_pass",
22+
":test_remove_to"
1823
]
1924
)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
9+
TEST(LoweringPasses, RemoveContiguousLowersCorrectly) {
10+
std::string source_graph = R"IR(
11+
graph(%x.1):
12+
%6 : None = prim::Constant()
13+
%4 : bool = prim::Constant[value=0]()
14+
%3 : int = prim::Constant[value=5]() # experiments/test.py:8:17
15+
%y.1 : Tensor = aten::to(%x.1, %3, %4, %4, %6)
16+
%11 : Tensor = aten::relu(%y.1)
17+
return (%3))IR";
18+
std::string target_graph = R"IR(
19+
graph(%x.1):
20+
%11 : Tensor = aten::relu(%x.1)
21+
return (%11))IR";
22+
23+
auto sg = std::make_shared<torch::jit::Graph>();
24+
torch::jit::parseIR(source_graph, &*sg);
25+
trtorch::core::lowering::passes::RemoveContiguous(sg);
26+
27+
auto tg = std::make_shared<torch::jit::Graph>();
28+
torch::jit::parseIR(target_graph, &*tg);
29+
30+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
31+
}

0 commit comments

Comments
 (0)