File tree Expand file tree Collapse file tree 6 files changed +68
-1
lines changed Expand file tree Collapse file tree 6 files changed +68
-1
lines changed Original file line number Diff line number Diff line change @@ -49,6 +49,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
49
49
passes::UnpackLogSoftmax (g);
50
50
passes::RemoveNOPs (g);
51
51
passes::AliasOperators (g);
52
+ passes::SiluToSigmoidMultipication (g);
52
53
torch::jit::EliminateDeadCode (g);
53
54
LOG_GRAPH (*g);
54
55
}
Original file line number Diff line number Diff line change @@ -25,7 +25,8 @@ cc_library(
25
25
"unpack_addmm.cpp" ,
26
26
"unpack_batch_norm.cpp" ,
27
27
"unpack_log_softmax.cpp" ,
28
- "op_aliasing.cpp"
28
+ "op_aliasing.cpp" ,
29
+ "silu_to_sigmoid_multiplication.cpp"
29
30
],
30
31
deps = [
31
32
"//core/util:prelude" ,
Original file line number Diff line number Diff line change @@ -20,6 +20,7 @@ void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
20
20
void UnpackBatchNorm (std::shared_ptr<torch::jit::Graph>& graph);
21
21
void UnpackLogSoftmax (std::shared_ptr<torch::jit::Graph>& graph);
22
22
void AliasOperators (std::shared_ptr<torch::jit::Graph>& graph);
23
+ void SiluToSigmoidMultipication (std::shared_ptr<torch::jit::Graph>& graph);
23
24
24
25
} // namespace passes
25
26
} // namespace lowering
Original file line number Diff line number Diff line change
1
+ #include < torch/csrc/jit/passes/subgraph_rewrite.h>
2
+
3
+ #include " core/util/prelude.h"
4
+
5
+ namespace trtorch {
6
+ namespace core {
7
+ namespace lowering {
8
+ namespace passes {
9
+
10
+ void SiluToSigmoidMultipication (std::shared_ptr<torch::jit::Graph>& graph) {
11
+ std::string silu_pattern = R"IR(
12
+ graph(%x):
13
+ %1 : Tensor = aten::silu(%x)
14
+ return (%1))IR" ;
15
+ std::string sigmoid_multiplication_pattern = R"IR(
16
+ graph(%x):
17
+ %1 : Tensor = aten::sigmoid(%x)
18
+ %2 : Tensor = aten::mul(%x, %1)
19
+ return (%2))IR" ;
20
+ ;
21
+
22
+ torch::jit::SubgraphRewriter map_silu_to_sigmoid_multiplication;
23
+ map_silu_to_sigmoid_multiplication.RegisterRewritePattern (silu_pattern, sigmoid_multiplication_pattern);
24
+ map_silu_to_sigmoid_multiplication.runOnGraph (graph);
25
+ LOG_GRAPH (" Post map silu -> x * sigmoid(x): " << *graph);
26
+ }
27
+
28
+ } // namespace passes
29
+ } // namespace lowering
30
+ } // namespace core
31
+ } // namespace trtorch
Original file line number Diff line number Diff line change @@ -23,6 +23,10 @@ lowering_test(
23
23
name = "test_operator_aliasing_pass" ,
24
24
)
25
25
26
+ lowering_test (
27
+ name = "test_silu_to_sigmoid_multiplication" ,
28
+ )
29
+
26
30
test_suite (
27
31
name = "lowering_tests" ,
28
32
tests = [
Original file line number Diff line number Diff line change
1
+ #include < string>
2
+ #include " core/compiler.h"
3
+ #include " core/lowering/passes/passes.h"
4
+ #include " gtest/gtest.h"
5
+ #include " tests/util/util.h"
6
+ #include " torch/csrc/jit/ir/irparser.h"
7
+ #include " torch/csrc/jit/ir/subgraph_matcher.h"
8
+
9
+ TEST (LoweringPasses, RemoveSiluLowersCorrectly) {
10
+ std::string source_graph = R"IR(
11
+ graph(%x.1 : Tensor):
12
+ %2 : Tensor = aten::silu(%x.1)
13
+ return (%2))IR" ;
14
+ std::string target_graph = R"IR(
15
+ graph(%x.1):
16
+ %2 : Tensor = aten::sigmoid(%x.1)
17
+ %3 : Tensor = aten::mul(%x.1, %2)
18
+ return (%3))IR" ;
19
+
20
+ trtorch::core::util::logging::get_logger ().set_reportable_log_level (trtorch::core::util::logging::LogLevel::kGRAPH );
21
+ auto sg = std::make_shared<torch::jit::Graph>();
22
+ torch::jit::parseIR (source_graph, &*sg);
23
+ trtorch::core::lowering::passes::SiluToSigmoidMultipication (sg);
24
+
25
+ auto tg = std::make_shared<torch::jit::Graph>();
26
+ torch::jit::parseIR (target_graph, &*tg);
27
+
28
+ ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
29
+ }
You can’t perform that action at this time.
0 commit comments