Skip to content

Commit 06bbdfd

Browse files
ethanng72facebook-github-bot
authored andcommitted
Create decompose_ops.py and test_decompose_ops.py (#11299)
Summary: Create new class and test suite for passes that decompose an op into a equivalent series of simpler ops Rollback Plan: Test Plan: Imported from GitHub, without a `Test Plan:` line. Rollback Plan: Reviewed By: hsharma35 Differential Revision: D75826474 Pulled By: ethanng72
1 parent 93b1a0c commit 06bbdfd

File tree

5 files changed

+239
-116
lines changed

5 files changed

+239
-116
lines changed

backends/cadence/aot/TARGETS

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,24 @@ python_library(
276276
],
277277
)
278278

279+
python_library(
280+
name = "decompose_ops",
281+
srcs = [
282+
"decompose_ops.py",
283+
],
284+
typing = True,
285+
deps = [
286+
":pass_utils",
287+
"//caffe2:torch",
288+
"//executorch/backends/cadence/aot:pass_utils",
289+
"//executorch/exir:pass_base",
290+
"//executorch/exir/dialects:lib",
291+
"//executorch/exir/dialects/edge:lib",
292+
"//executorch/exir/passes:spec_prop_pass",
293+
],
294+
)
295+
296+
279297
python_unittest(
280298
name = "test_graph_builder",
281299
srcs = [
@@ -314,6 +332,27 @@ python_unittest(
314332
],
315333
)
316334

335+
python_unittest(
336+
name = "test_decompose_ops_passes",
337+
srcs = [
338+
"tests/test_decompose_ops_passes.py",
339+
],
340+
supports_static_listing = False,
341+
typing = True,
342+
deps = [
343+
"fbsource//third-party/pypi/parameterized:parameterized",
344+
":compiler",
345+
":decompose_ops",
346+
"//caffe2:torch",
347+
"//executorch/backends/cadence/aot:compiler",
348+
"//executorch/backends/cadence/aot:graph_builder",
349+
"//executorch/backends/cadence/aot:pass_utils",
350+
"//executorch/exir:pass_base",
351+
"//executorch/exir/dialects:lib",
352+
"//executorch/exir/passes:lib",
353+
],
354+
)
355+
317356
python_unittest(
318357
name = "test_fusion_ops_passes",
319358
srcs = [
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
9+
# This file contains all the functions that decompose one op into simpler ops in the
10+
# graph. The functions decomposing ops for models deployed with Jarvis are grouped
11+
# together in class 'DecomposeOpsInGraph'. Some examples of functions in the class are
12+
# 1. functions that decompose an ATen gelu op into an equivalent series of simpler ops
13+
14+
15+
from typing import Dict, Tuple
16+
17+
from executorch.backends.cadence.aot.pass_utils import (
18+
CadencePassAttribute,
19+
register_cadence_pass,
20+
)
21+
from executorch.exir.dialects._ops import ops as exir_ops
22+
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
23+
from torch.fx.node import Argument
24+
25+
26+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
27+
class DecomposeAtenApproxGeluPass(ExportPass):
28+
"""
29+
Decompose the aten gelu op with an approximate arg to a series of simpler ops
30+
"""
31+
32+
def call_operator(
33+
self,
34+
op,
35+
args: Tuple[Argument, ...],
36+
kwargs: Dict[str, Argument],
37+
meta: NodeMetadata,
38+
) -> ProxyValue:
39+
# compute the approximate gelu (0.7978845608028654 is sqrt(2 / pi))
40+
# as 0.5 * x * (1 + torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3)))
41+
42+
# Get 0.5 * x
43+
half = super().call_operator(
44+
exir_ops.edge.aten.mul.Tensor,
45+
(args[0], 0.5),
46+
{},
47+
meta,
48+
)
49+
50+
scaled = super().call_operator(
51+
exir_ops.edge.aten.mul.Tensor,
52+
(args[0], 0.044715),
53+
{},
54+
meta,
55+
)
56+
57+
# Get x^2 (note that we use mul.Tensor twice instead of pow.Tensor because
58+
# it is much more efficient on DSP backends)
59+
scaled_square = super().call_operator(
60+
exir_ops.edge.aten.mul.Tensor,
61+
(scaled, args[0]),
62+
{},
63+
meta,
64+
)
65+
66+
# Get x^3
67+
scaled_cubed = super().call_operator(
68+
exir_ops.edge.aten.mul.Tensor,
69+
(scaled_square, args[0]),
70+
{},
71+
meta,
72+
)
73+
74+
# Get x + 0.044715 * x^3
75+
inner_sum = super().call_operator(
76+
exir_ops.edge.aten.add.Tensor,
77+
(scaled_cubed, args[0]),
78+
{},
79+
meta,
80+
)
81+
82+
# Get 0.7978845608028654 * ( x + 0.044715 * x^3)
83+
scaled_sum = super().call_operator(
84+
exir_ops.edge.aten.mul.Tensor,
85+
(inner_sum, 0.7978845608028654),
86+
{},
87+
meta,
88+
)
89+
90+
# Get torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3))
91+
tanh = super().call_operator(
92+
exir_ops.edge.aten.tanh.default,
93+
(scaled_sum,),
94+
{},
95+
meta,
96+
)
97+
98+
# Get 1 + torch.tanh(0.79788456 * ( x + 0.044715 * x^3))
99+
# TODO(): Check why this is not working properly with integer values (e.g. 1 instead of 1.)
100+
outer_sum = super().call_operator(
101+
exir_ops.edge.aten.add.Tensor,
102+
(tanh, 1.0),
103+
{},
104+
meta,
105+
)
106+
107+
# Return the final result
108+
return super().call_operator(
109+
exir_ops.edge.aten.mul.Tensor,
110+
(half, outer_sum),
111+
{},
112+
meta,
113+
)
114+
115+
116+
# This class encapsulates all the functions that decompose one op in the graph.
117+
class CadenceDecomposeOpsInGraph:
118+
passes = [
119+
DecomposeAtenApproxGeluPass,
120+
]

backends/cadence/aot/replace_ops.py

Lines changed: 1 addition & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -2078,89 +2078,11 @@ def call_operator(
20782078
kwargs: Dict[str, Argument],
20792079
meta: NodeMetadata,
20802080
) -> ProxyValue:
2081-
if "approximate" not in kwargs:
2082-
return super().call_operator(op, args, kwargs, meta)
2083-
20842081
if op not in {
20852082
exir_ops.edge.aten.gelu.default,
20862083
}:
20872084
return super().call_operator(op, args, kwargs, meta)
2088-
2089-
# compute the approximate gelu (0.7978845608028654 is sqrt(2 / pi))
2090-
# as 0.5 * x * (1 + torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3)))
2091-
2092-
# Get 0.5 * x
2093-
half = super().call_operator(
2094-
exir_ops.edge.aten.mul.Tensor,
2095-
(args[0], 0.5),
2096-
{},
2097-
meta,
2098-
)
2099-
2100-
scaled = super().call_operator(
2101-
exir_ops.edge.aten.mul.Tensor,
2102-
(args[0], 0.044715),
2103-
{},
2104-
meta,
2105-
)
2106-
2107-
# Get x^2 (note that we use mul.Tensor twice instead of pow.Tensor because
2108-
# it is much more efficient on DSP backends)
2109-
scaled_square = super().call_operator(
2110-
exir_ops.edge.aten.mul.Tensor,
2111-
(scaled, args[0]),
2112-
{},
2113-
meta,
2114-
)
2115-
2116-
# Get x^3
2117-
scaled_cubed = super().call_operator(
2118-
exir_ops.edge.aten.mul.Tensor,
2119-
(scaled_square, args[0]),
2120-
{},
2121-
meta,
2122-
)
2123-
2124-
# Get x + 0.044715 * x^3
2125-
inner_sum = super().call_operator(
2126-
exir_ops.edge.aten.add.Tensor,
2127-
(scaled_cubed, args[0]),
2128-
{},
2129-
meta,
2130-
)
2131-
2132-
# Get 0.7978845608028654 * ( x + 0.044715 * x^3)
2133-
scaled_sum = super().call_operator(
2134-
exir_ops.edge.aten.mul.Tensor,
2135-
(inner_sum, 0.7978845608028654),
2136-
{},
2137-
meta,
2138-
)
2139-
2140-
# Get torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3))
2141-
tanh = super().call_operator(
2142-
exir_ops.edge.aten.tanh.default,
2143-
(scaled_sum,),
2144-
{},
2145-
meta,
2146-
)
2147-
2148-
# Get 1 + torch.tanh(0.79788456 * ( x + 0.044715 * x^3))
2149-
# TODO(): Check why this is not working properly with integer values (e.g. 1 instead of 1.)
2150-
outer_sum = super().call_operator(
2151-
exir_ops.edge.aten.add.Tensor,
2152-
(tanh, 1.0),
2153-
{},
2154-
meta,
2155-
)
2156-
2157-
# Retunr the final result
2158-
return super().call_operator(
2159-
exir_ops.edge.aten.mul.Tensor,
2160-
(half, outer_sum),
2161-
{},
2162-
meta,
2163-
)
2085+
return super().call_operator(op, args, kwargs, meta)
21642086

21652087

21662088
# Adapted from fbcode/pyspeech/opt_passes/replace_ops.py
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import unittest
9+
from typing import List, Tuple, Union
10+
11+
import torch
12+
from executorch.backends.cadence.aot.decompose_ops import DecomposeAtenApproxGeluPass
13+
from executorch.backends.cadence.aot.graph_builder import single_op_builder
14+
from executorch.backends.cadence.aot.pass_utils import count_node
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
17+
from executorch.exir.pass_base import ExportPass
18+
19+
20+
class TestDecomposeOpsPasses(unittest.TestCase):
21+
def assertTargetCountEqual(
22+
self,
23+
graph_module: torch.fx.GraphModule,
24+
target: Union[EdgeOpOverload, str],
25+
expected_count: int,
26+
):
27+
"""Helper function to check the number of nodes with a given target."""
28+
actual_count = count_node(graph_module, target)
29+
self.assertEqual(
30+
actual_count,
31+
expected_count,
32+
f"{target} count mismatch for graph {graph_module}",
33+
)
34+
35+
def assertTargetCountsEqual(
36+
self,
37+
graph_module: torch.fx.GraphModule,
38+
targets_and_counts: List[Tuple[Union[EdgeOpOverload, str], int]],
39+
):
40+
"""Helper function to check the number of nodes of all types for a given target."""
41+
for target, expected_count in targets_and_counts:
42+
self.assertTargetCountEqual(graph_module, target, expected_count)
43+
44+
def test_decompose_aten_approximate_gelu(self):
45+
inputs = torch.randn(2, 1, 64)
46+
47+
gm = single_op_builder(
48+
placeholders=(inputs,),
49+
op=exir_ops.edge.aten.gelu.default,
50+
args=(inputs,),
51+
kwargs={"approximate": "tanh"},
52+
)
53+
gm = ExportPass().call(gm).graph_module
54+
55+
p = DecomposeAtenApproxGeluPass()
56+
graph_after_passes = p.call(gm).graph_module
57+
58+
# Assert that aten.gelu op was decomposed
59+
self.assertEqual(
60+
count_node(
61+
graph_after_passes,
62+
exir_ops.edge.aten.gelu.default,
63+
),
64+
0,
65+
)
66+
67+
# The decomposition should have one tanh, 2 add and 6 mul
68+
self.assertEqual(
69+
count_node(graph_after_passes, exir_ops.edge.aten.tanh.default),
70+
1,
71+
)
72+
self.assertEqual(
73+
count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor),
74+
2,
75+
)
76+
self.assertEqual(
77+
count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor),
78+
6,
79+
)

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,43 +1309,6 @@ def test_no_replace_aten_gelu_with_approximate_gelu(self):
13091309
1,
13101310
)
13111311

1312-
def test_replace_aten_approximate_gelu_with_approximate_gelu(self):
1313-
inputs = torch.randn(2, 1, 64)
1314-
1315-
gm = single_op_builder(
1316-
placeholders=(inputs,),
1317-
op=exir_ops.edge.aten.gelu.default,
1318-
args=(inputs,),
1319-
kwargs={"approximate": "tanh"},
1320-
)
1321-
gm = ExportPass().call(gm).graph_module
1322-
1323-
p = ReplaceAtenApproxGeluWithApproxGeluPass()
1324-
graph_after_passes = p.call(gm).graph_module
1325-
1326-
# Assert that aten.gelu op was decomposed
1327-
self.assertEqual(
1328-
count_node(
1329-
graph_after_passes,
1330-
exir_ops.edge.aten.gelu.default,
1331-
),
1332-
0,
1333-
)
1334-
1335-
# The decomposition should have one tanh, 2 add and 6 mul
1336-
self.assertEqual(
1337-
count_node(graph_after_passes, exir_ops.edge.aten.tanh.default),
1338-
1,
1339-
)
1340-
self.assertEqual(
1341-
count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor),
1342-
2,
1343-
)
1344-
self.assertEqual(
1345-
count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor),
1346-
6,
1347-
)
1348-
13491312
def test_replace_split_with_sizes_with_slice(self):
13501313
builder = GraphBuilder()
13511314
x = builder.placeholder("x", torch.randn(1, 16, 8, 4))

0 commit comments

Comments
 (0)