Skip to content

Commit 18e7b87

Browse files
author
pytorchbot
committed
2024-11-21 nightly release (a39ea29)
1 parent 02d28b4 commit 18e7b87

File tree

41 files changed

+7923
-613
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+7923
-613
lines changed

.gitmodules

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
[submodule "backends/arm/third-party/ethos-u-core-driver"]
22
path = backends/arm/third-party/ethos-u-core-driver
3-
url = https://git.mlplatform.org/ml/ethos-u/ethos-u-core-driver.git/
3+
url = https://github.com/pytorch-labs/ethos-u-core-driver-mirror
44
[submodule "backends/arm/third-party/serialization_lib"]
55
path = backends/arm/third-party/serialization_lib
6-
url = https://git.mlplatform.org/tosa/serialization_lib.git/
6+
url = https://github.com/pytorch-labs/tosa_serialization_lib-mirror
77
[submodule "backends/vulkan/third-party/Vulkan-Headers"]
88
path = backends/vulkan/third-party/Vulkan-Headers
99
url = https://github.com/KhronosGroup/Vulkan-Headers

backends/cadence/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
## Supported DSPs (in progress)
44
- HiFi Audio
5-
- ...
5+
- Fusion G3
66

77
## Tutorial
88

backends/cadence/aot/TARGETS

Lines changed: 176 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ python_library(
3838
deps = [
3939
":passes",
4040
":utils",
41+
":ops_registrations",
42+
":replace_ops",
4143
"//caffe2:torch",
4244
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
4345
"//executorch/backends/cadence/aot/quantizer:quantizer",
@@ -71,12 +73,16 @@ python_library(
7173
],
7274
deps = [
7375
":utils",
76+
":fuse_ops",
77+
":simplify_ops",
78+
":replace_ops",
79+
":reorder_ops",
80+
":remove_ops",
7481
"//caffe2:torch",
7582
"//executorch/exir:pass_base",
7683
"//executorch/exir/dialects:lib",
7784
"//executorch/exir/passes:lib",
7885
"//executorch/exir/passes:spec_prop_pass",
79-
"//executorch/backends/transforms:remove_clone_ops"
8086
],
8187
)
8288

@@ -163,6 +169,77 @@ python_library(
163169
],
164170
)
165171

172+
python_library(
173+
name = "simplify_ops",
174+
srcs = [
175+
"simplify_ops.py",
176+
],
177+
typing = True,
178+
deps = [
179+
":pass_utils",
180+
"//executorch/backends/cadence/aot:pass_utils",
181+
"//executorch/exir:pass_base",
182+
"//executorch/exir/dialects:lib",
183+
],
184+
)
185+
186+
python_library(
187+
name = "remove_ops",
188+
srcs = [
189+
"remove_ops.py",
190+
],
191+
typing = True,
192+
deps = [
193+
"//caffe2:torch",
194+
"//executorch/backends/cadence/aot:pass_utils",
195+
"//executorch/backends/cadence/aot:simplify_ops",
196+
"//executorch/exir:pass_base",
197+
"//executorch/exir/dialects:lib",
198+
"//executorch/exir/dialects/edge:lib",
199+
"//executorch/exir/passes:spec_prop_pass",
200+
"//executorch/backends/transforms:remove_clone_ops"
201+
],
202+
)
203+
204+
python_library(
205+
name = "reorder_ops",
206+
srcs = [
207+
"reorder_ops.py",
208+
],
209+
typing = True,
210+
deps = [
211+
"//caffe2:torch",
212+
"//executorch/backends/cadence/aot:compiler_utils",
213+
"//executorch/backends/cadence/aot:pass_utils",
214+
"//executorch/backends/cadence/aot:utils",
215+
"//executorch/exir:pass_base",
216+
"//executorch/exir:tensor",
217+
"//executorch/exir/dialects:lib",
218+
"//executorch/exir/dialects/edge:lib",
219+
],
220+
)
221+
222+
python_library(
223+
name = "replace_ops",
224+
srcs = [
225+
"replace_ops.py",
226+
],
227+
typing = True,
228+
deps = [
229+
":pass_utils",
230+
"//caffe2:torch",
231+
"//executorch/backends/cadence/aot:compiler_utils",
232+
"//executorch/backends/cadence/aot:fuse_ops",
233+
"//executorch/backends/cadence/aot:pass_utils",
234+
"//executorch/backends/cadence/aot:remove_ops",
235+
"//executorch/backends/cadence/aot:utils",
236+
"//executorch/exir:pass_base",
237+
"//executorch/exir/dialects:lib",
238+
"//executorch/exir/dialects/edge:lib",
239+
"//executorch/exir/passes:spec_prop_pass",
240+
],
241+
)
242+
166243
python_unittest(
167244
name = "test_graph_builder",
168245
srcs = [
@@ -179,3 +256,101 @@ python_unittest(
179256
":ops_registrations"
180257
],
181258
)
259+
260+
python_unittest(
261+
name = "test_replace_ops_passes",
262+
srcs = [
263+
"tests/test_replace_ops_passes.py",
264+
],
265+
supports_static_listing = False,
266+
typing = True,
267+
deps = [
268+
"fbsource//third-party/pypi/parameterized:parameterized",
269+
":compiler",
270+
":replace_ops",
271+
"//caffe2:torch",
272+
"//executorch/backends/cadence/aot:compiler",
273+
"//executorch/backends/cadence/aot:graph_builder",
274+
"//executorch/backends/cadence/aot:pass_utils",
275+
"//executorch/exir:pass_base",
276+
"//executorch/exir/dialects:lib",
277+
"//executorch/exir/passes:lib",
278+
],
279+
)
280+
281+
python_unittest(
282+
name = "test_fusion_ops_passes",
283+
srcs = [
284+
"tests/test_fusion_ops_passes.py",
285+
],
286+
typing = True,
287+
deps = [
288+
":compiler",
289+
"//caffe2:torch",
290+
"//executorch/backends/cadence/aot:compiler",
291+
"//executorch/backends/cadence/aot:fuse_ops",
292+
"//executorch/backends/cadence/aot:graph_builder",
293+
"//executorch/backends/cadence/aot:ops_registrations",
294+
"//executorch/backends/cadence/aot:pass_utils",
295+
"//executorch/exir/dialects:lib",
296+
"//executorch/exir/dialects/edge:lib",
297+
],
298+
)
299+
300+
python_unittest(
301+
name = "test_remove_ops_passes",
302+
srcs = [
303+
"tests/test_remove_ops_passes.py",
304+
],
305+
supports_static_listing = False,
306+
typing = True,
307+
deps = [
308+
"fbsource//third-party/pypi/parameterized:parameterized",
309+
"fbsource//third-party/pypi/pyre-extensions:pyre-extensions",
310+
":compiler",
311+
"//caffe2:torch",
312+
"//executorch/backends/cadence/aot:compiler",
313+
"//executorch/backends/cadence/aot:ops_registrations",
314+
"//executorch/backends/cadence/aot:pass_utils",
315+
"//executorch/backends/cadence/aot:remove_ops",
316+
"//executorch/backends/cadence/aot/quantizer:quantizer",
317+
"//executorch/exir/dialects:lib",
318+
],
319+
)
320+
321+
python_unittest(
322+
name = "test_simplify_ops_passes",
323+
srcs = [
324+
"tests/test_simplify_ops_passes.py",
325+
],
326+
supports_static_listing = False,
327+
typing = True,
328+
deps = [
329+
"fbsource//third-party/pypi/parameterized:parameterized",
330+
"//caffe2:torch",
331+
"//executorch/backends/cadence/aot:compiler",
332+
"//executorch/backends/cadence/aot:ops_registrations",
333+
"//executorch/backends/cadence/aot:pass_utils",
334+
"//executorch/backends/cadence/aot:simplify_ops",
335+
"//executorch/exir/dialects:lib",
336+
],
337+
)
338+
339+
python_unittest(
340+
name = "test_reorder_ops_passes",
341+
srcs = [
342+
"tests/test_reorder_ops_passes.py",
343+
],
344+
typing = True,
345+
deps = [
346+
":compiler",
347+
":pass_utils",
348+
"//caffe2:torch",
349+
"//executorch/backends/cadence/aot:compiler",
350+
"//executorch/backends/cadence/aot:fuse_ops",
351+
"//executorch/backends/cadence/aot:ops_registrations",
352+
"//executorch/backends/cadence/aot:pass_utils",
353+
"//executorch/backends/cadence/aot:reorder_ops",
354+
"//executorch/exir/dialects:lib",
355+
],
356+
)

backends/cadence/aot/compiler.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
from pathlib import Path
1111
from typing import Callable, cast, Optional
1212

13+
import executorch.backends.cadence.aot.ops_registrations # noqa
1314
import torch
14-
15-
from executorch.backends.cadence.aot.passes import ReplaceSafeSoftmaxWithSoftmax
1615
from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion
1716
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
17+
18+
from executorch.backends.cadence.aot.replace_ops import ReplaceSafeSoftmaxWithSoftmax
1819
from executorch.backends.cadence.aot.utils import model_gm_has_SDPA, model_is_quantized
1920
from executorch.backends.transforms.decompose_sdpa import (
2021
DecomposeScaledDotProductAttention,
@@ -193,9 +194,6 @@ def export_to_edge(
193194
return edge_prog_manager
194195

195196

196-
# Export the model and lower it to an EdgeProgramManager (in edge IR), and
197-
# apply passes specific to Cadence DSP execution. Return both to print the
198-
# differences.
199197
def export_to_cadence(
200198
model: torch.nn.Module,
201199
inputs: tuple[object, ...],
@@ -215,6 +213,25 @@ def export_to_cadence(
215213
return cadence_prog_manager
216214

217215

216+
def quantize_and_export_to_cadence(
217+
model: torch.nn.Module,
218+
inputs: tuple[object, ...],
219+
dump_graphs: bool = False,
220+
opt_level: int = 1,
221+
) -> EdgeProgramManager:
222+
quantized_model = quantize_pt2(model, inputs)
223+
224+
return export_to_cadence(
225+
quantized_model,
226+
inputs,
227+
opt_level=opt_level,
228+
dump_graphs=dump_graphs,
229+
)
230+
231+
232+
# Export the model and lower it to an EdgeProgramManager (in edge IR), and
233+
# apply passes specific to Cadence DSP execution. Return both to print the
234+
# differences.
218235
def export_to_executorch_gen_etrecord(
219236
model: torch.nn.Module,
220237
inputs: tuple[object, ...],

backends/cadence/aot/fuse_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1022,7 +1022,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
10221022
return PassResult(graph_module, True)
10231023

10241024

1025-
class FuseOpsInGraph:
1025+
class CadenceFuseOpsInGraph:
10261026
passes = [
10271027
FuseMMWithAdd,
10281028
FuseBatchNormWithConv,

backends/cadence/aot/pass_utils.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# pyre-strict
44

55
from dataclasses import dataclass
6-
from typing import Callable, Optional, Set, Union
6+
from typing import Callable, List, Optional, Set, Union
77

88
import torch
99
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
@@ -50,7 +50,7 @@ def get_all_available_cadence_passes() -> Set[ExportPass]:
5050
return set(ALL_CADENCE_PASSES.keys())
5151

5252

53-
# Create a new filter to filter out relevant passes from all Jarvis passes.
53+
# Create a new filter to filter out relevant passes from all passes.
5454
def create_cadence_pass_filter(
5555
opt_level: int, debug: bool = False
5656
) -> Callable[[ExportPass], bool]:
@@ -98,3 +98,47 @@ def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target)
9898
if node.op == "call_function" and node.target == target:
9999
total += 1
100100
return total
101+
102+
103+
# Testing utils
104+
# Return the compute/function nodes in the graph
105+
def get_compute_nodes_in_gm(graph_module: torch.fx.GraphModule) -> List[torch.fx.Node]:
106+
nodes = []
107+
for x in graph_module.graph.nodes:
108+
if x.op == "call_function":
109+
if isinstance(x.target, torch._ops.OpOverload):
110+
nodes.append(x.target.overloadpacket)
111+
elif isinstance(x.target, EdgeOpOverload):
112+
nodes.append(get_edge_overload_packet(x.target))
113+
return nodes
114+
115+
116+
# Return true if there is no edge from a node with target pred_target to a
117+
# node with target succ_target in the graph.
118+
def nodes_not_connected_in_gm(
119+
graph_module: torch.fx.GraphModule,
120+
pred_target: torch.fx.Node,
121+
succ_target: torch.fx.Node,
122+
) -> bool:
123+
for node in graph_module.graph.nodes:
124+
if node.target != pred_target:
125+
continue
126+
for user in node.users:
127+
if user.target == succ_target:
128+
return False
129+
return True
130+
131+
132+
# Returns true if there is no instance of a node with target succ_target
133+
# positioned immediately after a node with target pred_target in the graph
134+
def nodes_not_adjacent_in_gm(
135+
graph_module: torch.fx.GraphModule,
136+
pred_target: torch.fx.Node,
137+
succ_target: torch.fx.Node,
138+
) -> bool:
139+
for node in graph_module.graph.nodes:
140+
if node.target != pred_target:
141+
continue
142+
if node.next.target == succ_target:
143+
return False
144+
return True

0 commit comments

Comments
 (0)