Skip to content

Commit b5567be

Browse files
Support approximate gelu
Differential Revision: D75454999 Pull Request resolved: #11246
1 parent 97c8bf7 commit b5567be

File tree

2 files changed

+39
-13
lines changed

2 files changed

+39
-13
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2065,11 +2065,10 @@ def call_operator(
20652065
return super().call_operator(op, args, kwargs, meta)
20662066

20672067

2068-
@register_cadence_pass(CadencePassAttribute(opt_level=2))
2069-
class ReplaceGeluWithApproximateGeluPass(ExportPass):
2068+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
2069+
class ReplaceAtenApproxGeluWithApproxGeluPass(ExportPass):
20702070
"""
2071-
Replace the gelu op with an approximate gelu op. The approximate gelu op
2072-
is more efficient on DSP backends.
2071+
Replace the aten gelu op with an approximate arg with an approximate gelu op.
20732072
"""
20742073

20752074
def call_operator(
@@ -2079,6 +2078,9 @@ def call_operator(
20792078
kwargs: Dict[str, Argument],
20802079
meta: NodeMetadata,
20812080
) -> ProxyValue:
2081+
if "approximate" not in kwargs:
2082+
return super().call_operator(op, args, kwargs, meta)
2083+
20822084
if op not in {
20832085
exir_ops.edge.aten.gelu.default,
20842086
}:
@@ -2414,7 +2416,7 @@ class CadenceReplaceOpsInGraph:
24142416
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
24152417
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
24162418
ReplaceWhereWithFullArgsWithWhereScalar,
2417-
ReplaceGeluWithApproximateGeluPass,
2419+
ReplaceAtenApproxGeluWithApproxGeluPass,
24182420
ReplaceSplitWithSlicePass,
24192421
ReplacePowWithMulPass,
24202422
]

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@
2626
ForceChannelLastForConvPass,
2727
MakeSliceAndCatDimOutermostPass,
2828
ReplaceAddMMWithLinearPass,
29+
ReplaceAtenApproxGeluWithApproxGeluPass,
2930
ReplaceAtenConvolutionWithJarvisConvolutionPass,
3031
ReplaceConstantPadNdWithSlicePass,
3132
ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
3233
ReplaceConvWithIm2RowAndLinear,
3334
ReplaceEmptyTensorsWithFullPass,
3435
ReplaceFunctionallyEquivalentOpTargets,
35-
ReplaceGeluWithApproximateGeluPass,
3636
ReplaceIm2RowWithViewPass,
3737
ReplaceLinearWithFullyConnectedOpPass,
3838
ReplaceMatmulWithTransposedMatmulPass,
@@ -1287,17 +1287,41 @@ def forward(self, cond: torch.Tensor):
12871287
1,
12881288
)
12891289

1290-
def test_replace_aten_gelu_with_approximate_gelu(self):
1291-
class Gelu(torch.nn.Module):
1292-
def forward(self, input):
1293-
return torch.nn.functional.gelu(input)
1290+
def test_no_replace_aten_gelu_with_approximate_gelu(self):
1291+
inputs = torch.randn(2, 1, 64)
1292+
1293+
gm = single_op_builder(
1294+
placeholders=(inputs,),
1295+
op=exir_ops.edge.aten.gelu.default,
1296+
args=(inputs,),
1297+
)
1298+
gm = ExportPass().call(gm).graph_module
1299+
1300+
p = ReplaceAtenApproxGeluWithApproxGeluPass()
1301+
graph_after_passes = p.call(gm).graph_module
12941302

1303+
# Assert that aten.gelu op was not decomposed, since it didn't have an approximate argument
1304+
self.assertEqual(
1305+
count_node(
1306+
graph_after_passes,
1307+
exir_ops.edge.aten.gelu.default,
1308+
),
1309+
1,
1310+
)
1311+
1312+
def test_replace_aten_approximate_gelu_with_approximate_gelu(self):
12951313
inputs = torch.randn(2, 1, 64)
12961314

1297-
graph_module = export_to_edge(Gelu(), (inputs,)).exported_program().graph_module
1315+
gm = single_op_builder(
1316+
placeholders=(inputs,),
1317+
op=exir_ops.edge.aten.gelu.default,
1318+
args=(inputs,),
1319+
kwargs={"approximate": "tanh"},
1320+
)
1321+
gm = ExportPass().call(gm).graph_module
12981322

1299-
p = ReplaceGeluWithApproximateGeluPass()
1300-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1323+
p = ReplaceAtenApproxGeluWithApproxGeluPass()
1324+
graph_after_passes = p.call(gm).graph_module
13011325

13021326
# Assert that aten.gelu op was decomposed
13031327
self.assertEqual(

0 commit comments

Comments
 (0)