Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions backends/arm/test/misc/test_outputs_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# pyre-unsafe
import tempfile
from pathlib import Path

import pytest
import torch
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
from executorch.backends.arm.quantizer.arm_quantizer import (
get_symmetric_quantization_config,
TOSAQuantizer,
)
from executorch.backends.arm.tosa_partitioner import TOSAPartitioner
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir import to_edge_transform_and_lower
from torch import nn
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from tosa import TosaGraph


class Network(nn.Module):
def __init__(self, batch_norm=False):
super().__init__()
self.conv2d_0 = nn.Sequential(
nn.Conv2d(1, 8, 3, padding=1, bias=False),
nn.BatchNorm2d(8) if batch_norm else nn.Identity(),
nn.ReLU(),
)
self.conv2d_1 = nn.Sequential(
nn.Conv2d(8, 8, 3, padding=1, bias=False),
nn.BatchNorm2d(8) if batch_norm else nn.Identity(),
nn.ReLU(),
)
self.conv2d_2 = nn.Sequential(
nn.Conv2d(8, 8, 3, padding=1, bias=False),
nn.BatchNorm2d(8) if batch_norm else nn.Identity(),
nn.ReLU(),
)
self.out_0 = nn.Sequential(nn.Conv2d(8, 1, 3, padding=1, bias=False), nn.ReLU())
self.out_1 = nn.Sequential(nn.Conv2d(8, 2, 3, padding=1, bias=False), nn.ReLU())
self.out_2 = nn.Sequential(nn.Conv2d(8, 3, 3, padding=1, bias=False), nn.ReLU())

def forward(self, x):
x = self.conv2d_0(x)
x = self.conv2d_1(x)
x = self.conv2d_2(x)
out0 = self.out_0(x)
out1 = self.out_1(x)
out2 = self.out_2(x)
return out0, out1, out2


def _read_tosa_outputs(tosa_path: Path):
# Find output tensor names in order and return shapes
buf = tosa_path.read_bytes()
buf_arr = bytearray(buf)
graph = TosaGraph.TosaGraph.GetRootAsTosaGraph(buf_arr, 0)
region = graph.Regions(0)
block = region.Blocks(0)
# Build a dict name - tensor‑shape
tensors = {}
for i in range(block.TensorsLength()):
t = block.Tensors(i)
name = t.Name().decode()
# NHWC
shape = [t.Shape(j) for j in range(t.ShapeLength())]
tensors[name] = shape
shapes = []
for i in range(block.OutputsLength()):
out_name = block.Outputs(i).decode()
shapes.append(tensors[out_name])
return shapes


@pytest.mark.parametrize("batch_size", [1, 4])
def test_network_output_order_and_restore(tmp_path, batch_size):
model = Network(batch_norm=True).eval()
# Prepare spec
spec = TosaSpecification.create_from_string("TOSA-1.0+INT")
compile_spec = ArmCompileSpecBuilder().tosa_compile_spec(tosa_spec=spec).build()
# Setup quantizer
quantizer = TOSAQuantizer(compile_spec)
quantizer.set_global(
get_symmetric_quantization_config(is_qat=True, is_per_channel=False)
)
# Trace the model
dummy = torch.randn(batch_size, 1, 28, 28)
fx_mod = torch.export.export_for_training(model, (dummy,)).module()
model = prepare_pt2e(fx_mod, quantizer)
model(dummy)
model = convert_pt2e(model)
Comment on lines +83 to +95
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use FP profile and avoid quantization in this test? Just to simplify

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the order of inputs has happened only for INT profile, it is not repro in FP.
this test fails without this fix

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it because we don't run the same test in FP? I am failing to see a output order connection with the TOSA profiling? Is there a pass we run only in INT profile which shuffles the order or something?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it just happens here that for FP profile the order is what we want.
this test for INT does not fail in the debugger, for example, and that is why it was impossible for me to find out where exactly it fails during partioning but it fails when we run as a pytest.
the order of outputs is not deterministic. this change makes sure that we re-order according to the initial order.
the reason of these flakiness can be in usage of set, etc inside of Python code.
we need this fix for our project and this is a clean and working solution to make sure that order is as original

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it just happens here that for FP profile the order is what we want.
this test for INT does not fail in the debugger, for example, and that is why it was impossible for me to find out where exactly it fails during partioning but it fails when we run as a pytest.

Weird. Tracking the output order after each pass might lead to something. You can add a print in the base class for ExportPass or something.

the order of outputs is not deterministic.

This is surprising TBH. export() does have this guarantee (if flattened and back then perhaps with preserve_module_call_signature arg). Also, ExportGraphSignature also same, but it adds more stuff if you do buffer modifications.

Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
Outputs = [*mutated_inputs, *flattened_user_outputs]

See - https://docs.pytorch.org/docs/stable/export.html#torch.export.graph_signature.ExportGraphSignature

we need this fix for our project and this is a clean and working solution to make sure that order is as original

I get this and am also OK with landing this as a TOSA level solution.
That said, I would like to understand the root cause a bit better and see what's the right place to fix this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I create a ticket for us to investigate further. I close this PR for now then

# Export to aten dialect
aten_gm = torch.export.export(model, args=(dummy,), strict=True)
with tempfile.TemporaryDirectory() as tmpdir:
art_dir = Path(tmpdir)
part = TOSAPartitioner(
ArmCompileSpecBuilder()
.tosa_compile_spec(spec)
.dump_intermediate_artifacts_to(str(art_dir))
.build()
)
_ = to_edge_transform_and_lower(aten_gm, partitioner=[part])
# Expect exactly one .tosa file in the artefact dir
tosa_files = list(art_dir.glob("*.tosa"))
assert (
len(tosa_files) == 1
), f"Expected 1 .tosa artefact, found {len(tosa_files)} in {art_dir}"
out_shapes = _read_tosa_outputs(tosa_files[0])
# We use shape that is unique to output to check
# that we preserve output order
channel_dims = [s[-1] for s in out_shapes]
assert channel_dims == [1, 2, 3], (
"Outputs in .tosa do not keep author order: "
f"expected [1, 2, 3], got {channel_dims}"
)
57 changes: 55 additions & 2 deletions backends/arm/tosa_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# JIT compiler flows.
#
import logging
from typing import cast, final, List
from collections import deque
from itertools import count
from typing import cast, Dict, final, List, Set

import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import get_node_visitors
Expand All @@ -28,12 +30,38 @@
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch.export.exported_program import ExportedProgram
from torch.fx import Node
from torch.fx import Graph, Node

# TOSA backend debug functionality
logger = logging.getLogger(__name__)


def _annotate_external_ids(ep_graph: Graph) -> Dict[str, int]:
"""
Returns dictionary: node name -> external ids

Assign id to an output node of the model so we can trace it.
"""
node2external_id = {}

def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]):
q = deque(start_nodes)
while q:
n = q.popleft()
if n in seen:
continue
seen.add(n)
node2external_id[n.name] = idx
# Walk backwards so we touch every producer
q.extend(n.all_input_nodes)

out = next(n for n in ep_graph.nodes if n.op == "output")
seen: Set[Node] = set()
for idx, val in enumerate(out.args[0]):
bfs_mark([val], idx, seen)
return node2external_id


def arm_get_first_delegation_tag(graph_module) -> str:
"""Get the first delegation tag from the graph_module or return empty string."""
for node in graph_module.graph.nodes:
Expand Down Expand Up @@ -74,6 +102,9 @@ def preprocess( # noqa: C901
if output_format != "tosa":
raise ValueError(f'Invalid output format {output_format}, must be "tosa"')

# Assign to every node external id
node_2_id = _annotate_external_ids(edge_program.graph)

tosa_spec = get_tosa_spec(compile_spec)
if tosa_spec is None:
raise ValueError(
Expand All @@ -95,6 +126,28 @@ def preprocess( # noqa: C901
exported_program=edge_program
)

# Re-shuffle output nodes to preserve author's order
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so IIUC the order was correct before we ran passes (i.e. for the incoming edge_program) but then got switched up? If yes, did we find if some pass(es) are injecting things out of order in output_node.arg[0]?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the order was correct here. after we run passes, it changed.
because in the debugger this test was working, I was not able to get to the true reason.

def _external_id(n: Node, node_2_id, fallback: int) -> int:
return node_2_id.get(n.name, fallback)

out_node = next(n for n in graph_module.graph.nodes if n.op == "output")
_counter = count()

# sort nodes by the key that is id
def _sort_key(t: Node) -> int:
return _external_id(t, node_2_id, next(_counter))

orig_ord = tuple(sorted(out_node.args[0], key=_sort_key))

current_order = tuple(out_node.args[0])
if orig_ord != current_order:
replacement = (
list(orig_ord) if isinstance(out_node.args[0], list) else orig_ord
)
out_node.args = (replacement,)
graph_module.graph.lint()
graph_module.recompile()

node_visitors = get_node_visitors(edge_program, tosa_spec)
input_count = 0
for node in graph_module.graph.nodes:
Expand Down
Loading