Skip to content

Commit 374aafd

Browse files
inocsinnarendasan
authored andcommitted
add aten::detach lowering pass
Signed-off-by: inocsin <[email protected]>
1 parent 2b50334 commit 374aafd

File tree

3 files changed

+46
-1
lines changed

3 files changed

+46
-1
lines changed

core/lowering/passes/remove_to.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ struct ToRemoval {
2121

2222
void run() {
2323
findTo(graph_->block());
24+
findDetach(graph_->block());
2425
torch::jit::EliminateDeadCode(graph_);
2526
LOG_DEBUG(
2627
"RemoveTo - Note: Removing remaining aten::to operators, if type casts need to be preserved, add a pass before this pass is run");
@@ -39,6 +40,17 @@ struct ToRemoval {
3940
}
4041
}
4142

43+
void findDetach(Block* b) {
44+
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
45+
auto n = *it;
46+
if (n->kind() == c10::Symbol::fromQualString("aten::detach")) {
47+
LOG_GRAPH("Found that node " << *n << " is an detach node (RemoveTo)" << std::endl);
48+
n->outputs()[0]->replaceAllUsesWith(n->inputs()[0]);
49+
it.destroyCurrent();
50+
}
51+
}
52+
}
53+
4254
std::shared_ptr<Graph> graph_;
4355
};
4456
} // namespace

tests/core/lowering/BUILD

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,15 @@ lowering_test(
1515
name = "test_remove_to",
1616
)
1717

18+
lowering_test(
19+
name = "test_remove_detach_pass",
20+
)
21+
1822
test_suite(
1923
name = "lowering_tests",
2024
tests = [
2125
":test_remove_contiguous_pass",
22-
":test_remove_to"
26+
":test_remove_to",
27+
":test_remove_detach_pass"
2328
]
2429
)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
9+
TEST(LoweringPasses, RemoveDetachCorrectly) {
10+
std::string source_graph = R"IR(
11+
graph(%input):
12+
%2 = aten::detach(%input)
13+
%3 = foo::bar(%2)
14+
return (%3))IR";
15+
std::string target_graph = R"IR(
16+
graph(%input):
17+
%3 = foo::bar(%input)
18+
return (%3))IR";
19+
20+
auto sg = std::make_shared<torch::jit::Graph>();
21+
torch::jit::parseIR(source_graph, &*sg);
22+
trtorch::core::lowering::passes::RemoveContiguous(sg);
23+
24+
auto tg = std::make_shared<torch::jit::Graph>();
25+
torch::jit::parseIR(target_graph, &*tg);
26+
27+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
28+
}

0 commit comments

Comments
 (0)