Skip to content

Commit 064ef7c

Browse files
NXP backend: Add edge dialect pass to remove useless as_strided_copy nodes. (#16489)
### Summary This PR integrates the `RemoveUselessOpPass` edge dialect pass from the Samsung backend into the NXP backend. This feature is required to support an internal model. ### Test plan Tested internally by NXP on a private model. cc @robert-kalmar
1 parent 32ab7bb commit 064ef7c

File tree

2 files changed

+71
-1
lines changed

2 files changed

+71
-1
lines changed

backends/nxp/edge_passes/neutron_edge_pass_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 NXP
1+
# Copyright 2025-2026 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -8,6 +8,9 @@
88
MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass,
99
)
1010
from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass
11+
from executorch.backends.nxp.edge_passes.remove_as_strided_copy_nodes import (
12+
RemoveUselessAsStridedCopyNodes,
13+
)
1114
from torch.fx.passes.infra.pass_manager import PassManager
1215

1316

@@ -17,6 +20,7 @@ def __init__(self, passes: list[NeutronEdgePass] = None):
1720
passes: list[NeutronEdgePass] = passes or [
1821
MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(),
1922
MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(),
23+
RemoveUselessAsStridedCopyNodes(),
2024
]
2125

2226
super().__init__(
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright (c) 2025 Samsung Electronics Co. LTD
2+
# Copyright 2026 NXP
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import PassResult
10+
from executorch.exir.passes import dead_code_elimination_pass
11+
from torch.fx import GraphModule
12+
13+
14+
class RemoveUselessAsStridedCopyNodes(NeutronEdgePass):
15+
def __init__(self):
16+
super().__init__()
17+
18+
def gen_pattern_as_strided_copy(self, graph_module: GraphModule):
19+
# Unedited method taken from `backends/samsung/_passes/remove_useless_ops.py`.
20+
for node in list(graph_module.graph.nodes): # noqa: C416
21+
if node.target != exir_ops.edge.aten.mean.dim:
22+
continue
23+
if len(node.users) != 1:
24+
continue
25+
successor = list(node.users.keys())[0]
26+
if successor.target != exir_ops.edge.aten.as_strided_copy.default:
27+
continue
28+
is_pattern = True
29+
count = 0
30+
for i, stride in enumerate(successor.args[2]):
31+
if stride < node.meta["val"].size()[i]:
32+
if stride == 1:
33+
count += 1
34+
else:
35+
is_pattern = False
36+
break
37+
if count >= 2:
38+
is_pattern = False
39+
break
40+
if is_pattern:
41+
yield successor
42+
43+
def _fold_as_strided_copy(
44+
self,
45+
graph_module: GraphModule,
46+
) -> bool:
47+
# Method based on `_fold_as_strided_copy()` from `backends/samsung/_passes/remove_useless_ops.py`.
48+
made_changes = False
49+
for as_strided_copy_node in self.gen_pattern_as_strided_copy(graph_module):
50+
for user in list(as_strided_copy_node.users.keys()):
51+
user.replace_input_with(
52+
as_strided_copy_node, as_strided_copy_node.args[0]
53+
)
54+
graph_module.graph.erase_node(as_strided_copy_node)
55+
56+
made_changes = True
57+
58+
return made_changes
59+
60+
def run(self, graph_module: GraphModule):
61+
made_changes = self._fold_as_strided_copy(graph_module)
62+
63+
graph_module.recompile()
64+
dead_code_elimination_pass(graph_module)
65+
66+
return PassResult(graph_module, made_changes)

0 commit comments

Comments
 (0)