Skip to content

Commit 35754d1

Browse files
authored
Create decompose_ops.py and test_decompose_ops.py
Differential Revision: D75826474 Pull Request resolved: #11299
1 parent 8514d86 commit 35754d1

File tree

5 files changed

+242
-116
lines changed

5 files changed

+242
-116
lines changed

backends/cadence/aot/TARGETS

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,24 @@ python_library(
276276
],
277277
)
278278

279+
python_library(
280+
name = "decompose_ops",
281+
srcs = [
282+
"decompose_ops.py",
283+
],
284+
typing = True,
285+
deps = [
286+
":pass_utils",
287+
"//caffe2:torch",
288+
"//executorch/backends/cadence/aot:pass_utils",
289+
"//executorch/exir:pass_base",
290+
"//executorch/exir/dialects:lib",
291+
"//executorch/exir/dialects/edge:lib",
292+
"//executorch/exir/passes:spec_prop_pass",
293+
],
294+
)
295+
296+
279297
python_unittest(
280298
name = "test_graph_builder",
281299
srcs = [
@@ -314,6 +332,27 @@ python_unittest(
314332
],
315333
)
316334

335+
python_unittest(
336+
name = "test_decompose_ops_passes",
337+
srcs = [
338+
"tests/test_decompose_ops_passes.py",
339+
],
340+
supports_static_listing = False,
341+
typing = True,
342+
deps = [
343+
"fbsource//third-party/pypi/parameterized:parameterized",
344+
":compiler",
345+
":decompose_ops",
346+
"//caffe2:torch",
347+
"//executorch/backends/cadence/aot:compiler",
348+
"//executorch/backends/cadence/aot:graph_builder",
349+
"//executorch/backends/cadence/aot:pass_utils",
350+
"//executorch/exir:pass_base",
351+
"//executorch/exir/dialects:lib",
352+
"//executorch/exir/passes:lib",
353+
],
354+
)
355+
317356
python_unittest(
318357
name = "test_fusion_ops_passes",
319358
srcs = [
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
9+
# This file contains all the functions that decompose one op into simpler ops in the
10+
# graph. The functions decomposing ops for models deployed with Jarvis are grouped
11+
# together in class 'DecomposeOpsInGraph'. Some examples of functions in the class are
12+
# 1. functions that decompose an ATen gelu op into an equivalent series of simpler ops
13+
14+
# pyre-strict
15+
16+
from typing import Dict
17+
18+
from executorch.backends.cadence.aot.pass_utils import (
19+
CadencePassAttribute,
20+
register_cadence_pass,
21+
)
22+
from executorch.exir.dialects._ops import ops as exir_ops
23+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
24+
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
25+
from torch.fx.node import Argument
26+
27+
28+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
29+
class DecomposeAtenApproxGeluPass(ExportPass):
30+
"""
31+
Decompose the aten gelu op with an approximate arg to a series of simpler ops
32+
"""
33+
34+
def call_operator(
35+
self,
36+
op: EdgeOpOverload,
37+
args: tuple[Argument, ...],
38+
kwargs: Dict[str, Argument],
39+
meta: NodeMetadata,
40+
) -> ProxyValue:
41+
# compute the approximate gelu (0.7978845608028654 is sqrt(2 / pi))
42+
# as 0.5 * x * (1 + torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3)))
43+
44+
# Get 0.5 * x
45+
half = super().call_operator(
46+
exir_ops.edge.aten.mul.Tensor,
47+
(args[0], 0.5),
48+
{},
49+
meta,
50+
)
51+
52+
scaled = super().call_operator(
53+
exir_ops.edge.aten.mul.Tensor,
54+
(args[0], 0.044715),
55+
{},
56+
meta,
57+
)
58+
59+
# Get x^2 (note that we use mul.Tensor twice instead of pow.Tensor because
60+
# it is much more efficient on DSP backends)
61+
scaled_square = super().call_operator(
62+
exir_ops.edge.aten.mul.Tensor,
63+
(scaled, args[0]),
64+
{},
65+
meta,
66+
)
67+
68+
# Get x^3
69+
scaled_cubed = super().call_operator(
70+
exir_ops.edge.aten.mul.Tensor,
71+
(scaled_square, args[0]),
72+
{},
73+
meta,
74+
)
75+
76+
# Get x + 0.044715 * x^3
77+
inner_sum = super().call_operator(
78+
exir_ops.edge.aten.add.Tensor,
79+
(scaled_cubed, args[0]),
80+
{},
81+
meta,
82+
)
83+
84+
# Get 0.7978845608028654 * ( x + 0.044715 * x^3)
85+
scaled_sum = super().call_operator(
86+
exir_ops.edge.aten.mul.Tensor,
87+
(inner_sum, 0.7978845608028654),
88+
{},
89+
meta,
90+
)
91+
92+
# Get torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3))
93+
tanh = super().call_operator(
94+
exir_ops.edge.aten.tanh.default,
95+
(scaled_sum,),
96+
{},
97+
meta,
98+
)
99+
100+
# Get 1 + torch.tanh(0.79788456 * ( x + 0.044715 * x^3))
101+
# TODO(): Check why this is not working properly with integer values (e.g. 1 instead of 1.)
102+
outer_sum = super().call_operator(
103+
exir_ops.edge.aten.add.Tensor,
104+
(tanh, 1.0),
105+
{},
106+
meta,
107+
)
108+
109+
# Return the final result
110+
return super().call_operator(
111+
exir_ops.edge.aten.mul.Tensor,
112+
(half, outer_sum),
113+
{},
114+
meta,
115+
)
116+
117+
118+
# This class encapsulates all the functions that decompose one op in the graph.
119+
class CadenceDecomposeOpsInGraph:
120+
passes = [
121+
DecomposeAtenApproxGeluPass,
122+
]

backends/cadence/aot/replace_ops.py

Lines changed: 1 addition & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -2078,89 +2078,11 @@ def call_operator(
20782078
kwargs: Dict[str, Argument],
20792079
meta: NodeMetadata,
20802080
) -> ProxyValue:
2081-
if "approximate" not in kwargs:
2082-
return super().call_operator(op, args, kwargs, meta)
2083-
20842081
if op not in {
20852082
exir_ops.edge.aten.gelu.default,
20862083
}:
20872084
return super().call_operator(op, args, kwargs, meta)
2088-
2089-
# compute the approximate gelu (0.7978845608028654 is sqrt(2 / pi))
2090-
# as 0.5 * x * (1 + torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3)))
2091-
2092-
# Get 0.5 * x
2093-
half = super().call_operator(
2094-
exir_ops.edge.aten.mul.Tensor,
2095-
(args[0], 0.5),
2096-
{},
2097-
meta,
2098-
)
2099-
2100-
scaled = super().call_operator(
2101-
exir_ops.edge.aten.mul.Tensor,
2102-
(args[0], 0.044715),
2103-
{},
2104-
meta,
2105-
)
2106-
2107-
# Get x^2 (note that we use mul.Tensor twice instead of pow.Tensor because
2108-
# it is much more efficient on DSP backends)
2109-
scaled_square = super().call_operator(
2110-
exir_ops.edge.aten.mul.Tensor,
2111-
(scaled, args[0]),
2112-
{},
2113-
meta,
2114-
)
2115-
2116-
# Get x^3
2117-
scaled_cubed = super().call_operator(
2118-
exir_ops.edge.aten.mul.Tensor,
2119-
(scaled_square, args[0]),
2120-
{},
2121-
meta,
2122-
)
2123-
2124-
# Get x + 0.044715 * x^3
2125-
inner_sum = super().call_operator(
2126-
exir_ops.edge.aten.add.Tensor,
2127-
(scaled_cubed, args[0]),
2128-
{},
2129-
meta,
2130-
)
2131-
2132-
# Get 0.7978845608028654 * ( x + 0.044715 * x^3)
2133-
scaled_sum = super().call_operator(
2134-
exir_ops.edge.aten.mul.Tensor,
2135-
(inner_sum, 0.7978845608028654),
2136-
{},
2137-
meta,
2138-
)
2139-
2140-
# Get torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3))
2141-
tanh = super().call_operator(
2142-
exir_ops.edge.aten.tanh.default,
2143-
(scaled_sum,),
2144-
{},
2145-
meta,
2146-
)
2147-
2148-
# Get 1 + torch.tanh(0.79788456 * ( x + 0.044715 * x^3))
2149-
# TODO(): Check why this is not working properly with integer values (e.g. 1 instead of 1.)
2150-
outer_sum = super().call_operator(
2151-
exir_ops.edge.aten.add.Tensor,
2152-
(tanh, 1.0),
2153-
{},
2154-
meta,
2155-
)
2156-
2157-
# Retunr the final result
2158-
return super().call_operator(
2159-
exir_ops.edge.aten.mul.Tensor,
2160-
(half, outer_sum),
2161-
{},
2162-
meta,
2163-
)
2085+
return super().call_operator(op, args, kwargs, meta)
21642086

21652087

21662088
# Adapted from fbcode/pyspeech/opt_passes/replace_ops.py
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import unittest
10+
from typing import Union
11+
12+
import torch
13+
from executorch.backends.cadence.aot.decompose_ops import DecomposeAtenApproxGeluPass
14+
from executorch.backends.cadence.aot.graph_builder import single_op_builder
15+
from executorch.backends.cadence.aot.pass_utils import count_node
16+
from executorch.exir.dialects._ops import ops as exir_ops
17+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
18+
from executorch.exir.pass_base import ExportPass
19+
20+
21+
class TestDecomposeOpsPasses(unittest.TestCase):
22+
def assertTargetCountEqual(
23+
self,
24+
graph_module: torch.fx.GraphModule,
25+
target: Union[EdgeOpOverload, str],
26+
expected_count: int,
27+
) -> None:
28+
"""Helper function to check the number of nodes with a given target."""
29+
actual_count = count_node(graph_module, target)
30+
self.assertEqual(
31+
actual_count,
32+
expected_count,
33+
f"{target} count mismatch for graph {graph_module}",
34+
)
35+
36+
def assertTargetCountsEqual(
37+
self,
38+
graph_module: torch.fx.GraphModule,
39+
targets_and_counts: list[tuple[Union[EdgeOpOverload, str], int]],
40+
) -> None:
41+
"""Helper function to check the number of nodes of all types for a given target."""
42+
for target, expected_count in targets_and_counts:
43+
self.assertTargetCountEqual(graph_module, target, expected_count)
44+
45+
def test_decompose_aten_approximate_gelu(self) -> None:
46+
inputs = torch.randn(2, 1, 64)
47+
48+
gm = single_op_builder(
49+
placeholders=(inputs,),
50+
op=exir_ops.edge.aten.gelu.default,
51+
args=(inputs,),
52+
kwargs={"approximate": "tanh"},
53+
)
54+
gm = ExportPass().call(gm).graph_module
55+
56+
p = DecomposeAtenApproxGeluPass()
57+
graph_after_passes = p.call(gm).graph_module
58+
59+
# Assert that aten.gelu op was decomposed
60+
self.assertEqual(
61+
count_node(
62+
graph_after_passes,
63+
exir_ops.edge.aten.gelu.default,
64+
),
65+
0,
66+
)
67+
68+
# The decomposition should have one tanh, 2 add and 6 mul
69+
self.assertEqual(
70+
count_node(graph_after_passes, exir_ops.edge.aten.tanh.default),
71+
1,
72+
)
73+
self.assertEqual(
74+
count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor),
75+
2,
76+
)
77+
self.assertEqual(
78+
count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor),
79+
6,
80+
)

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,43 +1309,6 @@ def test_no_replace_aten_gelu_with_approximate_gelu(self):
13091309
1,
13101310
)
13111311

1312-
def test_replace_aten_approximate_gelu_with_approximate_gelu(self):
1313-
inputs = torch.randn(2, 1, 64)
1314-
1315-
gm = single_op_builder(
1316-
placeholders=(inputs,),
1317-
op=exir_ops.edge.aten.gelu.default,
1318-
args=(inputs,),
1319-
kwargs={"approximate": "tanh"},
1320-
)
1321-
gm = ExportPass().call(gm).graph_module
1322-
1323-
p = ReplaceAtenApproxGeluWithApproxGeluPass()
1324-
graph_after_passes = p.call(gm).graph_module
1325-
1326-
# Assert that aten.gelu op was decomposed
1327-
self.assertEqual(
1328-
count_node(
1329-
graph_after_passes,
1330-
exir_ops.edge.aten.gelu.default,
1331-
),
1332-
0,
1333-
)
1334-
1335-
# The decomposition should have one tanh, 2 add and 6 mul
1336-
self.assertEqual(
1337-
count_node(graph_after_passes, exir_ops.edge.aten.tanh.default),
1338-
1,
1339-
)
1340-
self.assertEqual(
1341-
count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor),
1342-
2,
1343-
)
1344-
self.assertEqual(
1345-
count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor),
1346-
6,
1347-
)
1348-
13491312
def test_replace_split_with_sizes_with_slice(self):
13501313
builder = GraphBuilder()
13511314
x = builder.placeholder("x", torch.randn(1, 16, 8, 4))

0 commit comments

Comments
 (0)