Skip to content

Commit d324799

Browse files
authored
Arm backend: Preserve output order (pytorch#13997)
This change allows us to preserve output order after export. Signed-off-by: Elena Zhelezina <[email protected]>
1 parent f478f2f commit d324799

File tree

2 files changed

+175
-2
lines changed

2 files changed

+175
-2
lines changed
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
#
6+
# pyre-unsafe
7+
import tempfile
8+
from pathlib import Path
9+
10+
import pytest
11+
import torch
12+
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_quantization_config,
15+
TOSAQuantizer,
16+
)
17+
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
18+
from executorch.backends.arm.tosa.specification import TosaSpecification
19+
from executorch.exir import to_edge_transform_and_lower
20+
from torch import nn
21+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
22+
from tosa import TosaGraph
23+
24+
25+
class Network(nn.Module):
26+
def __init__(self, batch_norm=False):
27+
super().__init__()
28+
self.conv2d_0 = nn.Sequential(
29+
nn.Conv2d(1, 8, 3, padding=1, bias=False),
30+
nn.BatchNorm2d(8) if batch_norm else nn.Identity(),
31+
nn.ReLU(),
32+
)
33+
self.conv2d_1 = nn.Sequential(
34+
nn.Conv2d(8, 8, 3, padding=1, bias=False),
35+
nn.BatchNorm2d(8) if batch_norm else nn.Identity(),
36+
nn.ReLU(),
37+
)
38+
self.conv2d_2 = nn.Sequential(
39+
nn.Conv2d(8, 8, 3, padding=1, bias=False),
40+
nn.BatchNorm2d(8) if batch_norm else nn.Identity(),
41+
nn.ReLU(),
42+
)
43+
self.out_0 = nn.Sequential(nn.Conv2d(8, 1, 3, padding=1, bias=False), nn.ReLU())
44+
self.out_1 = nn.Sequential(nn.Conv2d(8, 2, 3, padding=1, bias=False), nn.ReLU())
45+
self.out_2 = nn.Sequential(nn.Conv2d(8, 3, 3, padding=1, bias=False), nn.ReLU())
46+
47+
def forward(self, x):
48+
x = self.conv2d_0(x)
49+
x = self.conv2d_1(x)
50+
x = self.conv2d_2(x)
51+
out0 = self.out_0(x)
52+
out1 = self.out_1(x)
53+
out2 = self.out_2(x)
54+
return out0, out1, out2
55+
56+
57+
def _read_tosa_outputs(tosa_path: Path):
58+
# Find output tensor names in order and return shapes
59+
buf = tosa_path.read_bytes()
60+
buf_arr = bytearray(buf)
61+
graph = TosaGraph.TosaGraph.GetRootAsTosaGraph(buf_arr, 0)
62+
region = graph.Regions(0)
63+
block = region.Blocks(0)
64+
# Build a dict name - tensor‑shape
65+
tensors = {}
66+
for i in range(block.TensorsLength()):
67+
t = block.Tensors(i)
68+
name = t.Name().decode()
69+
# NHWC
70+
shape = [t.Shape(j) for j in range(t.ShapeLength())]
71+
tensors[name] = shape
72+
shapes = []
73+
for i in range(block.OutputsLength()):
74+
out_name = block.Outputs(i).decode()
75+
shapes.append(tensors[out_name])
76+
return shapes
77+
78+
79+
@pytest.mark.parametrize("batch_size", [1, 4])
80+
def test_network_output_order_and_restore(tmp_path, batch_size):
81+
model = Network(batch_norm=True).eval()
82+
# Prepare spec
83+
spec = TosaSpecification.create_from_string("TOSA-1.0+INT")
84+
compile_spec = ArmCompileSpecBuilder().tosa_compile_spec(tosa_spec=spec).build()
85+
# Setup quantizer
86+
quantizer = TOSAQuantizer(compile_spec)
87+
quantizer.set_global(
88+
get_symmetric_quantization_config(is_qat=True, is_per_channel=False)
89+
)
90+
# Trace the model
91+
dummy = torch.randn(batch_size, 1, 28, 28)
92+
fx_mod = torch.export.export_for_training(model, (dummy,)).module()
93+
model = prepare_pt2e(fx_mod, quantizer)
94+
model(dummy)
95+
model = convert_pt2e(model)
96+
# Export to aten dialect
97+
aten_gm = torch.export.export(model, args=(dummy,), strict=True)
98+
with tempfile.TemporaryDirectory() as tmpdir:
99+
art_dir = Path(tmpdir)
100+
part = TOSAPartitioner(
101+
ArmCompileSpecBuilder()
102+
.tosa_compile_spec(spec)
103+
.dump_intermediate_artifacts_to(str(art_dir))
104+
.build()
105+
)
106+
_ = to_edge_transform_and_lower(aten_gm, partitioner=[part])
107+
# Expect exactly one .tosa file in the artefact dir
108+
tosa_files = list(art_dir.glob("*.tosa"))
109+
assert (
110+
len(tosa_files) == 1
111+
), f"Expected 1 .tosa artefact, found {len(tosa_files)} in {art_dir}"
112+
out_shapes = _read_tosa_outputs(tosa_files[0])
113+
# We use shape that is unique to output to check
114+
# that we preserve output order
115+
channel_dims = [s[1] for s in reversed(out_shapes)]
116+
assert channel_dims == [1, 2, 3], (
117+
"Outputs in .tosa do not keep author order: "
118+
f"expected [1, 2, 3], got {channel_dims}"
119+
)

backends/arm/tosa/backend.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
# JIT compiler flows.
1212
#
1313
import logging
14-
from typing import cast, final, List
14+
from collections import deque
15+
from itertools import count
16+
from typing import cast, Dict, final, List, Set
1517

1618
import serializer.tosa_serializer as ts # type: ignore
1719
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
@@ -26,12 +28,38 @@
2628
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
2729
from executorch.exir.backend.compile_spec_schema import CompileSpec
2830
from torch.export.exported_program import ExportedProgram
29-
from torch.fx import Node
31+
from torch.fx import Graph, Node
3032

3133
# TOSA backend debug functionality
3234
logger = logging.getLogger(__name__)
3335

3436

37+
def _annotate_external_ids(ep_graph: Graph) -> Dict[str, int]:
38+
"""
39+
Returns dictionary: node name -> external ids
40+
41+
Assign id to an output node of the model so we can trace it.
42+
"""
43+
node2external_id = {}
44+
45+
def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]):
46+
q = deque(start_nodes)
47+
while q:
48+
n = q.popleft()
49+
if n in seen:
50+
continue
51+
seen.add(n)
52+
node2external_id[n.name] = idx
53+
# Walk backwards so we touch every producer
54+
q.extend(n.all_input_nodes)
55+
56+
out = next(n for n in ep_graph.nodes if n.op == "output")
57+
seen: Set[Node] = set()
58+
for idx, val in enumerate(out.args[0]):
59+
bfs_mark([val], idx, seen)
60+
return node2external_id
61+
62+
3563
def arm_get_first_delegation_tag(graph_module) -> str:
3664
"""Get the first delegation tag from the graph_module or return empty string."""
3765
for node in graph_module.graph.nodes:
@@ -75,6 +103,9 @@ def preprocess( # noqa: C901
75103
if output_format != "tosa":
76104
raise ValueError(f'Invalid output format {output_format}, must be "tosa"')
77105

106+
# Assign to every node external id
107+
node_2_id = _annotate_external_ids(edge_program.graph)
108+
78109
tosa_spec = get_tosa_spec(compile_spec)
79110
if tosa_spec is None:
80111
raise ValueError(
@@ -107,6 +138,29 @@ def preprocess( # noqa: C901
107138
from executorch.backends.arm.operators.node_visitor import get_node_visitors
108139

109140
node_visitors = get_node_visitors(edge_program, tosa_spec, debug_hook)
141+
142+
# Re-shuffle output nodes to preserve author's order
143+
def _external_id(n: Node, node_2_id, fallback: int) -> int:
144+
return node_2_id.get(n.name, fallback)
145+
146+
out_node = next(n for n in graph_module.graph.nodes if n.op == "output")
147+
_counter = count()
148+
149+
# sort nodes by the key that is id
150+
def _sort_key(t: Node) -> int:
151+
return _external_id(t, node_2_id, next(_counter))
152+
153+
orig_ord = tuple(sorted(out_node.args[0], key=_sort_key))
154+
155+
current_order = tuple(out_node.args[0])
156+
if orig_ord != current_order:
157+
replacement = (
158+
list(orig_ord) if isinstance(out_node.args[0], list) else orig_ord
159+
)
160+
out_node.args = (replacement,)
161+
graph_module.graph.lint()
162+
graph_module.recompile()
163+
110164
input_count = 0
111165
for node in graph_module.graph.nodes:
112166
node = cast(Node, node)

0 commit comments

Comments
 (0)