Skip to content

Commit 6768609

Browse files
author
Kaiyu Shi
committed
Add docstring
1 parent 7a910f3 commit 6768609

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

onnxscript/rewriter/fuse_hardswish.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818

1919
class _HardSigmoidFusionBase(pattern.RewriteRuleClassBase):
20+
"""HardSwish requires constant values so we check in base class."""
21+
2022
def check(
2123
self,
2224
op,
@@ -40,6 +42,12 @@ def check(
4042

4143

4244
class HardSwishFusion(_HardSigmoidFusionBase):
45+
"""Fuse Add(_, 3) + Clip<0, 6>(_) + Mul + Div(_, 6) into HardSwish
46+
47+
In this case we can't make HardSigmoid fusion first. The Mul
48+
is placed before Div while HardSigmoid requires Add+Clip+Div.
49+
"""
50+
4351
def pattern(
4452
self,
4553
op,
@@ -66,6 +74,8 @@ def rewrite(
6674

6775

6876
class HardSwishFusionFromHardSigmoid(pattern.RewriteRuleClassBase):
77+
"""Fuse HardSigmoid<alpha=1/6, beta=0.5> + Mul into HardSwish"""
78+
6979
def pattern(self, op, x: ir.Value) -> ir.Value:
7080
# Floating point matching for 1/6 is not exact, so we use isclose below
7181
out = op.HardSigmoid(x, _allow_other_attributes=True, _outputs=["hardsigmoid_out"])

0 commit comments

Comments
 (0)