Skip to content

Commit 7473851

Browse files
Zonglin Pengfacebook-github-bot
authored andcommitted
add reorder pass testing
Differential Revision: D66078188
1 parent ad348db commit 7473851

File tree

4 files changed

+502
-4
lines changed

4 files changed

+502
-4
lines changed

backends/cadence/aot/TARGETS

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,22 @@ python_unittest(
335335
"//executorch/exir/dialects:lib",
336336
],
337337
)
338+
339+
python_unittest(
340+
name = "test_reorder_ops_passes",
341+
srcs = [
342+
"tests/test_reorder_ops_passes.py",
343+
],
344+
typing = True,
345+
deps = [
346+
":compiler",
347+
":pass_utils",
348+
"//caffe2:torch",
349+
"//executorch/backends/cadence/aot:compiler",
350+
"//executorch/backends/cadence/aot:fuse_ops",
351+
"//executorch/backends/cadence/aot:ops_registrations",
352+
"//executorch/backends/cadence/aot:pass_utils",
353+
"//executorch/backends/cadence/aot:reorder_ops",
354+
"//executorch/exir/dialects:lib",
355+
],
356+
)

backends/cadence/aot/compiler.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,6 @@ def export_to_edge(
194194
return edge_prog_manager
195195

196196

197-
# Export the model and lower it to an EdgeProgramManager (in edge IR), and
198-
# apply passes specific to Cadence DSP execution. Return both to print the
199-
# differences.
200197
def export_to_cadence(
201198
model: torch.nn.Module,
202199
inputs: tuple[object, ...],
@@ -216,6 +213,25 @@ def export_to_cadence(
216213
return cadence_prog_manager
217214

218215

216+
def quantize_and_export_to_cadence(
217+
model: torch.nn.Module,
218+
inputs: tuple[object, ...],
219+
dump_graphs: bool = False,
220+
opt_level: int = 1,
221+
) -> EdgeProgramManager:
222+
quantized_model = quantize_pt2(model, inputs)
223+
224+
return export_to_cadence(
225+
quantized_model,
226+
inputs,
227+
opt_level=opt_level,
228+
dump_graphs=dump_graphs,
229+
)
230+
231+
232+
# Export the model and lower it to an EdgeProgramManager (in edge IR), and
233+
# apply passes specific to Cadence DSP execution. Return both to print the
234+
# differences.
219235
def export_to_executorch_gen_etrecord(
220236
model: torch.nn.Module,
221237
inputs: tuple[object, ...],

backends/cadence/aot/pass_utils.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# pyre-strict
44

55
from dataclasses import dataclass
6-
from typing import Callable, Optional, Set, Union
6+
from typing import Callable, List, Optional, Set, Union
77

88
import torch
99
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
@@ -98,3 +98,47 @@ def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target)
9898
if node.op == "call_function" and node.target == target:
9999
total += 1
100100
return total
101+
102+
103+
# Testing utils
104+
# Return the compute/function nodes in the graph
105+
def get_compute_nodes_in_gm(graph_module: torch.fx.GraphModule) -> List[torch.fx.Node]:
106+
nodes = []
107+
for x in graph_module.graph.nodes:
108+
if x.op == "call_function":
109+
if isinstance(x.target, torch._ops.OpOverload):
110+
nodes.append(x.target.overloadpacket)
111+
elif isinstance(x.target, EdgeOpOverload):
112+
nodes.append(get_edge_overload_packet(x.target))
113+
return nodes
114+
115+
116+
# Return true if there is no edge from a node with target pred_target to a
117+
# node with target succ_target in the graph.
118+
def nodes_not_connected_in_gm(
119+
graph_module: torch.fx.GraphModule,
120+
pred_target: torch.fx.Node,
121+
succ_target: torch.fx.Node,
122+
) -> bool:
123+
for node in graph_module.graph.nodes:
124+
if node.target != pred_target:
125+
continue
126+
for user in node.users:
127+
if user.target == succ_target:
128+
return False
129+
return True
130+
131+
132+
# Returns true if there is no instance of a node with target succ_target
133+
# positioned immediately after a node with target pred_target in the graph
134+
def nodes_not_adjacent_in_gm(
135+
graph_module: torch.fx.GraphModule,
136+
pred_target: torch.fx.Node,
137+
succ_target: torch.fx.Node,
138+
) -> bool:
139+
for node in graph_module.graph.nodes:
140+
if node.target != pred_target:
141+
continue
142+
if node.next.target == succ_target:
143+
return False
144+
return True

0 commit comments

Comments
 (0)