Skip to content

Commit 29fba26

Browse files
committed
Arm backend: Enable serializing to different regions.
Each conditional submodule in the graph_module gets its own region. The TOSA reference model requires all tensor names in one model to be unique, regardless of region. Pytorch's naming semantics, however don't guarantee this. To fix this, attach a suffix containing the submodule name to tensors in submodules. Signed-off-by: Erik Lundell <[email protected]> Change-Id: I910a7d71f0b5da2d2d9219746efd012f9bd251fb
1 parent 1882bc1 commit 29fba26

File tree

3 files changed

+32
-13
lines changed

3 files changed

+32
-13
lines changed

backends/arm/process_node.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def process_inputs_to_buffers(
159159
buffer_values = np.transpose(buffer_values, tosa_arg.dim_order)
160160

161161
tosa_graph.addConst(
162-
buffer_values.shape, tosa_arg.dtype, buffer_values, name=node.name
162+
buffer_values.shape, tosa_arg.dtype, buffer_values, name=tosa_arg.name
163163
)
164164

165165

@@ -216,11 +216,9 @@ def process_placeholder(
216216
raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.")
217217

218218

219-
def process_output(
220-
node: torch.fx.Node,
221-
tosa_graph: Any,
222-
):
219+
def process_output(node: torch.fx.Node, tosa_graph: Any, tosa_spec: TosaSpecification):
223220
for output in cast(tuple[torch.fx.Node, ...], node.args[0]):
221+
output_arg = TosaArg(output, tosa_spec)
224222
tosa_graph.addOutputTensor(
225-
tosa_graph.currRegion.currBasicBlock.tensors[output.name]
223+
tosa_graph.currRegion.currBasicBlock.tensors[output_arg.name]
226224
)

backends/arm/tosa/backend.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,14 @@
2424
process_placeholder,
2525
)
2626
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
27+
from executorch.backends.arm.tosa.mapping import TOSA_TENSOR_NAME_META
2728
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
2829
from executorch.exir.backend.compile_spec_schema import CompileSpec
30+
from executorch.exir.graph_module import get_control_flow_submodules
2931
from torch.export.exported_program import ExportedProgram
3032
from torch.fx import Graph, GraphModule, Node
3133

34+
3235
# TOSA backend debug functionality
3336
logger = logging.getLogger(__name__)
3437

@@ -169,12 +172,13 @@ def _preprocess( # noqa: C901
169172
return PreprocessResult(processed_bytes=binary)
170173

171174
@staticmethod
172-
def _preprocess_module(
175+
def _preprocess_module( # noqa: C901
173176
graph_module: GraphModule,
174177
edge_program: ExportedProgram,
175178
compile_spec: TosaCompileSpec,
176179
tosa_graph: ts.TosaSerializer,
177180
debug_hook: DebugHook | None,
181+
submodule_name: str | None = None,
178182
):
179183
"""Convert 'graph_module' to a tosa_graph"""
180184
tosa_spec = compile_spec.tosa_spec
@@ -194,7 +198,13 @@ def _preprocess_module(
194198
node_visitors = get_node_visitors(edge_program, tosa_spec, debug_hook)
195199
graph_module = _sort_outputs(graph_module, node_to_id_map)
196200

197-
input_count = 0
201+
if submodule_name is not None:
202+
tosa_graph.startRegion(submodule_name)
203+
tosa_graph.currRegion.addBasicBlock(submodule_name)
204+
suffix = f"_{submodule_name}"
205+
for loop_node in graph_module.graph.nodes:
206+
loop_node.meta[TOSA_TENSOR_NAME_META] = suffix
207+
198208
for node in graph_module.graph.nodes:
199209
node = cast(Node, node)
200210
try:
@@ -204,18 +214,27 @@ def _preprocess_module(
204214
if len(node.users) == 0:
205215
continue
206216
process_placeholder(node, tosa_graph, edge_program, tosa_spec)
207-
if node.name in edge_program.graph_signature.user_inputs:
208-
input_count += 1
209217
elif node.op == "output":
210-
process_output(node, tosa_graph)
218+
process_output(node, tosa_graph, tosa_spec)
211219
else:
212220
# This will only happen if an unpartitioned graph is passed without
213221
# any checking of compatibility.
214222
raise RuntimeError(f"{node.name} is unsupported op {node.op}")
215223
except Exception:
216-
debug_fail(node, graph_module, tosa_graph.serialize(), artifact_path)
224+
debug_fail(node, graph_module, tosa_graph, artifact_path)
217225
raise
218226

227+
# Recursively preprocess controlflow submodules.
228+
for name, submodule, _ in get_control_flow_submodules(graph_module):
229+
TOSABackend._preprocess_module(
230+
submodule,
231+
edge_program,
232+
compile_spec,
233+
tosa_graph,
234+
debug_hook,
235+
submodule_name=name,
236+
)
237+
219238
@staticmethod
220239
def filter_tosa_compile_specs(
221240
compile_spec: ArmCompileSpec,

backends/arm/tosa/mapping.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import tosa_serializer as ts
1818
from executorch.backends.arm.tosa.specification import TosaSpecification
1919

20+
TOSA_TENSOR_NAME_META = "tosa_tensor_name"
21+
2022
UNSUPPORTED_DTYPES = (
2123
torch.float64,
2224
torch.double,
@@ -144,7 +146,7 @@ def __process_node(self, argument: torch.fx.Node):
144146
argument (torch.fx.Node): FX node to inspect.
145147
146148
"""
147-
self.name: str = argument.name
149+
self.name = argument.name + argument.meta.get(TOSA_TENSOR_NAME_META, "")
148150
output_dtype, self.shape, self.dim_order = extract_tensor_meta(
149151
argument.meta, self.tosa_spec
150152
)

0 commit comments

Comments
 (0)