Skip to content

Commit 0391fe7

Browse files
authored
NXP backend: Add NXP backend support for aten.unsqueeze (#16467)
### Summary adds conversion and quantization of aten.unsqueeze ### Test plan tests can be manually run using `pytest -c /dev/null backends/nxp/tests/` cc @robert-kalmar @MartinPavella
1 parent e847384 commit 0391fe7

File tree

4 files changed

+237
-2
lines changed

4 files changed

+237
-2
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2026 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Optional
7+
8+
import torch
9+
from torch._subclasses import FakeTensor, FakeTensorMode
10+
from torch.fx import GraphModule, Node
11+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
12+
13+
14+
class ConvertUnsqueezeToViewPass(PassBase):
15+
"""Replace 'aten.unsqueeze.default' with 'aten.view.default'.
16+
17+
x x
18+
│ │
19+
┌─────────────▼─────────────┐ replace with ┌─────────────▼─────────────┐
20+
│ aten.unsqueeze(x, dim) │ ──────────────► │ aten.view.default(x, S) │
21+
└─────────────┬─────────────┘ └─────────────┬─────────────┘
22+
│ │
23+
▼ ▼
24+
out out
25+
"""
26+
27+
@staticmethod
28+
def _is_unsqueeze(node_: Node) -> bool:
29+
return (
30+
node_.op == "call_function"
31+
and node_.target == torch.ops.aten.unsqueeze.default
32+
)
33+
34+
def _create_view_node(self, *view_args) -> Node:
35+
view_target = torch.ops.aten.view.default
36+
view_node = self.graph_module.graph.call_function(view_target, view_args)
37+
38+
view_node.meta["source_fn_stack"] = [
39+
(view_node.name, torch.ops.aten.view.default)
40+
]
41+
42+
x_val = view_args[0].meta["val"]
43+
with FakeTensorMode() as mode:
44+
fake_input = FakeTensor.from_tensor(
45+
torch.empty(x_val.shape, dtype=x_val.dtype), mode
46+
)
47+
output_shape = view_target(fake_input, *view_args[1:]).shape
48+
view_node.meta["val"] = FakeTensor.from_tensor(
49+
torch.empty(output_shape, dtype=x_val.dtype), mode
50+
)
51+
52+
return view_node
53+
54+
def call(self, graph_module: GraphModule) -> Optional[PassResult]:
55+
self.graph_module = graph_module
56+
made_changes = False
57+
58+
if not any(self._is_unsqueeze(n) for n in graph_module.graph.nodes):
59+
return PassResult(graph_module, made_changes)
60+
61+
for node in list(graph_module.graph.nodes):
62+
if not self._is_unsqueeze(node):
63+
continue
64+
65+
input_node = node.all_input_nodes[0]
66+
target_size = node.meta["val"].shape
67+
68+
with self.graph_module.graph.inserting_after(node):
69+
view_node = self._create_view_node(input_node, target_size)
70+
71+
node.replace_all_uses_with(view_node)
72+
self.graph_module.graph.erase_node(node)
73+
74+
made_changes = True
75+
76+
self.graph_module.recompile()
77+
self.graph_module.graph.eliminate_dead_code()
78+
79+
return PassResult(graph_module, made_changes)

backends/nxp/aten_passes/neutron_aten_pass_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 NXP
1+
# Copyright 2026 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -7,6 +7,9 @@
77

88
import torch
99

10+
from executorch.backends.nxp.aten_passes.convert_unsqueeze_to_view import (
11+
ConvertUnsqueezeToViewPass,
12+
)
1013
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_conv_pass import (
1114
FuseBatchNormWithConvPass,
1215
)
@@ -49,6 +52,7 @@ def __init__(
4952
RemoveNodesWithKnownOutputs(),
5053
FuseLinearAndAddPass(),
5154
MoveActivationBeforeConcat(neutron_target_spec),
55+
ConvertUnsqueezeToViewPass(),
5256
]
5357

5458
super().__init__(passes)

backends/nxp/tests/models.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2024-2025 NXP
1+
# Copyright (c) 2024-2026 NXP
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
@@ -670,3 +670,12 @@ def __init__(self):
670670

671671
def forward(self, x):
672672
return self.sequential(x)
673+
674+
675+
class UnsqueezeAddModel(torch.nn.Module):
676+
def __init__(self, dim):
677+
super().__init__()
678+
self.dim = dim
679+
680+
def forward(self, x, y):
681+
return torch.unsqueeze(x + y, self.dim)
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright 2026 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import numpy as np
7+
import pytest
8+
import torch
9+
from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import (
10+
ConvertUnsqueezeToViewPass,
11+
NeutronAtenPassManager,
12+
)
13+
from executorch.backends.nxp.backend.edge_program_converter import (
14+
EdgeProgramToIRConverter,
15+
)
16+
from executorch.backends.nxp.tests.executorch_pipeline import (
17+
neutron_target_spec,
18+
to_quantized_edge_program,
19+
)
20+
from executorch.backends.nxp.tests.executors import (
21+
convert_run_compare,
22+
graph_contains_any_of_ops,
23+
)
24+
25+
from executorch.backends.nxp.tests.models import UnsqueezeAddModel
26+
from executorch.exir.dialects._ops import ops as exir_ops
27+
from torch.export import ExportedProgram
28+
29+
30+
@pytest.fixture(autouse=True)
31+
def reseed_model_per_test_run():
32+
torch.manual_seed(42)
33+
np.random.seed(23)
34+
35+
36+
@pytest.mark.parametrize(
37+
"input_shape, dim",
38+
[
39+
pytest.param((2,), 0, id="1D."),
40+
pytest.param((8, 4, 6), 2, id="3D."),
41+
pytest.param((8, 4, 6, 8), -2, id="4D, negative dim."),
42+
pytest.param((8, 4, 6), 3, id="3D, dim arg is clipped."),
43+
pytest.param((8, 4, 6), -4, id="3D, dim arg is clipped."),
44+
],
45+
)
46+
def test_convert_unsqueeze_to_view_simple(mocker, input_shape, dim):
47+
model = UnsqueezeAddModel(dim)
48+
49+
example_input_1 = torch.rand(input_shape)
50+
example_input_2 = torch.rand(input_shape)
51+
52+
exir_program_aten = torch.export.export(
53+
model,
54+
(example_input_1, example_input_2),
55+
).module()
56+
57+
# Check "aten.unsqueeze.default" is present
58+
assert graph_contains_any_of_ops(
59+
exir_program_aten.graph, [torch.ops.aten.unsqueeze.default]
60+
)
61+
62+
example_input = (example_input_1, example_input_2)
63+
outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)]
64+
65+
# Apply the optimization.
66+
NeutronAtenPassManager(neutron_target_spec, [ConvertUnsqueezeToViewPass()])(
67+
exir_program_aten
68+
)
69+
70+
# Make sure no "aten.unsqueeze.default" is in the model.
71+
assert not graph_contains_any_of_ops(
72+
exir_program_aten.graph,
73+
[torch.ops.aten.unsqueeze.default],
74+
)
75+
76+
# Make sure there is "aten.view.default" in the model.
77+
assert graph_contains_any_of_ops(
78+
exir_program_aten.graph,
79+
[torch.ops.aten.view.default],
80+
)
81+
82+
outputs_after = [o.detach().numpy() for o in exir_program_aten(*example_input)]
83+
84+
# Make sure the model still produces the exact same output.
85+
assert len(outputs_before) == len(outputs_after)
86+
87+
for i in range(len(outputs_before)):
88+
assert np.allclose(outputs_before[i], outputs_after[i])
89+
90+
91+
@pytest.mark.parametrize(
92+
"input_shape, dim",
93+
[
94+
pytest.param((2,), 0, id="1D."),
95+
pytest.param((8, 4, 6), 2, id="3D."),
96+
pytest.param((8, 4, 6, 8), -2, id="4D, negative dim."),
97+
pytest.param((8, 4, 6), 3, id="3D, dim arg is clipped."),
98+
pytest.param((8, 4, 6), -4, id="3D, dim arg is clipped."),
99+
],
100+
)
101+
def test_convert_unsqueeze_to_view_full_pipeline(mocker, input_shape, dim):
102+
model = UnsqueezeAddModel(dim)
103+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
104+
105+
# Run conversion
106+
edge_program = to_quantized_edge_program(
107+
model,
108+
[input_shape, input_shape],
109+
).exported_program()
110+
111+
# Make sure no "aten.unsqueeze.default" is in the model.
112+
assert not graph_contains_any_of_ops(
113+
edge_program.graph,
114+
[
115+
torch.ops.aten.unsqueeze.default,
116+
],
117+
)
118+
119+
# Capture generated model
120+
neutron_ir_model = converter_spy.spy_return[0]
121+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
122+
123+
# Make sure "edge.aten.view_copy.default" is in the model.
124+
assert graph_contains_any_of_ops(
125+
exported_program.graph,
126+
[
127+
exir_ops.edge.aten.view_copy.default,
128+
],
129+
)
130+
131+
example_input_1 = (np.random.random(input_shape).astype(np.float32) * 50).astype(
132+
np.int8
133+
)
134+
example_input_2 = (np.random.random(input_shape).astype(np.float32) * 50).astype(
135+
np.int8
136+
)
137+
example_input = {0: example_input_1, 1: example_input_2}
138+
139+
convert_run_compare(
140+
exported_program,
141+
input_data=example_input,
142+
tfl_model=neutron_ir_model,
143+
)

0 commit comments

Comments
 (0)