-
Notifications
You must be signed in to change notification settings - Fork 748
Arm backend: Preserve output order #13454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
58fcd50
28b8770
fc17103
4d727f3
6b7e747
71d2b8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| # Copyright 2025 Arm Limited and/or its affiliates. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
| # pyre-unsafe | ||
| import tempfile | ||
| from pathlib import Path | ||
|
|
||
| import pytest | ||
| import torch | ||
| from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder | ||
| from executorch.backends.arm.quantizer.arm_quantizer import ( | ||
| get_symmetric_quantization_config, | ||
| TOSAQuantizer, | ||
| ) | ||
| from executorch.backends.arm.tosa_partitioner import TOSAPartitioner | ||
| from executorch.backends.arm.tosa_specification import TosaSpecification | ||
| from executorch.exir import to_edge_transform_and_lower | ||
| from torch import nn | ||
| from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e | ||
| from tosa import TosaGraph | ||
|
|
||
|
|
||
| class Network(nn.Module): | ||
| def __init__(self, batch_norm=False): | ||
| super().__init__() | ||
| self.conv2d_0 = nn.Sequential( | ||
| nn.Conv2d(1, 8, 3, padding=1, bias=False), | ||
| nn.BatchNorm2d(8) if batch_norm else nn.Identity(), | ||
| nn.ReLU(), | ||
| ) | ||
| self.conv2d_1 = nn.Sequential( | ||
| nn.Conv2d(8, 8, 3, padding=1, bias=False), | ||
| nn.BatchNorm2d(8) if batch_norm else nn.Identity(), | ||
| nn.ReLU(), | ||
| ) | ||
| self.conv2d_2 = nn.Sequential( | ||
| nn.Conv2d(8, 8, 3, padding=1, bias=False), | ||
| nn.BatchNorm2d(8) if batch_norm else nn.Identity(), | ||
| nn.ReLU(), | ||
| ) | ||
| self.out_0 = nn.Sequential(nn.Conv2d(8, 1, 3, padding=1, bias=False), nn.ReLU()) | ||
| self.out_1 = nn.Sequential(nn.Conv2d(8, 2, 3, padding=1, bias=False), nn.ReLU()) | ||
| self.out_2 = nn.Sequential(nn.Conv2d(8, 3, 3, padding=1, bias=False), nn.ReLU()) | ||
|
|
||
| def forward(self, x): | ||
| x = self.conv2d_0(x) | ||
| x = self.conv2d_1(x) | ||
| x = self.conv2d_2(x) | ||
| out0 = self.out_0(x) | ||
| out1 = self.out_1(x) | ||
| out2 = self.out_2(x) | ||
| return out0, out1, out2 | ||
|
|
||
|
|
||
| def _read_tosa_outputs(tosa_path: Path): | ||
| # Find output tensor names in order and return shapes | ||
| buf = tosa_path.read_bytes() | ||
| buf_arr = bytearray(buf) | ||
| graph = TosaGraph.TosaGraph.GetRootAsTosaGraph(buf_arr, 0) | ||
| region = graph.Regions(0) | ||
| block = region.Blocks(0) | ||
| # Build a dict name - tensor‑shape | ||
| tensors = {} | ||
| for i in range(block.TensorsLength()): | ||
| t = block.Tensors(i) | ||
| name = t.Name().decode() | ||
| # NHWC | ||
| shape = [t.Shape(j) for j in range(t.ShapeLength())] | ||
| tensors[name] = shape | ||
| shapes = [] | ||
| for i in range(block.OutputsLength()): | ||
| out_name = block.Outputs(i).decode() | ||
| shapes.append(tensors[out_name]) | ||
| return shapes | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("batch_size", [1, 4]) | ||
| def test_network_output_order_and_restore(tmp_path, batch_size): | ||
| model = Network(batch_norm=True).eval() | ||
| # Prepare spec | ||
| spec = TosaSpecification.create_from_string("TOSA-1.0+INT") | ||
| compile_spec = ArmCompileSpecBuilder().tosa_compile_spec(tosa_spec=spec).build() | ||
| # Setup quantizer | ||
| quantizer = TOSAQuantizer(compile_spec) | ||
| quantizer.set_global( | ||
| get_symmetric_quantization_config(is_qat=True, is_per_channel=False) | ||
| ) | ||
| # Trace the model | ||
| dummy = torch.randn(batch_size, 1, 28, 28) | ||
| fx_mod = torch.export.export_for_training(model, (dummy,)).module() | ||
| model = prepare_pt2e(fx_mod, quantizer) | ||
| model(dummy) | ||
| model = convert_pt2e(model) | ||
| # Export to aten dialect | ||
| aten_gm = torch.export.export(model, args=(dummy,), strict=True) | ||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| art_dir = Path(tmpdir) | ||
| part = TOSAPartitioner( | ||
| ArmCompileSpecBuilder() | ||
| .tosa_compile_spec(spec) | ||
| .dump_intermediate_artifacts_to(str(art_dir)) | ||
| .build() | ||
| ) | ||
| _ = to_edge_transform_and_lower(aten_gm, partitioner=[part]) | ||
| # Expect exactly one .tosa file in the artefact dir | ||
| tosa_files = list(art_dir.glob("*.tosa")) | ||
| assert ( | ||
| len(tosa_files) == 1 | ||
| ), f"Expected 1 .tosa artefact, found {len(tosa_files)} in {art_dir}" | ||
| out_shapes = _read_tosa_outputs(tosa_files[0]) | ||
| # We use shape that is unique to output to check | ||
| # that we preserve output order | ||
| channel_dims = [s[-1] for s in out_shapes] | ||
| assert channel_dims == [1, 2, 3], ( | ||
| "Outputs in .tosa do not keep author order: " | ||
| f"expected [1, 2, 3], got {channel_dims}" | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,9 @@ | |
| # JIT compiler flows. | ||
| # | ||
| import logging | ||
| from typing import cast, final, List | ||
| from collections import deque | ||
| from itertools import count | ||
| from typing import cast, Dict, final, List, Set | ||
|
|
||
| import serializer.tosa_serializer as ts # type: ignore | ||
| from executorch.backends.arm.operators.node_visitor import get_node_visitors | ||
|
|
@@ -28,12 +30,38 @@ | |
| from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult | ||
| from executorch.exir.backend.compile_spec_schema import CompileSpec | ||
| from torch.export.exported_program import ExportedProgram | ||
| from torch.fx import Node | ||
| from torch.fx import Graph, Node | ||
|
|
||
| # TOSA backend debug functionality | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _annotate_external_ids(ep_graph: Graph) -> Dict[str, int]: | ||
| """ | ||
| Returns dictionary: node name -> external ids | ||
|
|
||
| Assign id to an output node of the model so we can trace it. | ||
| """ | ||
| node2external_id = {} | ||
|
|
||
| def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]): | ||
| q = deque(start_nodes) | ||
| while q: | ||
| n = q.popleft() | ||
| if n in seen: | ||
| continue | ||
| seen.add(n) | ||
| node2external_id[n.name] = idx | ||
| # Walk backwards so we touch every producer | ||
| q.extend(n.all_input_nodes) | ||
|
|
||
| out = next(n for n in ep_graph.nodes if n.op == "output") | ||
| seen: Set[Node] = set() | ||
| for idx, val in enumerate(out.args[0]): | ||
| bfs_mark([val], idx, seen) | ||
| return node2external_id | ||
|
|
||
|
|
||
| def arm_get_first_delegation_tag(graph_module) -> str: | ||
| """Get the first delegation tag from the graph_module or return empty string.""" | ||
| for node in graph_module.graph.nodes: | ||
|
|
@@ -74,6 +102,9 @@ def preprocess( # noqa: C901 | |
| if output_format != "tosa": | ||
| raise ValueError(f'Invalid output format {output_format}, must be "tosa"') | ||
|
|
||
| # Assign to every node external id | ||
| node_2_id = _annotate_external_ids(edge_program.graph) | ||
digantdesai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| tosa_spec = get_tosa_spec(compile_spec) | ||
| if tosa_spec is None: | ||
| raise ValueError( | ||
|
|
@@ -95,6 +126,28 @@ def preprocess( # noqa: C901 | |
| exported_program=edge_program | ||
| ) | ||
|
|
||
| # Re-shuffle output nodes to preserve author's order | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so IIUC the order was correct before we ran passes (i.e. for the incoming edge_program) but then got switched up? If yes, did we find if some pass(es) are injecting things out of order in
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, the order was correct here. after we run passes, it changed. |
||
| def _external_id(n: Node, node_2_id, fallback: int) -> int: | ||
| return node_2_id.get(n.name, fallback) | ||
|
|
||
| out_node = next(n for n in graph_module.graph.nodes if n.op == "output") | ||
| _counter = count() | ||
|
|
||
| # sort nodes by the key that is id | ||
| def _sort_key(t: Node) -> int: | ||
| return _external_id(t, node_2_id, next(_counter)) | ||
|
|
||
| orig_ord = tuple(sorted(out_node.args[0], key=_sort_key)) | ||
|
|
||
| current_order = tuple(out_node.args[0]) | ||
| if orig_ord != current_order: | ||
| replacement = ( | ||
| list(orig_ord) if isinstance(out_node.args[0], list) else orig_ord | ||
| ) | ||
| out_node.args = (replacement,) | ||
| graph_module.graph.lint() | ||
| graph_module.recompile() | ||
|
|
||
| node_visitors = get_node_visitors(edge_program, tosa_spec) | ||
| input_count = 0 | ||
| for node in graph_module.graph.nodes: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use FP profile and avoid quantization in this test? Just to simplify
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the order of inputs has happened only for INT profile, it is not repro in FP.
this test fails without this fix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it because we don't run the same test in FP? I am failing to see a output order connection with the TOSA profiling? Is there a pass we run only in INT profile which shuffles the order or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it just happens here that for FP profile the order is what we want.
this test for INT does not fail in the debugger, for example, and that is why it was impossible for me to find out where exactly it fails during partioning but it fails when we run as a pytest.
the order of outputs is not deterministic. this change makes sure that we re-order according to the initial order.
the reason of these flakiness can be in usage of set, etc inside of Python code.
we need this fix for our project and this is a clean and working solution to make sure that order is as original
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weird. Tracking the output order after each pass might lead to something. You can add a print in the base class for ExportPass or something.
This is surprising TBH.
export()does have this guarantee (if flattened and back then perhaps withpreserve_module_call_signaturearg). Also,ExportGraphSignaturealso same, but it adds more stuff if you do buffer modifications.See - https://docs.pytorch.org/docs/stable/export.html#torch.export.graph_signature.ExportGraphSignature
I get this and am also OK with landing this as a TOSA level solution.
That said, I would like to understand the root cause a bit better and see what's the right place to fix this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I create a ticket for us to investigate further. I close this PR for now then