Skip to content

Commit d00240e

Browse files
committed
feat(aten::hardsigmoid): Unpack hardsigmoid
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 460fc9b commit d00240e

File tree

6 files changed

+138
-0
lines changed

6 files changed

+138
-0
lines changed

core/lowering/lowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
4141
passes::MarkNodesForFallback(g, true);
4242
}
4343
passes::UnpackHardSwish(g);
44+
passes::UnpackHardSigmoid(g);
4445
passes::EliminateExceptionOrPassPattern(g);
4546
passes::ReduceToOperation(g);
4647
passes::ReduceGelu(g);

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ cc_library(
3030
"silu_to_sigmoid_multiplication.cpp",
3131
"unpack_addmm.cpp",
3232
"unpack_batch_norm.cpp",
33+
"unpack_hardsigmoid.cpp",
3334
"unpack_hardswish.cpp",
3435
"unpack_log_softmax.cpp",
3536
"unpack_std.cpp",

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph);
3838
void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);
3939
void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
4040
void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph);
41+
void UnpackHardSigmoid(std::shared_ptr<torch::jit::Graph>& graph);
4142

4243
} // namespace passes
4344
} // namespace lowering
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
2+
3+
#include "core/util/prelude.h"
4+
5+
namespace torch_tensorrt {
6+
namespace core {
7+
namespace lowering {
8+
namespace passes {
9+
10+
void UnpackHardSigmoid(std::shared_ptr<torch::jit::Graph>& graph) {
11+
std::string hardsigmoid_pattern = R"IR(
12+
graph(%input):
13+
%result = aten::hardsigmoid(%input)
14+
return (%result))IR";
15+
16+
std::string hardsigmoid_pattern_inplace = R"IR(
17+
graph(%input):
18+
%result = aten::hardsigmoid_(%input)
19+
return (%result))IR";
20+
21+
std::string new_pattern = R"IR(
22+
graph(%x.1):
23+
%22 : float = prim::Constant[value=0.5]()
24+
%3 : int = prim::Constant[value=6]()
25+
%5 : int = prim::Constant[value=1]()
26+
%10 : int = prim::Constant[value=0]()
27+
%4 : Tensor = aten::div(%x.1, %3)
28+
%9 : Tensor = aten::add(%4, %22, %5)
29+
%21 : Tensor = aten::clamp(%9, %10, %5)
30+
return (%21))IR";
31+
32+
torch::jit::SubgraphRewriter rewriter;
33+
rewriter.RegisterRewritePattern(hardsigmoid_pattern, new_pattern);
34+
rewriter.RegisterRewritePattern(hardsigmoid_pattern_inplace, new_pattern);
35+
rewriter.runOnGraph(graph);
36+
37+
LOG_GRAPH("Post unpack hardsigmoid: " << *graph);
38+
}
39+
40+
} // namespace passes
41+
} // namespace lowering
42+
} // namespace core
43+
} // namespace torch_tensorrt

tests/core/lowering/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ lowering_test(
7575
name = "test_silu_to_sigmoid_multiplication",
7676
)
7777

78+
lowering_test(
79+
name = "test_unpack_hardsigmoid",
80+
)
81+
7882
lowering_test(
7983
name = "test_unpack_hardswish",
8084
)
@@ -98,6 +102,7 @@ test_suite(
98102
":test_remove_detach_pass",
99103
":test_remove_dropout_pass",
100104
":test_remove_unnecessary_casts",
105+
":test_unpack_hardsigmoid",
101106
":test_unpack_hardswish",
102107
":test_unpack_reduce_ops",
103108
":test_view_to_reshape_pass",
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
9+
TEST(LoweringPasses, UnpackHardSigmoid) {
10+
std::string source_graph = R"IR(
11+
graph(%input):
12+
%result = aten::hardsigmoid(%input)
13+
return (%result))IR";
14+
15+
std::string target_graph = R"IR(
16+
graph(%x.1):
17+
%22 : float = prim::Constant[value=0.5]()
18+
%3 : int = prim::Constant[value=6]()
19+
%5 : int = prim::Constant[value=1]()
20+
%10 : int = prim::Constant[value=0]()
21+
%4 : Tensor = aten::div(%x.1, %3)
22+
%9 : Tensor = aten::add(%4, %22, %5)
23+
%21 : Tensor = aten::clamp(%9, %10, %5)
24+
return (%21))IR";
25+
26+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
27+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
28+
auto sg = std::make_shared<torch::jit::Graph>();
29+
torch::jit::parseIR(source_graph, &*sg);
30+
31+
auto in = at::rand({10, 100}, {at::kCUDA});
32+
auto sg_params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {});
33+
auto sg_results = torch_tensorrt::tests::util::RunGraph(sg, sg_params, {in});
34+
35+
torch_tensorrt::core::lowering::passes::UnpackHardSigmoid(sg);
36+
37+
auto tg = std::make_shared<torch::jit::Graph>();
38+
torch::jit::parseIR(target_graph, &*tg);
39+
40+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
41+
42+
in = at::clone(in);
43+
auto tg_params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {});
44+
auto tg_results = torch_tensorrt::tests::util::RunGraph(tg, tg_params, {in});
45+
46+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(sg_results[0], tg_results[0], 2e-6));
47+
}
48+
49+
TEST(LoweringPasses, UnpackHardSigmoidInPlace) {
50+
std::string source_graph = R"IR(
51+
graph(%input):
52+
%result = aten::hardsigmoid_(%input)
53+
return (%result))IR";
54+
55+
std::string target_graph = R"IR(
56+
graph(%x.1):
57+
%22 : float = prim::Constant[value=0.5]()
58+
%3 : int = prim::Constant[value=6]()
59+
%5 : int = prim::Constant[value=1]()
60+
%10 : int = prim::Constant[value=0]()
61+
%4 : Tensor = aten::div(%x.1, %3)
62+
%9 : Tensor = aten::add(%4, %22, %5)
63+
%21 : Tensor = aten::clamp(%9, %10, %5)
64+
return (%21))IR";
65+
66+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
67+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
68+
auto sg = std::make_shared<torch::jit::Graph>();
69+
torch::jit::parseIR(source_graph, &*sg);
70+
71+
auto in = at::rand({10, 100}, {at::kCUDA});
72+
auto sg_params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {});
73+
auto sg_results = torch_tensorrt::tests::util::RunGraph(sg, sg_params, {in});
74+
75+
torch_tensorrt::core::lowering::passes::UnpackHardSigmoid(sg);
76+
77+
auto tg = std::make_shared<torch::jit::Graph>();
78+
torch::jit::parseIR(target_graph, &*tg);
79+
80+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
81+
82+
in = at::clone(in);
83+
auto tg_params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {});
84+
auto tg_results = torch_tensorrt::tests::util::RunGraph(tg, tg_params, {in});
85+
86+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(sg_results[0], tg_results[0], 2e-6));
87+
}

0 commit comments

Comments
 (0)