Skip to content

Commit b433a53

Browse files
authored
Merge pull request #298 from NVIDIA/fix_remove_to_pass
Fix remove to pass
2 parents c8dc7ad + 111349d commit b433a53

File tree

3 files changed

+38
-2
lines changed

3 files changed

+38
-2
lines changed

core/lowering/passes/remove_to.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct ToRemoval {
3333
auto n = *it;
3434
if (n->kind() == c10::Symbol::fromQualString("aten::to")) {
3535
LOG_GRAPH("Found that node " << *n << " is an to node (RemoveTo)" << std::endl);
36-
n->outputs()[0]->replaceAllUsesWith(n->inputs()[1]);
36+
n->outputs()[0]->replaceAllUsesWith(n->inputs()[0]);
3737
it.destroyCurrent();
3838
}
3939
}

tests/core/lowering/BUILD

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@ lowering_test(
1111
name = "test_remove_contiguous_pass",
1212
)
1313

14+
lowering_test(
15+
name = "test_remove_to",
16+
)
17+
1418
test_suite(
1519
name = "lowering_tests",
1620
tests = [
17-
":test_remove_contiguous_pass"
21+
":test_remove_contiguous_pass",
22+
":test_remove_to"
1823
]
1924
)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
9+
TEST(LoweringPasses, RemoveContiguousLowersCorrectly) {
10+
std::string source_graph = R"IR(
11+
graph(%x.1):
12+
%6 : None = prim::Constant()
13+
%4 : bool = prim::Constant[value=0]()
14+
%3 : int = prim::Constant[value=5]() # experiments/test.py:8:17
15+
%y.1 : Tensor = aten::to(%x.1, %3, %4, %4, %6)
16+
%11 : Tensor = aten::relu(%y.1)
17+
return (%3))IR";
18+
std::string target_graph = R"IR(
19+
graph(%x.1):
20+
%11 : Tensor = aten::relu(%x.1)
21+
return (%11))IR";
22+
23+
auto sg = std::make_shared<torch::jit::Graph>();
24+
torch::jit::parseIR(source_graph, &*sg);
25+
trtorch::core::lowering::passes::RemoveContiguous(sg);
26+
27+
auto tg = std::make_shared<torch::jit::Graph>();
28+
torch::jit::parseIR(target_graph, &*tg);
29+
30+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
31+
}

0 commit comments

Comments
 (0)