Skip to content

Commit 5b1a947

Browse files
committed
Revert "Arm backend: Serialize controlflow submodules. (pytorch#15381)"
This reverts commit a4e7475.
1 parent 4da1550 commit 5b1a947

File tree

4 files changed

+74
-116
lines changed

4 files changed

+74
-116
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,15 @@ def _transform(self, graph_module: GraphModule):
153153
with TosaLoweringContext(self.tosa_spec):
154154
return self(graph_module).graph_module
155155

156-
def _tosa_INT_pipeline(
157-
self, exported_program: ExportedProgram, graph_module: GraphModule
158-
) -> GraphModule:
156+
def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
159157
self.add_pass(AnnotateOutputDimOrderPass())
160158
self.add_pass(FuseQuantizedActivationPass())
161159
self.add_pass(RemoveGetItemPass())
162160
self.add_pass(ConvertSplitToSlicePass())
163161
self.add_pass(ConvertMmToBmmPass())
164-
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
162+
self.add_pass(
163+
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
164+
)
165165
self.add_pass(ConvertFullLikeToFullPass())
166166
self.add_pass(ConvertToClampPass())
167167
self.add_pass(ConvertMinMaxPass())
@@ -218,11 +218,9 @@ def _tosa_INT_pipeline(
218218
self.add_pass(InsertRescalePass())
219219

220220
self.validate_constraints_mandatory()
221-
return self._transform(graph_module)
221+
return self._transform(exported_program.graph_module)
222222

223-
def _tosa_FP_pipeline(
224-
self, exported_program: ExportedProgram, graph_module: GraphModule
225-
) -> GraphModule:
223+
def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
226224
self.add_pass(AnnotateOutputDimOrderPass())
227225
self.add_pass(DecomposeExpm1Pass())
228226
self.add_pass(DecomposeLogitPass())
@@ -257,7 +255,9 @@ def _tosa_FP_pipeline(
257255
self.add_pass(DecomposeLayerNormPass())
258256
self.add_pass(DecomposeBatchNormNoStatsPass())
259257
self.add_pass(DecomposeVarPass())
260-
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
258+
self.add_pass(
259+
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
260+
)
261261
self.add_pass(DecomposeNotEqualPass())
262262
self.add_pass(DecomposeDivPass())
263263
self.add_pass(DecomposeAddSubAlphaPass())
@@ -305,16 +305,14 @@ def _tosa_FP_pipeline(
305305
self.add_pass(InsertRescalePass())
306306

307307
self.validate_constraints_mandatory()
308-
return self._transform(graph_module)
308+
return self._transform(exported_program.graph_module)
309309

310-
def transform_to_backend_pipeline(
311-
self, exported_program: ExportedProgram, graph_module: GraphModule
312-
):
310+
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
313311
"""Apply passes before transforming program to backend"""
314312
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"):
315-
return self._tosa_FP_pipeline(exported_program, graph_module)
313+
return self._tosa_FP_pipeline(exported_program)
316314
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+INT"):
317-
return self._tosa_INT_pipeline(exported_program, graph_module)
315+
return self._tosa_INT_pipeline(exported_program)
318316
else:
319317
raise NotImplementedError(
320318
f"No pass pipeline implemented for {self.tosa_spec=}"

backends/arm/process_node.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def process_inputs_to_buffers(
158158
buffer_values = np.transpose(buffer_values, tosa_arg.dim_order)
159159

160160
tosa_graph.addConst(
161-
buffer_values.shape, tosa_arg.dtype, buffer_values, name=tosa_arg.name
161+
buffer_values.shape, tosa_arg.dtype, buffer_values, name=node.name
162162
)
163163

164164

@@ -215,9 +215,11 @@ def process_placeholder(
215215
raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.")
216216

217217

218-
def process_output(node: torch.fx.Node, tosa_graph: Any, tosa_spec: TosaSpecification):
218+
def process_output(
219+
node: torch.fx.Node,
220+
tosa_graph: Any,
221+
):
219222
for output in cast(tuple[torch.fx.Node, ...], node.args[0]):
220-
output_arg = TosaArg(output, tosa_spec)
221223
tosa_graph.addOutputTensor(
222-
tosa_graph.currRegion.currBasicBlock.tensors[output_arg.name]
224+
tosa_graph.currRegion.currBasicBlock.tensors[output.name]
223225
)

backends/arm/tosa/backend.py

Lines changed: 54 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,10 @@
2424
process_placeholder,
2525
)
2626
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
27-
from executorch.backends.arm.tosa.mapping import TOSA_TENSOR_NAME_META
2827
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
2928
from executorch.exir.backend.compile_spec_schema import CompileSpec
30-
from executorch.exir.graph_module import get_control_flow_submodules
3129
from torch.export.exported_program import ExportedProgram
32-
from torch.fx import Graph, GraphModule, Node
33-
30+
from torch.fx import Graph, Node
3431

3532
# TOSA backend debug functionality
3633
logger = logging.getLogger(__name__)
@@ -55,39 +52,13 @@ def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]):
5552
# Walk backwards so we touch every producer
5653
q.extend(n.all_input_nodes)
5754

58-
out = ep_graph.output_node()
59-
# First argument of output node is tuple of outputs
60-
output_list = cast(tuple, out.args[0])
55+
out = next(n for n in ep_graph.nodes if n.op == "output")
6156
seen: Set[Node] = set()
62-
for idx, val in enumerate(output_list):
57+
for idx, val in enumerate(out.args[0]):
6358
bfs_mark([val], idx, seen)
6459
return node2external_id
6560

6661

67-
def _sort_outputs(graph_module: GraphModule, node_to_id_map: dict[str, int]):
68-
def _external_id(n: Node, node_2_id, fallback: int) -> int:
69-
return node_2_id.get(n.name, fallback)
70-
71-
out_node = graph_module.graph.output_node()
72-
out_list = cast(tuple, out_node.args[0])
73-
_counter = count()
74-
75-
# sort nodes by the key that is id
76-
def _sort_key(t: Node) -> int:
77-
return _external_id(t, node_to_id_map, next(_counter))
78-
79-
orig_ord = tuple(sorted(out_list, key=_sort_key))
80-
81-
current_order = tuple(out_list)
82-
if orig_ord != current_order:
83-
replacement = list(orig_ord) if isinstance(out_node.args[0], list) else orig_ord
84-
out_node.args = (replacement,)
85-
graph_module.graph.lint()
86-
graph_module.recompile()
87-
88-
return graph_module
89-
90-
9162
def arm_get_first_delegation_tag(graph_module) -> str:
9263
"""Get the first delegation tag from the graph_module or return empty string."""
9364
for node in graph_module.graph.nodes:
@@ -122,9 +93,9 @@ def _preprocess( # noqa: C901
12293
artifact_path = compile_spec.get_intermediate_path()
12394
tosa_spec = compile_spec.tosa_spec
12495
dump_debug_info = compile_spec.tosa_debug_mode
125-
debug_hook = None
126-
if dump_debug_info is not None:
127-
debug_hook = DebugHook(dump_debug_info)
96+
97+
# Assign to every node external id
98+
node_2_id = _annotate_external_ids(edge_program.graph)
12899

129100
logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}")
130101

@@ -145,66 +116,45 @@ def _preprocess( # noqa: C901
145116
f"doesn't match specification {tosa_spec}"
146117
)
147118

148-
TOSABackend._preprocess_module(
149-
edge_program.graph_module,
150-
edge_program,
151-
compile_spec,
152-
tosa_graph,
153-
debug_hook,
154-
)
155-
# Serialize and return the TOSA flatbuffer.
156-
binary = tosa_graph.serialize()
157-
158-
if artifact_path:
159-
tag = arm_get_first_delegation_tag(edge_program.graph_module)
160-
debug_tosa_dump(
161-
binary,
162-
artifact_path,
163-
suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"),
164-
)
165-
166-
if debug_hook is not None:
167-
if debug_hook.mode == ArmCompileSpec.DebugMode.JSON:
168-
json_output = debug_hook.serialize()
169-
with open(f"{artifact_path}/debug.json", "w") as f:
170-
f.write(json_output)
171-
172-
return PreprocessResult(processed_bytes=binary)
173-
174-
@staticmethod
175-
def _preprocess_module( # noqa: C901
176-
graph_module: GraphModule,
177-
edge_program: ExportedProgram,
178-
compile_spec: TosaCompileSpec,
179-
tosa_graph: ts.TosaSerializer,
180-
debug_hook: DebugHook | None,
181-
submodule_name: str | None = None,
182-
):
183-
"""Convert 'graph_module' to a tosa_graph"""
184-
tosa_spec = compile_spec.tosa_spec
185-
node_to_id_map = _annotate_external_ids(graph_module.graph)
186-
artifact_path = compile_spec.get_intermediate_path()
187-
188119
# TODO: Fix the need to lazily import this.
189120
from executorch.backends.arm._passes import ArmPassManager
190121

191122
graph_module = ArmPassManager(tosa_spec).transform_to_backend_pipeline( # type: ignore
192-
exported_program=edge_program, graph_module=graph_module
123+
exported_program=edge_program
193124
)
194125

126+
debug_hook = None
127+
if dump_debug_info is not None:
128+
debug_hook = DebugHook(dump_debug_info)
129+
195130
# TODO: Fix the need to lazily import this.
196131
from executorch.backends.arm.operators.node_visitor import get_node_visitors
197132

198133
node_visitors = get_node_visitors(edge_program, tosa_spec, debug_hook)
199-
graph_module = _sort_outputs(graph_module, node_to_id_map)
200134

201-
if submodule_name is not None:
202-
tosa_graph.startRegion(submodule_name)
203-
tosa_graph.currRegion.addBasicBlock(submodule_name)
204-
suffix = f"_{submodule_name}"
205-
for loop_node in graph_module.graph.nodes:
206-
loop_node.meta[TOSA_TENSOR_NAME_META] = suffix
135+
# Re-shuffle output nodes to preserve author's order
136+
def _external_id(n: Node, node_2_id, fallback: int) -> int:
137+
return node_2_id.get(n.name, fallback)
138+
139+
out_node = next(n for n in graph_module.graph.nodes if n.op == "output")
140+
_counter = count()
141+
142+
# sort nodes by the key that is id
143+
def _sort_key(t: Node) -> int:
144+
return _external_id(t, node_2_id, next(_counter))
207145

146+
orig_ord = tuple(sorted(out_node.args[0], key=_sort_key))
147+
148+
current_order = tuple(out_node.args[0])
149+
if orig_ord != current_order:
150+
replacement = (
151+
list(orig_ord) if isinstance(out_node.args[0], list) else orig_ord
152+
)
153+
out_node.args = (replacement,)
154+
graph_module.graph.lint()
155+
graph_module.recompile()
156+
157+
input_count = 0
208158
for node in graph_module.graph.nodes:
209159
node = cast(Node, node)
210160
try:
@@ -214,27 +164,37 @@ def _preprocess_module( # noqa: C901
214164
if len(node.users) == 0:
215165
continue
216166
process_placeholder(node, tosa_graph, edge_program, tosa_spec)
167+
if node.name in edge_program.graph_signature.user_inputs:
168+
input_count += 1
217169
elif node.op == "output":
218-
process_output(node, tosa_graph, tosa_spec)
170+
process_output(node, tosa_graph)
219171
else:
220172
# This will only happen if an unpartitioned graph is passed without
221173
# any checking of compatibility.
222174
raise RuntimeError(f"{node.name} is unsupported op {node.op}")
223175
except Exception:
224-
debug_fail(node, graph_module, tosa_graph, artifact_path)
176+
debug_fail(node, graph_module, tosa_graph.serialize(), artifact_path)
225177
raise
226178

227-
# Recursively preprocess controlflow submodules.
228-
for name, submodule, _ in get_control_flow_submodules(graph_module):
229-
TOSABackend._preprocess_module(
230-
submodule,
231-
edge_program,
232-
compile_spec,
233-
tosa_graph,
234-
debug_hook,
235-
submodule_name=name,
179+
# Serialize and return the TOSA flatbuffer.
180+
binary = tosa_graph.serialize()
181+
182+
if artifact_path:
183+
tag = arm_get_first_delegation_tag(graph_module)
184+
debug_tosa_dump(
185+
binary,
186+
artifact_path,
187+
suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"),
236188
)
237189

190+
if debug_hook is not None:
191+
if debug_hook.mode == ArmCompileSpec.DebugMode.JSON:
192+
json_output = debug_hook.serialize()
193+
with open(f"{artifact_path}/debug.json", "w") as f:
194+
f.write(json_output)
195+
196+
return PreprocessResult(processed_bytes=binary)
197+
238198
@staticmethod
239199
def filter_tosa_compile_specs(
240200
compile_spec: ArmCompileSpec,

backends/arm/tosa/mapping.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
import tosa_serializer as ts
1818
from executorch.backends.arm.tosa.specification import TosaSpecification
1919

20-
TOSA_TENSOR_NAME_META = "tosa_tensor_name"
21-
2220
UNSUPPORTED_DTYPES = (
2321
torch.float64,
2422
torch.double,
@@ -146,7 +144,7 @@ def __process_node(self, argument: torch.fx.Node):
146144
argument (torch.fx.Node): FX node to inspect.
147145
148146
"""
149-
self.name = argument.name + argument.meta.get(TOSA_TENSOR_NAME_META, "")
147+
self.name: str = argument.name
150148
output_dtype, self.shape, self.dim_order = extract_tensor_meta(
151149
argument.meta, self.tosa_spec
152150
)

0 commit comments

Comments
 (0)