Skip to content

Commit e21336b

Browse files
ethansfngfacebook-github-bot
authored andcommitted
Create decompose_ops.py and test_decompose_ops.py
Summary: Create new class and test suite for passes that decompose an op into a equivalent series of simpler ops Rollback Plan: Differential Revision: D75826474
1 parent b5567be commit e21336b

File tree

5 files changed

+255
-116
lines changed

5 files changed

+255
-116
lines changed

backends/cadence/aot/TARGETS

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,24 @@ python_library(
276276
],
277277
)
278278

279+
python_library(
280+
name = "decompose_ops",
281+
srcs = [
282+
"decompose_ops.py",
283+
],
284+
typing = True,
285+
deps = [
286+
":pass_utils",
287+
"//caffe2:torch",
288+
"//executorch/backends/cadence/aot:pass_utils",
289+
"//executorch/exir:pass_base",
290+
"//executorch/exir/dialects:lib",
291+
"//executorch/exir/dialects/edge:lib",
292+
"//executorch/exir/passes:spec_prop_pass",
293+
],
294+
)
295+
296+
279297
python_unittest(
280298
name = "test_graph_builder",
281299
srcs = [
@@ -314,6 +332,27 @@ python_unittest(
314332
],
315333
)
316334

335+
python_unittest(
336+
name = "test_decompose_ops_passes",
337+
srcs = [
338+
"tests/test_decompose_ops_passes.py",
339+
],
340+
supports_static_listing = False,
341+
typing = True,
342+
deps = [
343+
"fbsource//third-party/pypi/parameterized:parameterized",
344+
":compiler",
345+
":decompose_ops",
346+
"//caffe2:torch",
347+
"//executorch/backends/cadence/aot:compiler",
348+
"//executorch/backends/cadence/aot:graph_builder",
349+
"//executorch/backends/cadence/aot:pass_utils",
350+
"//executorch/exir:pass_base",
351+
"//executorch/exir/dialects:lib",
352+
"//executorch/exir/passes:lib",
353+
],
354+
)
355+
317356
python_unittest(
318357
name = "test_fusion_ops_passes",
319358
srcs = [
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
9+
# This file contains all the functions that decompose one op into simpler ops in the
10+
# graph. The functions decomposing ops for models deployed with Jarvis are grouped
11+
# together in class 'DecomposeOpsInGraph'. Some examples of functions in the class are
12+
# 1. functions that decompose an ATen gelu op into an equivalent series of simpler ops
13+
14+
# pyre-unsafe
15+
16+
from typing import Dict, Tuple
17+
18+
from executorch.backends.cadence.aot.pass_utils import (
19+
CadencePassAttribute,
20+
register_cadence_pass,
21+
)
22+
from executorch.exir.dialects._ops import ops as exir_ops
23+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
24+
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
25+
from torch.fx.node import Argument
26+
27+
# A map to represent ops that:
28+
# (a) are functionally equivalent wrt. Jarvis; and
29+
# (b) have identical arguments
30+
# An op whose target is 'key' in this dict can be replaced by the functionally euivalent
31+
# op whose target is 'value'. The replacement would just involve changing the op target.
32+
functionally_equivalent_op_targets: Dict[EdgeOpOverload, EdgeOpOverload] = {
33+
exir_ops.edge.aten.relu_.default: exir_ops.edge.aten.relu.default,
34+
exir_ops.edge.aten.unsafe_split.Tensor: exir_ops.edge.aten.split_copy.Tensor,
35+
}
36+
37+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
38+
class DecomposeAtenApproxGeluPass(ExportPass):
39+
"""
40+
Decompose the aten gelu op with an approximate arg to a series of simpler ops
41+
"""
42+
43+
def call_operator(
44+
self,
45+
op,
46+
args: Tuple[Argument, ...],
47+
kwargs: Dict[str, Argument],
48+
meta: NodeMetadata,
49+
) -> ProxyValue:
50+
# compute the approximate gelu (0.7978845608028654 is sqrt(2 / pi))
51+
# as 0.5 * x * (1 + torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3)))
52+
53+
# Get 0.5 * x
54+
half = super().call_operator(
55+
exir_ops.edge.aten.mul.Tensor,
56+
(args[0], 0.5),
57+
{},
58+
meta,
59+
)
60+
61+
scaled = super().call_operator(
62+
exir_ops.edge.aten.mul.Tensor,
63+
(args[0], 0.044715),
64+
{},
65+
meta,
66+
)
67+
68+
# Get x^2 (note that we use mul.Tensor twice instead of pow.Tensor because
69+
# it is much more efficient on DSP backends)
70+
scaled_square = super().call_operator(
71+
exir_ops.edge.aten.mul.Tensor,
72+
(scaled, args[0]),
73+
{},
74+
meta,
75+
)
76+
77+
# Get x^3
78+
scaled_cubed = super().call_operator(
79+
exir_ops.edge.aten.mul.Tensor,
80+
(scaled_square, args[0]),
81+
{},
82+
meta,
83+
)
84+
85+
# Get x + 0.044715 * x^3
86+
inner_sum = super().call_operator(
87+
exir_ops.edge.aten.add.Tensor,
88+
(scaled_cubed, args[0]),
89+
{},
90+
meta,
91+
)
92+
93+
# Get 0.7978845608028654 * ( x + 0.044715 * x^3)
94+
scaled_sum = super().call_operator(
95+
exir_ops.edge.aten.mul.Tensor,
96+
(inner_sum, 0.7978845608028654),
97+
{},
98+
meta,
99+
)
100+
101+
# Get torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3))
102+
tanh = super().call_operator(
103+
exir_ops.edge.aten.tanh.default,
104+
(scaled_sum,),
105+
{},
106+
meta,
107+
)
108+
109+
# Get 1 + torch.tanh(0.79788456 * ( x + 0.044715 * x^3))
110+
# TODO(): Check why this is not working properly with integer values (e.g. 1 instead of 1.)
111+
outer_sum = super().call_operator(
112+
exir_ops.edge.aten.add.Tensor,
113+
(tanh, 1.0),
114+
{},
115+
meta,
116+
)
117+
118+
# Return the final result
119+
return super().call_operator(
120+
exir_ops.edge.aten.mul.Tensor,
121+
(half, outer_sum),
122+
{},
123+
meta,
124+
)
125+
126+
127+
# This class encapsulates all the functions that decompose one op in the graph.
128+
class CadenceDecomposeOpsInGraph:
129+
passes = [
130+
DecomposeAtenApproxGeluPass,
131+
]

backends/cadence/aot/replace_ops.py

Lines changed: 1 addition & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -2078,89 +2078,11 @@ def call_operator(
20782078
kwargs: Dict[str, Argument],
20792079
meta: NodeMetadata,
20802080
) -> ProxyValue:
2081-
if "approximate" not in kwargs:
2082-
return super().call_operator(op, args, kwargs, meta)
2083-
20842081
if op not in {
20852082
exir_ops.edge.aten.gelu.default,
20862083
}:
20872084
return super().call_operator(op, args, kwargs, meta)
2088-
2089-
# compute the approximate gelu (0.7978845608028654 is sqrt(2 / pi))
2090-
# as 0.5 * x * (1 + torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3)))
2091-
2092-
# Get 0.5 * x
2093-
half = super().call_operator(
2094-
exir_ops.edge.aten.mul.Tensor,
2095-
(args[0], 0.5),
2096-
{},
2097-
meta,
2098-
)
2099-
2100-
scaled = super().call_operator(
2101-
exir_ops.edge.aten.mul.Tensor,
2102-
(args[0], 0.044715),
2103-
{},
2104-
meta,
2105-
)
2106-
2107-
# Get x^2 (note that we use mul.Tensor twice instead of pow.Tensor because
2108-
# it is much more efficient on DSP backends)
2109-
scaled_square = super().call_operator(
2110-
exir_ops.edge.aten.mul.Tensor,
2111-
(scaled, args[0]),
2112-
{},
2113-
meta,
2114-
)
2115-
2116-
# Get x^3
2117-
scaled_cubed = super().call_operator(
2118-
exir_ops.edge.aten.mul.Tensor,
2119-
(scaled_square, args[0]),
2120-
{},
2121-
meta,
2122-
)
2123-
2124-
# Get x + 0.044715 * x^3
2125-
inner_sum = super().call_operator(
2126-
exir_ops.edge.aten.add.Tensor,
2127-
(scaled_cubed, args[0]),
2128-
{},
2129-
meta,
2130-
)
2131-
2132-
# Get 0.7978845608028654 * ( x + 0.044715 * x^3)
2133-
scaled_sum = super().call_operator(
2134-
exir_ops.edge.aten.mul.Tensor,
2135-
(inner_sum, 0.7978845608028654),
2136-
{},
2137-
meta,
2138-
)
2139-
2140-
# Get torch.tanh(0.7978845608028654 * ( x + 0.044715 * x^3))
2141-
tanh = super().call_operator(
2142-
exir_ops.edge.aten.tanh.default,
2143-
(scaled_sum,),
2144-
{},
2145-
meta,
2146-
)
2147-
2148-
# Get 1 + torch.tanh(0.79788456 * ( x + 0.044715 * x^3))
2149-
# TODO(): Check why this is not working properly with integer values (e.g. 1 instead of 1.)
2150-
outer_sum = super().call_operator(
2151-
exir_ops.edge.aten.add.Tensor,
2152-
(tanh, 1.0),
2153-
{},
2154-
meta,
2155-
)
2156-
2157-
# Retunr the final result
2158-
return super().call_operator(
2159-
exir_ops.edge.aten.mul.Tensor,
2160-
(half, outer_sum),
2161-
{},
2162-
meta,
2163-
)
2085+
return super().call_operator(op, args, kwargs, meta)
21642086

21652087

21662088
# Adapted from fbcode/pyspeech/opt_passes/replace_ops.py
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import unittest
10+
from typing import Any, Callable, List, Tuple, Union
11+
12+
import torch
13+
from executorch.backends.cadence.aot.decompose_ops import (
14+
DecomposeAtenApproxGeluPass,
15+
)
16+
from executorch.backends.cadence.aot.graph_builder import (
17+
single_op_builder,
18+
)
19+
from executorch.backends.cadence.aot.pass_utils import count_node
20+
from executorch.exir.dialects._ops import ops as exir_ops
21+
from executorch.exir.pass_base import ExportPass
22+
23+
24+
class TestDecomposeOpsPasses(unittest.TestCase):
25+
def assertTargetCountEqual(
26+
self,
27+
graph_module: torch.fx.GraphModule,
28+
target: Union[Callable[..., Any], str],
29+
expected_count: int,
30+
):
31+
"""Helper function to check the number of nodes with a given target."""
32+
actual_count = count_node(graph_module, target)
33+
self.assertEqual(
34+
actual_count,
35+
expected_count,
36+
f"{target} count mismatch for graph {graph_module}",
37+
)
38+
39+
def assertTargetCountsEqual(
40+
self,
41+
graph_module: torch.fx.GraphModule,
42+
targets_and_counts: List[Tuple[Union[Callable[..., Any], str], int]],
43+
):
44+
"""Helper function to check the number of nodes of all types for a given target."""
45+
for target, expected_count in targets_and_counts:
46+
self.assertTargetCountEqual(graph_module, target, expected_count)
47+
48+
49+
def test_decompose_aten_approximate_gelu(self):
50+
inputs = torch.randn(2, 1, 64)
51+
52+
gm = single_op_builder(
53+
placeholders=(inputs,),
54+
op=exir_ops.edge.aten.gelu.default,
55+
args=(inputs,),
56+
kwargs={"approximate": "tanh"},
57+
)
58+
gm = ExportPass().call(gm).graph_module
59+
60+
p = DecomposeAtenApproxGeluPass()
61+
graph_after_passes = p.call(gm).graph_module
62+
63+
# Assert that aten.gelu op was decomposed
64+
self.assertEqual(
65+
count_node(
66+
graph_after_passes,
67+
exir_ops.edge.aten.gelu.default,
68+
),
69+
0,
70+
)
71+
72+
# The decomposition should have one tanh, 2 add and 6 mul
73+
self.assertEqual(
74+
count_node(graph_after_passes, exir_ops.edge.aten.tanh.default),
75+
1,
76+
)
77+
self.assertEqual(
78+
count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor),
79+
2,
80+
)
81+
self.assertEqual(
82+
count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor),
83+
6,
84+
)

0 commit comments

Comments
 (0)