Skip to content

Commit fde3980

Browse files
committed
Arm backend: Preserve output order
This change allows us to preserve output order after export. Change-Id: I7ee55c2877ca1b247f10d2e90da3ba38dc727b6f Signed-off-by: Elena Zhelezina <[email protected]>
1 parent cec1400 commit fde3980

File tree

2 files changed

+176
-2
lines changed

2 files changed

+176
-2
lines changed
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
#
6+
# pyre-unsafe
7+
import tempfile
8+
from pathlib import Path
9+
10+
import pytest
11+
import torch
12+
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_quantization_config,
15+
TOSAQuantizer,
16+
)
17+
from executorch.backends.arm.tosa_partitioner import TOSAPartitioner
18+
from executorch.backends.arm.tosa_specification import TosaSpecification
19+
from executorch.exir import to_edge_transform_and_lower
20+
from torch import nn
21+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
22+
from tosa import TosaGraph
23+
24+
25+
class Network(nn.Module):
26+
def __init__(self, batch_norm=False):
27+
super().__init__()
28+
self.conv2d_0 = nn.Sequential(
29+
nn.Conv2d(1, 8, 3, padding=1, bias=False),
30+
nn.BatchNorm2d(8) if batch_norm else nn.Identity(),
31+
nn.ReLU(),
32+
)
33+
self.conv2d_1 = nn.Sequential(
34+
nn.Conv2d(8, 8, 3, padding=1, bias=False),
35+
nn.BatchNorm2d(8) if batch_norm else nn.Identity(),
36+
nn.ReLU(),
37+
)
38+
self.conv2d_2 = nn.Sequential(
39+
nn.Conv2d(8, 8, 3, padding=1, bias=False),
40+
nn.BatchNorm2d(8) if batch_norm else nn.Identity(),
41+
nn.ReLU(),
42+
)
43+
self.out_0 = nn.Sequential(nn.Conv2d(8, 1, 3, padding=1, bias=False), nn.ReLU())
44+
self.out_1 = nn.Sequential(nn.Conv2d(8, 2, 3, padding=1, bias=False), nn.ReLU())
45+
self.out_2 = nn.Sequential(nn.Conv2d(8, 3, 3, padding=1, bias=False), nn.ReLU())
46+
47+
def forward(self, x):
48+
x = self.conv2d_0(x)
49+
x = self.conv2d_1(x)
50+
x = self.conv2d_2(x)
51+
out0 = self.out_0(x)
52+
out1 = self.out_1(x)
53+
out2 = self.out_2(x)
54+
return out0, out1, out2
55+
56+
57+
def _read_tosa_outputs(tosa_path: Path):
58+
# Find output tensor names in order and return shapes
59+
buf = tosa_path.read_bytes()
60+
buf_arr = bytearray(buf)
61+
graph = TosaGraph.TosaGraph.GetRootAsTosaGraph(buf_arr, 0)
62+
region = graph.Regions(0)
63+
block = region.Blocks(0)
64+
# Build a dict name - tensor‑shape
65+
tensors = {}
66+
for i in range(block.TensorsLength()):
67+
t = block.Tensors(i)
68+
name = t.Name().decode()
69+
# NHWC
70+
shape = [t.Shape(j) for j in range(t.ShapeLength())]
71+
tensors[name] = shape
72+
shapes = []
73+
for i in range(block.OutputsLength()):
74+
out_name = block.Outputs(i).decode()
75+
shapes.append(tensors[out_name])
76+
return shapes
77+
78+
79+
@pytest.mark.parametrize("batch_size", [1, 4])
80+
def test_network_output_order_and_restore(tmp_path, batch_size):
81+
model = Network(batch_norm=True).eval()
82+
# Prepare spec
83+
spec = TosaSpecification.create_from_string("TOSA-1.0+INT")
84+
compile_spec = ArmCompileSpecBuilder().tosa_compile_spec(tosa_spec=spec).build()
85+
# Setup quantizer
86+
quantizer = TOSAQuantizer(compile_spec)
87+
quantizer.set_global(
88+
get_symmetric_quantization_config(is_qat=True, is_per_channel=False)
89+
)
90+
# Trace the model
91+
dummy = torch.randn(batch_size, 1, 28, 28)
92+
fx_mod = torch.export.export_for_training(model, (dummy,)).module()
93+
model = prepare_pt2e(fx_mod, quantizer)
94+
model(dummy)
95+
model = convert_pt2e(model)
96+
# Export to aten dialect
97+
aten_gm = torch.export.export(model, args=(dummy,), strict=True)
98+
with tempfile.TemporaryDirectory() as tmpdir:
99+
art_dir = Path(tmpdir)
100+
part = TOSAPartitioner(
101+
ArmCompileSpecBuilder()
102+
.tosa_compile_spec(spec)
103+
.dump_intermediate_artifacts_to(str(art_dir))
104+
.build()
105+
)
106+
_ = to_edge_transform_and_lower(aten_gm, partitioner=[part])
107+
# Expect exactly one .tosa file in the artefact dir
108+
tosa_files = list(art_dir.glob("*.tosa"))
109+
assert (
110+
len(tosa_files) == 1
111+
), f"Expected 1 .tosa artefact, found {len(tosa_files)} in {art_dir}"
112+
out_shapes = _read_tosa_outputs(tosa_files[0])
113+
# We use shape that is unique to output to check
114+
# that we preserve output order
115+
channel_dims = [s[-1] for s in out_shapes]
116+
assert channel_dims == [1, 2, 3], (
117+
"Outputs in .tosa do not keep author order: "
118+
f"expected [1, 2, 3], got {channel_dims}"
119+
)

backends/arm/tosa/backend.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
# JIT compiler flows.
1212
#
1313
import logging
14-
from typing import cast, final, List
14+
from collections import deque
15+
from itertools import count
16+
from typing import cast, Dict, final, List, Set
1517

1618
import serializer.tosa_serializer as ts # type: ignore
1719
from executorch.backends.arm.common.debug import debug_fail, debug_tosa_dump
@@ -25,12 +27,38 @@
2527
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
2628
from executorch.exir.backend.compile_spec_schema import CompileSpec
2729
from torch.export.exported_program import ExportedProgram
28-
from torch.fx import Node
30+
from torch.fx import Graph, Node
2931

3032
# TOSA backend debug functionality
3133
logger = logging.getLogger(__name__)
3234

3335

36+
def _annotate_external_ids(ep_graph: Graph) -> Dict[str, int]:
37+
"""
38+
Returns dictionary: node name -> external ids
39+
40+
Assign id to an output node of the model so we can trace it.
41+
"""
42+
node2external_id = {}
43+
44+
def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]):
45+
q = deque(start_nodes)
46+
while q:
47+
n = q.popleft()
48+
if n in seen:
49+
continue
50+
seen.add(n)
51+
node2external_id[n.name] = idx
52+
# Walk backwards so we touch every producer
53+
q.extend(n.all_input_nodes)
54+
55+
out = next(n for n in ep_graph.nodes if n.op == "output")
56+
seen: Set[Node] = set()
57+
for idx, val in enumerate(out.args[0]):
58+
bfs_mark([val], idx, seen)
59+
return node2external_id
60+
61+
3462
def arm_get_first_delegation_tag(graph_module) -> str:
3563
"""Get the first delegation tag from the graph_module or return empty string."""
3664
for node in graph_module.graph.nodes:
@@ -74,6 +102,9 @@ def preprocess( # noqa: C901
74102
if output_format != "tosa":
75103
raise ValueError(f'Invalid output format {output_format}, must be "tosa"')
76104

105+
# Assign to every node external id
106+
node_2_id = _annotate_external_ids(edge_program.graph)
107+
77108
tosa_spec = get_tosa_spec(compile_spec)
78109
if tosa_spec is None:
79110
raise ValueError(
@@ -106,6 +137,30 @@ def preprocess( # noqa: C901
106137
from executorch.backends.arm.operators.node_visitor import get_node_visitors
107138

108139
node_visitors = get_node_visitors(edge_program, tosa_spec, debug_hook)
140+
141+
# Re-shuffle output nodes to preserve author's order
142+
def _external_id(n: Node, node_2_id, fallback: int) -> int:
143+
return node_2_id.get(n.name, fallback)
144+
145+
out_node = next(n for n in graph_module.graph.nodes if n.op == "output")
146+
_counter = count()
147+
148+
# sort nodes by the key that is id
149+
def _sort_key(t: Node) -> int:
150+
return _external_id(t, node_2_id, next(_counter))
151+
152+
orig_ord = tuple(sorted(out_node.args[0], key=_sort_key))
153+
154+
current_order = tuple(out_node.args[0])
155+
if orig_ord != current_order:
156+
replacement = (
157+
list(orig_ord) if isinstance(out_node.args[0], list) else orig_ord
158+
)
159+
out_node.args = (replacement,)
160+
graph_module.graph.lint()
161+
graph_module.recompile()
162+
163+
node_visitors = get_node_visitors(edge_program, tosa_spec)
109164
input_count = 0
110165
for node in graph_module.graph.nodes:
111166
node = cast(Node, node)

0 commit comments

Comments
 (0)