Skip to content

Commit 1f0a8cb

Browse files
authored
Merge pull request #413 from guoruoqian/feature_dropout
Support feature_dropout and feature_alpha_dropout converters
2 parents 9eeea49 + f0a4a10 commit 1f0a8cb

File tree

4 files changed

+218
-2
lines changed

4 files changed

+218
-2
lines changed

core/conversion/conversion_ignorelist.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ const std::unordered_set<std::string>& get_non_convertable_nodes() {
2424
"prim::CallMethod",
2525
"prim::Drop",
2626
"aten::dropout",
27-
"aten::dropout_"};
27+
"aten::dropout_",
28+
"aten::feature_dropout",
29+
"aten::feature_dropout_",
30+
"aten::feature_alpha_dropout",
31+
"aten::feature_alpha_dropout_"};
2832
return nonconvertable_nodes;
2933
}
3034
// clang-format on

core/lowering/passes/remove_dropout.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,61 @@ void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
3232
remove_dropout_inplace_pattern.RegisterRewritePattern(dropout_inplace_pattern, no_dropout_inplace_pattern);
3333
remove_dropout_inplace_pattern.runOnGraph(graph);
3434

35+
// remove feature_dropout
36+
std::string feature_dropout_pattern = R"IR(
37+
graph(%input, %4, %5):
38+
%6 = aten::feature_dropout(%input, %4, %5)
39+
return (%6))IR";
40+
std::string no_feature_dropout_pattern = R"IR(
41+
graph(%input, %4, %5):
42+
return (%input))IR";
43+
44+
torch::jit::SubgraphRewriter remove_feature_dropout_pattern;
45+
remove_feature_dropout_pattern.RegisterRewritePattern(feature_dropout_pattern, no_feature_dropout_pattern);
46+
remove_feature_dropout_pattern.runOnGraph(graph);
47+
48+
// remove feature_dropout inplace
49+
std::string feature_dropout_inplace_pattern = R"IR(
50+
graph(%input, %4, %5):
51+
%6 = aten::feature_dropout_(%input, %4, %5)
52+
return (%6))IR";
53+
std::string no_feature_dropout_inplace_pattern = R"IR(
54+
graph(%input, %4, %5):
55+
return (%input))IR";
56+
57+
torch::jit::SubgraphRewriter remove_feature_dropout_inplace_pattern;
58+
remove_feature_dropout_inplace_pattern.RegisterRewritePattern(
59+
feature_dropout_inplace_pattern, no_feature_dropout_inplace_pattern);
60+
remove_feature_dropout_inplace_pattern.runOnGraph(graph);
61+
62+
// remove feature_alpha_dropout
63+
std::string feature_alpha_dropout_pattern = R"IR(
64+
graph(%input, %4, %5):
65+
%6 = aten::feature_alpha_dropout(%input, %4, %5)
66+
return (%6))IR";
67+
std::string no_feature_alpha_dropout_pattern = R"IR(
68+
graph(%input, %4, %5):
69+
return (%input))IR";
70+
71+
torch::jit::SubgraphRewriter remove_feature_alpha_dropout_pattern;
72+
remove_feature_alpha_dropout_pattern.RegisterRewritePattern(
73+
feature_alpha_dropout_pattern, no_feature_alpha_dropout_pattern);
74+
remove_feature_alpha_dropout_pattern.runOnGraph(graph);
75+
76+
// remove feature_alpha_dropout inplace
77+
std::string feature_alpha_dropout_inplace_pattern = R"IR(
78+
graph(%input, %4, %5):
79+
%6 = aten::feature_alpha_dropout_(%input, %4, %5)
80+
return (%6))IR";
81+
std::string no_feature_alpha_dropout_inplace_pattern = R"IR(
82+
graph(%input, %4, %5):
83+
return (%input))IR";
84+
85+
torch::jit::SubgraphRewriter remove_feature_alpha_dropout_inplace_pattern;
86+
remove_feature_alpha_dropout_inplace_pattern.RegisterRewritePattern(
87+
feature_alpha_dropout_inplace_pattern, no_feature_alpha_dropout_inplace_pattern);
88+
remove_feature_alpha_dropout_inplace_pattern.runOnGraph(graph);
89+
3590
LOG_GRAPH("Post remove dropout: " << *graph);
3691
}
3792

tests/core/lowering/BUILD

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ lowering_test(
1515
name = "test_remove_contiguous_pass",
1616
)
1717

18+
lowering_test(
19+
name = "test_remove_dropout_pass",
20+
)
21+
1822
lowering_test(
1923
name = "test_remove_to",
2024
)
@@ -38,6 +42,7 @@ test_suite(
3842
":test_remove_contiguous_pass",
3943
":test_remove_to",
4044
":test_remove_detach_pass",
41-
":test_operator_aliasing_pass"
45+
":test_operator_aliasing_pass",
46+
":test_remove_dropout_pass"
4247
]
4348
)
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "core/util/prelude.h"
5+
#include "gtest/gtest.h"
6+
#include "tests/util/util.h"
7+
#include "torch/csrc/jit/ir/irparser.h"
8+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
9+
10+
TEST(LoweringPasses, RemoveDropoutLowersCorrectly) {
11+
std::string source_graph = R"IR(
12+
graph(%x.1):
13+
%3 : float = prim::Constant[value=0.5]()
14+
%4 : bool = prim::Constant[value=0]()
15+
%y.1 : Tensor = aten::dropout(%x.1, %3, %4)
16+
%11 : Tensor = aten::relu(%y.1)
17+
return (%11))IR";
18+
std::string target_graph = R"IR(
19+
graph(%x.1):
20+
%11 : Tensor = aten::relu(%x.1)
21+
return (%11))IR";
22+
23+
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
24+
auto sg = std::make_shared<torch::jit::Graph>();
25+
torch::jit::parseIR(source_graph, sg.get());
26+
trtorch::core::lowering::passes::RemoveDropout(sg);
27+
28+
auto tg = std::make_shared<torch::jit::Graph>();
29+
torch::jit::parseIR(target_graph, tg.get());
30+
31+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
32+
}
33+
34+
TEST(LoweringPasses, RemoveDropoutInplaceLowersCorrectly) {
35+
std::string source_graph = R"IR(
36+
graph(%x.1):
37+
%3 : float = prim::Constant[value=0.5]()
38+
%4 : bool = prim::Constant[value=0]()
39+
%y.1 : Tensor = aten::dropout_(%x.1, %3, %4)
40+
%11 : Tensor = aten::relu(%y.1)
41+
return (%11))IR";
42+
std::string target_graph = R"IR(
43+
graph(%x.1):
44+
%11 : Tensor = aten::relu(%x.1)
45+
return (%11))IR";
46+
47+
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
48+
auto sg = std::make_shared<torch::jit::Graph>();
49+
torch::jit::parseIR(source_graph, sg.get());
50+
trtorch::core::lowering::passes::RemoveDropout(sg);
51+
52+
auto tg = std::make_shared<torch::jit::Graph>();
53+
torch::jit::parseIR(target_graph, tg.get());
54+
55+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
56+
}
57+
58+
TEST(LoweringPasses, RemoveFeatureDropoutLowersCorrectly) {
59+
std::string source_graph = R"IR(
60+
graph(%x.1):
61+
%3 : float = prim::Constant[value=0.5]()
62+
%4 : bool = prim::Constant[value=0]()
63+
%y.1 : Tensor = aten::feature_dropout(%x.1, %3, %4)
64+
%11 : Tensor = aten::relu(%y.1)
65+
return (%11))IR";
66+
std::string target_graph = R"IR(
67+
graph(%x.1):
68+
%11 : Tensor = aten::relu(%x.1)
69+
return (%11))IR";
70+
71+
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
72+
auto sg = std::make_shared<torch::jit::Graph>();
73+
torch::jit::parseIR(source_graph, sg.get());
74+
trtorch::core::lowering::passes::RemoveDropout(sg);
75+
76+
auto tg = std::make_shared<torch::jit::Graph>();
77+
torch::jit::parseIR(target_graph, tg.get());
78+
79+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
80+
}
81+
82+
TEST(LoweringPasses, RemoveFeatureDropoutInplaceLowersCorrectly) {
83+
std::string source_graph = R"IR(
84+
graph(%x.1):
85+
%3 : float = prim::Constant[value=0.5]()
86+
%4 : bool = prim::Constant[value=0]()
87+
%y.1 : Tensor = aten::feature_dropout_(%x.1, %3, %4)
88+
%11 : Tensor = aten::relu(%y.1)
89+
return (%11))IR";
90+
std::string target_graph = R"IR(
91+
graph(%x.1):
92+
%11 : Tensor = aten::relu(%x.1)
93+
return (%11))IR";
94+
95+
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
96+
auto sg = std::make_shared<torch::jit::Graph>();
97+
torch::jit::parseIR(source_graph, sg.get());
98+
trtorch::core::lowering::passes::RemoveDropout(sg);
99+
100+
auto tg = std::make_shared<torch::jit::Graph>();
101+
torch::jit::parseIR(target_graph, tg.get());
102+
103+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
104+
}
105+
106+
TEST(LoweringPasses, RemoveFeatureAlphaDropoutLowersCorrectly) {
107+
std::string source_graph = R"IR(
108+
graph(%x.1):
109+
%3 : float = prim::Constant[value=0.5]()
110+
%4 : bool = prim::Constant[value=0]()
111+
%y.1 : Tensor = aten::feature_alpha_dropout(%x.1, %3, %4)
112+
%11 : Tensor = aten::relu(%y.1)
113+
return (%11))IR";
114+
std::string target_graph = R"IR(
115+
graph(%x.1):
116+
%11 : Tensor = aten::relu(%x.1)
117+
return (%11))IR";
118+
119+
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
120+
auto sg = std::make_shared<torch::jit::Graph>();
121+
torch::jit::parseIR(source_graph, sg.get());
122+
trtorch::core::lowering::passes::RemoveDropout(sg);
123+
124+
auto tg = std::make_shared<torch::jit::Graph>();
125+
torch::jit::parseIR(target_graph, tg.get());
126+
127+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
128+
}
129+
130+
TEST(LoweringPasses, RemoveFeatureAlphaDropoutInplaceLowersCorrectly) {
131+
std::string source_graph = R"IR(
132+
graph(%x.1):
133+
%3 : float = prim::Constant[value=0.5]()
134+
%4 : bool = prim::Constant[value=0]()
135+
%y.1 : Tensor = aten::feature_alpha_dropout_(%x.1, %3, %4)
136+
%11 : Tensor = aten::relu(%y.1)
137+
return (%11))IR";
138+
std::string target_graph = R"IR(
139+
graph(%x.1):
140+
%11 : Tensor = aten::relu(%x.1)
141+
return (%11))IR";
142+
143+
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
144+
auto sg = std::make_shared<torch::jit::Graph>();
145+
torch::jit::parseIR(source_graph, sg.get());
146+
trtorch::core::lowering::passes::RemoveDropout(sg);
147+
148+
auto tg = std::make_shared<torch::jit::Graph>();
149+
torch::jit::parseIR(target_graph, tg.get());
150+
151+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
152+
}

0 commit comments

Comments
 (0)