Skip to content

Commit 3d6b1d0

Browse files
committed
refactor(RemoveNOPs): Rename RemoveTo to RemoveNOPs since it now covers
both to and detach. We should be able to make this generic for a set of operators that have no meaning in TensorRT. Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 699012c commit 3d6b1d0

File tree

7 files changed

+22
-18
lines changed

7 files changed

+22
-18
lines changed

core/lowering/lowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
4747
passes::UnpackAddMM(g);
4848
// passes::UnpackBatchNorm(g);
4949
passes::UnpackLogSoftmax(g);
50-
passes::RemoveTo(g);
50+
passes::RemoveNOPs(g);
5151
torch::jit::EliminateDeadCode(g);
5252
LOG_GRAPH(*g);
5353
}

core/lowering/passes/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ cc_library(
2121
"remove_bn_dim_check.cpp",
2222
"remove_contiguous.cpp",
2323
"remove_dropout.cpp",
24-
"remove_to.cpp",
24+
"remove_nops.cpp",
2525
"unpack_addmm.cpp",
2626
"unpack_batch_norm.cpp",
2727
"unpack_log_softmax.cpp",

core/lowering/passes/passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
1515
void RemoveBNDimCheck(std::shared_ptr<torch::jit::Graph> graph);
1616
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
1717
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
18-
void RemoveTo(std::shared_ptr<torch::jit::Graph> graph);
18+
void RemoveNOPs(std::shared_ptr<torch::jit::Graph> graph);
1919
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
2020
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
2121
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);

core/lowering/passes/remove_to.cpp renamed to core/lowering/passes/remove_nops.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,35 +16,35 @@ namespace lowering {
1616
namespace passes {
1717
namespace {
1818
using namespace torch::jit;
19-
struct ToRemoval {
20-
ToRemoval(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {}
19+
struct NOPRemoval {
20+
NOPRemoval(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {}
2121

2222
void run() {
23-
findTo(graph_->block());
24-
findDetach(graph_->block());
23+
removeTo(graph_->block());
24+
removeDetach(graph_->block());
2525
torch::jit::EliminateDeadCode(graph_);
2626
LOG_DEBUG(
27-
"RemoveTo - Note: Removing remaining aten::to operators, if type casts need to be preserved, add a pass before this pass is run");
27+
"RemoveNOPs - Note: Removing remaining aten::to operators (in addition to other ops that have no meaning in TRT), if type casts need to be preserved, add a pass before this pass is run");
2828
LOG_GRAPH("Post aten::to removal: " << *graph_);
2929
}
3030

3131
private:
32-
void findTo(Block* b) {
32+
void removeTo(Block* b) {
3333
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
3434
auto n = *it;
3535
if (n->kind() == c10::Symbol::fromQualString("aten::to")) {
36-
LOG_GRAPH("Found that node " << *n << " is an to node (RemoveTo)" << std::endl);
36+
LOG_GRAPH("Found that node " << *n << " is an to node (RemoveNOPs)" << std::endl);
3737
n->outputs()[0]->replaceAllUsesWith(n->inputs()[0]);
3838
it.destroyCurrent();
3939
}
4040
}
4141
}
4242

43-
void findDetach(Block* b) {
43+
void removeDetach(Block* b) {
4444
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
4545
auto n = *it;
4646
if (n->kind() == c10::Symbol::fromQualString("aten::detach")) {
47-
LOG_GRAPH("Found that node " << *n << " is an detach node (RemoveTo)" << std::endl);
47+
LOG_GRAPH("Found that node " << *n << " is an detach node (RemoveNOPs)" << std::endl);
4848
n->outputs()[0]->replaceAllUsesWith(n->inputs()[0]);
4949
it.destroyCurrent();
5050
}
@@ -55,8 +55,8 @@ struct ToRemoval {
5555
};
5656
} // namespace
5757

58-
void RemoveTo(std::shared_ptr<Graph> graph) {
59-
ToRemoval tr(std::move(graph));
58+
void RemoveNOPs(std::shared_ptr<Graph> graph) {
59+
NOPRemoval tr(std::move(graph));
6060
tr.run();
6161
}
6262

tests/core/lowering/test_remove_contiguous_pass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ TEST(LoweringPasses, RemoveContiguousLowersCorrectly) {
1717
%3 = foo::bar(%input)
1818
return (%3))IR";
1919

20+
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
2021
auto sg = std::make_shared<torch::jit::Graph>();
2122
torch::jit::parseIR(source_graph, &*sg);
2223
trtorch::core::lowering::passes::RemoveContiguous(sg);

tests/core/lowering/test_remove_detach_pass.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@ TEST(LoweringPasses, RemoveDetachCorrectly) {
1717
%3 = aten::sin(%input)
1818
return (%3))IR";
1919

20+
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
2021
auto sg = std::make_shared<torch::jit::Graph>();
2122
torch::jit::parseIR(source_graph, &*sg);
22-
trtorch::core::lowering::passes::RemoveTo(sg);
23+
trtorch::core::lowering::passes::RemoveNOPs(sg);
2324

2425
auto tg = std::make_shared<torch::jit::Graph>();
2526
torch::jit::parseIR(target_graph, &*tg);

tests/core/lowering/test_remove_to.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,30 @@
11
#include <string>
22
#include "core/compiler.h"
33
#include "core/lowering/passes/passes.h"
4+
#include "core/util/prelude.h"
45
#include "gtest/gtest.h"
56
#include "tests/util/util.h"
67
#include "torch/csrc/jit/ir/irparser.h"
78
#include "torch/csrc/jit/ir/subgraph_matcher.h"
89

9-
TEST(LoweringPasses, RemoveContiguousLowersCorrectly) {
10+
TEST(LoweringPasses, RemoveToLowersCorrectly) {
1011
std::string source_graph = R"IR(
1112
graph(%x.1):
1213
%6 : None = prim::Constant()
1314
%4 : bool = prim::Constant[value=0]()
1415
%3 : int = prim::Constant[value=5]() # experiments/test.py:8:17
1516
%y.1 : Tensor = aten::to(%x.1, %3, %4, %4, %6)
1617
%11 : Tensor = aten::relu(%y.1)
17-
return (%3))IR";
18+
return (%11))IR";
1819
std::string target_graph = R"IR(
1920
graph(%x.1):
2021
%11 : Tensor = aten::relu(%x.1)
2122
return (%11))IR";
2223

24+
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
2325
auto sg = std::make_shared<torch::jit::Graph>();
2426
torch::jit::parseIR(source_graph, &*sg);
25-
trtorch::core::lowering::passes::RemoveContiguous(sg);
27+
trtorch::core::lowering::passes::RemoveNOPs(sg);
2628

2729
auto tg = std::make_shared<torch::jit::Graph>();
2830
torch::jit::parseIR(target_graph, &*tg);

0 commit comments

Comments
 (0)