Skip to content

Commit 1882bc1

Browse files
committed
Arm backend: Break out processing per graph_module in backend.
This will enable us to process multiple submodules contained in a partitioned ExportedProgram. Signed-off-by: Erik Lundell <[email protected]> Change-Id: I82e41b1e9ff2409ca31e86e4a89747e694ab4ea4
1 parent de56c81 commit 1882bc1

File tree

2 files changed

+88
-65
lines changed

2 files changed

+88
-65
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,15 @@ def _transform(self, graph_module: GraphModule):
153153
with TosaLoweringContext(self.tosa_spec):
154154
return self(graph_module).graph_module
155155

156-
def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
156+
def _tosa_INT_pipeline(
157+
self, exported_program: ExportedProgram, graph_module: GraphModule
158+
) -> GraphModule:
157159
self.add_pass(AnnotateOutputDimOrderPass())
158160
self.add_pass(FuseQuantizedActivationPass())
159161
self.add_pass(RemoveGetItemPass())
160162
self.add_pass(ConvertSplitToSlicePass())
161163
self.add_pass(ConvertMmToBmmPass())
162-
self.add_pass(
163-
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
164-
)
164+
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
165165
self.add_pass(ConvertFullLikeToFullPass())
166166
self.add_pass(ConvertToClampPass())
167167
self.add_pass(ConvertMinMaxPass())
@@ -218,9 +218,11 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
218218
self.add_pass(InsertRescalePass())
219219

220220
self.validate_constraints_mandatory()
221-
return self._transform(exported_program.graph_module)
221+
return self._transform(graph_module)
222222

223-
def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
223+
def _tosa_FP_pipeline(
224+
self, exported_program: ExportedProgram, graph_module: GraphModule
225+
) -> GraphModule:
224226
self.add_pass(AnnotateOutputDimOrderPass())
225227
self.add_pass(DecomposeExpm1Pass())
226228
self.add_pass(DecomposeLogitPass())
@@ -255,9 +257,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
255257
self.add_pass(DecomposeLayerNormPass())
256258
self.add_pass(DecomposeBatchNormNoStatsPass())
257259
self.add_pass(DecomposeVarPass())
258-
self.add_pass(
259-
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
260-
)
260+
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
261261
self.add_pass(DecomposeNotEqualPass())
262262
self.add_pass(DecomposeDivPass())
263263
self.add_pass(DecomposeAddSubAlphaPass())
@@ -305,14 +305,16 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
305305
self.add_pass(InsertRescalePass())
306306

307307
self.validate_constraints_mandatory()
308-
return self._transform(exported_program.graph_module)
308+
return self._transform(graph_module)
309309

310-
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
310+
def transform_to_backend_pipeline(
311+
self, exported_program: ExportedProgram, graph_module: GraphModule
312+
):
311313
"""Apply passes before transforming program to backend"""
312314
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"):
313-
return self._tosa_FP_pipeline(exported_program)
315+
return self._tosa_FP_pipeline(exported_program, graph_module)
314316
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+INT"):
315-
return self._tosa_INT_pipeline(exported_program)
317+
return self._tosa_INT_pipeline(exported_program, graph_module)
316318
else:
317319
raise NotImplementedError(
318320
f"No pass pipeline implemented for {self.tosa_spec=}"

backends/arm/tosa/backend.py

Lines changed: 73 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
2828
from executorch.exir.backend.compile_spec_schema import CompileSpec
2929
from torch.export.exported_program import ExportedProgram
30-
from torch.fx import Graph, Node
30+
from torch.fx import Graph, GraphModule, Node
3131

3232
# TOSA backend debug functionality
3333
logger = logging.getLogger(__name__)
@@ -52,13 +52,39 @@ def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]):
5252
# Walk backwards so we touch every producer
5353
q.extend(n.all_input_nodes)
5454

55-
out = next(n for n in ep_graph.nodes if n.op == "output")
55+
out = ep_graph.output_node()
56+
# First argument of output node is tuple of outputs
57+
output_list = cast(tuple, out.args[0])
5658
seen: Set[Node] = set()
57-
for idx, val in enumerate(out.args[0]):
59+
for idx, val in enumerate(output_list):
5860
bfs_mark([val], idx, seen)
5961
return node2external_id
6062

6163

64+
def _sort_outputs(graph_module: GraphModule, node_to_id_map: dict[str, int]):
65+
def _external_id(n: Node, node_2_id, fallback: int) -> int:
66+
return node_2_id.get(n.name, fallback)
67+
68+
out_node = graph_module.graph.output_node()
69+
out_list = cast(tuple, out_node.args[0])
70+
_counter = count()
71+
72+
# sort nodes by the key that is id
73+
def _sort_key(t: Node) -> int:
74+
return _external_id(t, node_to_id_map, next(_counter))
75+
76+
orig_ord = tuple(sorted(out_list, key=_sort_key))
77+
78+
current_order = tuple(out_list)
79+
if orig_ord != current_order:
80+
replacement = list(orig_ord) if isinstance(out_node.args[0], list) else orig_ord
81+
out_node.args = (replacement,)
82+
graph_module.graph.lint()
83+
graph_module.recompile()
84+
85+
return graph_module
86+
87+
6288
def arm_get_first_delegation_tag(graph_module) -> str:
6389
"""Get the first delegation tag from the graph_module or return empty string."""
6490
for node in graph_module.graph.nodes:
@@ -93,9 +119,9 @@ def _preprocess( # noqa: C901
93119
artifact_path = compile_spec.get_intermediate_path()
94120
tosa_spec = compile_spec.tosa_spec
95121
dump_debug_info = compile_spec.tosa_debug_mode
96-
97-
# Assign to every node external id
98-
node_2_id = _annotate_external_ids(edge_program.graph)
122+
debug_hook = None
123+
if dump_debug_info is not None:
124+
debug_hook = DebugHook(dump_debug_info)
99125

100126
logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}")
101127

@@ -116,43 +142,57 @@ def _preprocess( # noqa: C901
116142
f"doesn't match specification {tosa_spec}"
117143
)
118144

145+
TOSABackend._preprocess_module(
146+
edge_program.graph_module,
147+
edge_program,
148+
compile_spec,
149+
tosa_graph,
150+
debug_hook,
151+
)
152+
# Serialize and return the TOSA flatbuffer.
153+
binary = tosa_graph.serialize()
154+
155+
if artifact_path:
156+
tag = arm_get_first_delegation_tag(edge_program.graph_module)
157+
debug_tosa_dump(
158+
binary,
159+
artifact_path,
160+
suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"),
161+
)
162+
163+
if debug_hook is not None:
164+
if debug_hook.mode == ArmCompileSpec.DebugMode.JSON:
165+
json_output = debug_hook.serialize()
166+
with open(f"{artifact_path}/debug.json", "w") as f:
167+
f.write(json_output)
168+
169+
return PreprocessResult(processed_bytes=binary)
170+
171+
@staticmethod
172+
def _preprocess_module(
173+
graph_module: GraphModule,
174+
edge_program: ExportedProgram,
175+
compile_spec: TosaCompileSpec,
176+
tosa_graph: ts.TosaSerializer,
177+
debug_hook: DebugHook | None,
178+
):
179+
"""Convert 'graph_module' to a tosa_graph"""
180+
tosa_spec = compile_spec.tosa_spec
181+
node_to_id_map = _annotate_external_ids(graph_module.graph)
182+
artifact_path = compile_spec.get_intermediate_path()
183+
119184
# TODO: Fix the need to lazily import this.
120185
from executorch.backends.arm._passes import ArmPassManager
121186

122187
graph_module = ArmPassManager(tosa_spec).transform_to_backend_pipeline( # type: ignore
123-
exported_program=edge_program
188+
exported_program=edge_program, graph_module=graph_module
124189
)
125190

126-
debug_hook = None
127-
if dump_debug_info is not None:
128-
debug_hook = DebugHook(dump_debug_info)
129-
130191
# TODO: Fix the need to lazily import this.
131192
from executorch.backends.arm.operators.node_visitor import get_node_visitors
132193

133194
node_visitors = get_node_visitors(edge_program, tosa_spec, debug_hook)
134-
135-
# Re-shuffle output nodes to preserve author's order
136-
def _external_id(n: Node, node_2_id, fallback: int) -> int:
137-
return node_2_id.get(n.name, fallback)
138-
139-
out_node = next(n for n in graph_module.graph.nodes if n.op == "output")
140-
_counter = count()
141-
142-
# sort nodes by the key that is id
143-
def _sort_key(t: Node) -> int:
144-
return _external_id(t, node_2_id, next(_counter))
145-
146-
orig_ord = tuple(sorted(out_node.args[0], key=_sort_key))
147-
148-
current_order = tuple(out_node.args[0])
149-
if orig_ord != current_order:
150-
replacement = (
151-
list(orig_ord) if isinstance(out_node.args[0], list) else orig_ord
152-
)
153-
out_node.args = (replacement,)
154-
graph_module.graph.lint()
155-
graph_module.recompile()
195+
graph_module = _sort_outputs(graph_module, node_to_id_map)
156196

157197
input_count = 0
158198
for node in graph_module.graph.nodes:
@@ -176,25 +216,6 @@ def _sort_key(t: Node) -> int:
176216
debug_fail(node, graph_module, tosa_graph.serialize(), artifact_path)
177217
raise
178218

179-
# Serialize and return the TOSA flatbuffer.
180-
binary = tosa_graph.serialize()
181-
182-
if artifact_path:
183-
tag = arm_get_first_delegation_tag(graph_module)
184-
debug_tosa_dump(
185-
binary,
186-
artifact_path,
187-
suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"),
188-
)
189-
190-
if debug_hook is not None:
191-
if debug_hook.mode == ArmCompileSpec.DebugMode.JSON:
192-
json_output = debug_hook.serialize()
193-
with open(f"{artifact_path}/debug.json", "w") as f:
194-
f.write(json_output)
195-
196-
return PreprocessResult(processed_bytes=binary)
197-
198219
@staticmethod
199220
def filter_tosa_compile_specs(
200221
compile_spec: ArmCompileSpec,

0 commit comments

Comments
 (0)