Skip to content

Commit 22009ae

Browse files
committed
NXP backend: Update pass which removes GetItem nodes, to preserve the node format.
The pass `RemoveGetItemPass` replaces a `max_pool2d_with_indices` node with a `max_pool2d` node, that doesn't require a GetItem afterward. The new operator must, however, preserve the original node format. Therefore, a copy of the pass was created in `backends/nxp/_passes`, where it was modified. The new directory was created, because the pass doesn't follow the `NeutronEdgePass` interface.
1 parent f1a9df5 commit 22009ae

File tree

3 files changed

+117
-1
lines changed

3 files changed

+117
-1
lines changed

backends/nxp/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,18 @@ runtime.python_library(
3232
],
3333
)
3434

35+
runtime.python_library(
36+
name = "_passes",
37+
srcs = glob([
38+
"_passes/*.py",
39+
]),
40+
deps = [
41+
"//caffe2:torch",
42+
"//executorch/exir:lib",
43+
"//executorch/exir:pass_manager",
44+
],
45+
)
46+
3547
runtime.python_library(
3648
name = "quantizer",
3749
srcs = [
@@ -65,6 +77,7 @@ runtime.python_library(
6577
deps = [
6678
":neutron_sdk",
6779
":aten_passes",
80+
":_passes",
6881
":quantizer",
6982
"fbsource//third-party/pypi/flatbuffers:flatbuffers",
7083
"fbsource//third-party/pypi/ml-dtypes:ml-dtypes",
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2025 NXP
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import torch
9+
10+
from executorch.backends.nxp.backend.node_format_inference import (
11+
NodeFormat,
12+
NXP_NODE_FORMAT,
13+
)
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.pass_base import ExportPass, PassResult
16+
17+
18+
class RemoveGetItemPass(ExportPass):
19+
"""
20+
This remove item is used to remove getitem operator for max_pool2d_with_indices.default operator, and replace it with a single operator,
21+
that extracts the first output. More specifically, we are only getting the first output from aten::maxpool2d operator.
22+
Before Pass:
23+
MaxPool2d ---> GetItem[max_values, max_indexes]
24+
After Pass:
25+
MaxPool2d -> max_values
26+
"""
27+
28+
def call(self, graph_module: torch.fx.GraphModule):
29+
module = graph_module
30+
for node in module.graph.nodes:
31+
if node.op == "call_function":
32+
if (
33+
node.target.__name__ == "aten.max_pool2d_with_indices.default"
34+
or node.target.__name__ == "aten.max.dim"
35+
):
36+
users = list(node.users.keys())
37+
38+
if len(users) != 1:
39+
if len(users) == 2 and node.target.__name__ == "aten.max.dim":
40+
# Two users is allowed for max.dim. For that case,
41+
# rather than removing the getitem node in this
42+
# pass, we handle the getitem nodes in the op's
43+
# visitor when serializing
44+
continue
45+
else:
46+
raise AssertionError(
47+
f"Invalid number of users for {node.target.__name__}: {len(users)}"
48+
)
49+
50+
getitem_node = list(node.users.keys())[0]
51+
52+
if getitem_node.target.__name__ != "getitem":
53+
raise AssertionError(
54+
f"Expected max node's user to be getitem, got {getitem_node.target.__name__}"
55+
)
56+
57+
getitem_index = getitem_node.args[1]
58+
59+
with module.graph.inserting_before(node):
60+
if (
61+
node.target.__name__
62+
== "aten.max_pool2d_with_indices.default"
63+
):
64+
if getitem_index != 0:
65+
raise AssertionError(
66+
f"Expected second argument of getitem node for {node.target.__name__} to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values from the op but not getting the corresponding indices."
67+
)
68+
new_max_wd = module.graph.create_node(
69+
"call_function",
70+
exir_ops.edge.aten.max_pool2d.default,
71+
args=node.args,
72+
kwargs=node.kwargs,
73+
)
74+
75+
else:
76+
if getitem_index != 0:
77+
raise AssertionError(
78+
f"Expected second argument of getitem node for {node.target.__name__} to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values or getting both the max values and their corresponding indices from the op, but not getting the indices alone."
79+
)
80+
new_max_wd = module.graph.create_node(
81+
"call_function",
82+
exir_ops.edge.aten.amax.default,
83+
args=node.args,
84+
kwargs=node.kwargs,
85+
)
86+
87+
# MODIFIED PART START
88+
# Make sure to preserve the inferred node format.
89+
new_max_wd.meta[NXP_NODE_FORMAT] = node.meta.get(
90+
NXP_NODE_FORMAT, NodeFormat.NONE
91+
)
92+
# MODIFIED PART END
93+
94+
getitem_node.replace_all_uses_with(new_max_wd)
95+
96+
module.graph.erase_node(getitem_node)
97+
module.graph.erase_node(node)
98+
99+
graph_module.recompile()
100+
# Propagate metadata and retrace module
101+
graph_module = super().call(graph_module).graph_module
102+
103+
return PassResult(graph_module, True)

backends/nxp/nxp_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
import torch
17+
from executorch.backends.nxp._passes.remove_getitem_pass import RemoveGetItemPass
1718

1819
from executorch.backends.nxp.backend.edge_program_converter import (
1920
EdgeProgramToIRConverter,
@@ -28,7 +29,6 @@
2829
NeutronNodeArtifacts,
2930
)
3031
from executorch.backends.nxp.neutron_pass_manager import NeutronPassManager
31-
from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass
3232
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
3333
from executorch.exir.backend.compile_spec_schema import CompileSpec
3434
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier

0 commit comments

Comments
 (0)