Skip to content

Commit be4702c

Browse files
authored
Merge pull request #352 from NVIDIA/bowa_silu
Lowering pass for SiLU
2 parents 29ef616 + 5d0ab48 commit be4702c

File tree

6 files changed

+68
-1
lines changed

6 files changed

+68
-1
lines changed

core/lowering/lowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
4949
passes::UnpackLogSoftmax(g);
5050
passes::RemoveNOPs(g);
5151
passes::AliasOperators(g);
52+
passes::SiluToSigmoidMultipication(g);
5253
torch::jit::EliminateDeadCode(g);
5354
LOG_GRAPH(*g);
5455
}

core/lowering/passes/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ cc_library(
2525
"unpack_addmm.cpp",
2626
"unpack_batch_norm.cpp",
2727
"unpack_log_softmax.cpp",
28-
"op_aliasing.cpp"
28+
"op_aliasing.cpp",
29+
"silu_to_sigmoid_multiplication.cpp"
2930
],
3031
deps = [
3132
"//core/util:prelude",

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
2020
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
2121
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
2222
void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);
23+
void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
2324

2425
} // namespace passes
2526
} // namespace lowering
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
2+
3+
#include "core/util/prelude.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace lowering {
8+
namespace passes {
9+
10+
void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph) {
11+
std::string silu_pattern = R"IR(
12+
graph(%x):
13+
%1 : Tensor = aten::silu(%x)
14+
return (%1))IR";
15+
std::string sigmoid_multiplication_pattern = R"IR(
16+
graph(%x):
17+
%1 : Tensor = aten::sigmoid(%x)
18+
%2 : Tensor = aten::mul(%x, %1)
19+
return (%2))IR";
20+
;
21+
22+
torch::jit::SubgraphRewriter map_silu_to_sigmoid_multiplication;
23+
map_silu_to_sigmoid_multiplication.RegisterRewritePattern(silu_pattern, sigmoid_multiplication_pattern);
24+
map_silu_to_sigmoid_multiplication.runOnGraph(graph);
25+
LOG_GRAPH("Post map silu -> x * sigmoid(x): " << *graph);
26+
}
27+
28+
} // namespace passes
29+
} // namespace lowering
30+
} // namespace core
31+
} // namespace trtorch

tests/core/lowering/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ lowering_test(
2323
name = "test_operator_aliasing_pass",
2424
)
2525

26+
lowering_test(
27+
name = "test_silu_to_sigmoid_multiplication",
28+
)
29+
2630
test_suite(
2731
name = "lowering_tests",
2832
tests = [
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
9+
TEST(LoweringPasses, RemoveSiluLowersCorrectly) {
10+
std::string source_graph = R"IR(
11+
graph(%x.1 : Tensor):
12+
%2 : Tensor = aten::silu(%x.1)
13+
return (%2))IR";
14+
std::string target_graph = R"IR(
15+
graph(%x.1):
16+
%2 : Tensor = aten::sigmoid(%x.1)
17+
%3 : Tensor = aten::mul(%x.1, %2)
18+
return (%3))IR";
19+
20+
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
21+
auto sg = std::make_shared<torch::jit::Graph>();
22+
torch::jit::parseIR(source_graph, &*sg);
23+
trtorch::core::lowering::passes::SiluToSigmoidMultipication(sg);
24+
25+
auto tg = std::make_shared<torch::jit::Graph>();
26+
torch::jit::parseIR(target_graph, &*tg);
27+
28+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
29+
}

0 commit comments

Comments
 (0)