Skip to content

Commit 53451e1

Browse files
Michael Maitlandfacebook-github-bot
authored andcommitted
Support approximate gelu (#11246)
Summary: Pull Request resolved: #11246 GELU accepts an `approximate` argument which is either `none` by default, or `tanh` When the `approximate` kwarg is present, decompose the op. We already have an existing test in test_aten_gelu_out to make sure the op is supported. Reviewed By: zonglinpeng Differential Revision: D75454999
1 parent b02ac1b commit 53451e1

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2065,11 +2065,10 @@ def call_operator(
20652065
return super().call_operator(op, args, kwargs, meta)
20662066

20672067

2068-
@register_cadence_pass(CadencePassAttribute(opt_level=2))
2069-
class ReplaceGeluWithApproximateGeluPass(ExportPass):
2068+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
2069+
class ReplaceAtenApproxGeluWithApproxGeluPass(ExportPass):
20702070
"""
2071-
Replace the gelu op with an approximate gelu op. The approximate gelu op
2072-
is more efficient on DSP backends.
2071+
Replace the aten gelu op with an approximate arg with an approximate gelu op.
20732072
"""
20742073

20752074
def call_operator(
@@ -2079,6 +2078,9 @@ def call_operator(
20792078
kwargs: Dict[str, Argument],
20802079
meta: NodeMetadata,
20812080
) -> ProxyValue:
2081+
if "approximate" not in kwargs:
2082+
return super().call_operator(op, args, kwargs, meta)
2083+
20822084
if op not in {
20832085
exir_ops.edge.aten.gelu.default,
20842086
}:
@@ -2414,7 +2416,7 @@ class CadenceReplaceOpsInGraph:
24142416
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
24152417
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
24162418
ReplaceWhereWithFullArgsWithWhereScalar,
2417-
ReplaceGeluWithApproximateGeluPass,
2419+
ReplaceAtenApproxGeluWithApproxGeluPass,
24182420
ReplaceSplitWithSlicePass,
24192421
ReplacePowWithMulPass,
24202422
]

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
ReplaceConvWithIm2RowAndLinear,
3333
ReplaceEmptyTensorsWithFullPass,
3434
ReplaceFunctionallyEquivalentOpTargets,
35-
ReplaceGeluWithApproximateGeluPass,
35+
ReplaceAtenApproxGeluWithApproxGeluPass,
3636
ReplaceIm2RowWithViewPass,
3737
ReplaceLinearWithFullyConnectedOpPass,
3838
ReplaceMatmulWithTransposedMatmulPass,
@@ -1287,7 +1287,7 @@ def forward(self, cond: torch.Tensor):
12871287
1,
12881288
)
12891289

1290-
def test_replace_aten_gelu_with_approximate_gelu(self):
1290+
def test_no_replace_aten_gelu_with_approximate_gelu(self):
12911291
class Gelu(torch.nn.Module):
12921292
def forward(self, input):
12931293
return torch.nn.functional.gelu(input)
@@ -1296,7 +1296,28 @@ def forward(self, input):
12961296

12971297
graph_module = export_to_edge(Gelu(), (inputs,)).exported_program().graph_module
12981298

1299-
p = ReplaceGeluWithApproximateGeluPass()
1299+
p = ReplaceAtenApproxGeluWithApproxGeluPass()
1300+
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1301+
1302+
# Assert that aten.gelu op was not decomposed, since it didn't have an approximate argument
1303+
self.assertEqual(
1304+
count_node(
1305+
graph_after_passes,
1306+
exir_ops.edge.aten.gelu.default,
1307+
),
1308+
1,
1309+
)
1310+
1311+
def test_replace_aten_approximate_gelu_with_approximate_gelu(self):
1312+
class Gelu(torch.nn.Module):
1313+
def forward(self, input):
1314+
return torch.nn.functional.gelu(input, approximate = "tanh")
1315+
1316+
inputs = torch.randn(2, 1, 64)
1317+
1318+
graph_module = export_to_edge(Gelu(), (inputs,)).exported_program().graph_module
1319+
1320+
p = ReplaceAtenApproxGeluWithApproxGeluPass()
13001321
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
13011322

13021323
# Assert that aten.gelu op was decomposed

0 commit comments

Comments
 (0)