Skip to content

Commit bf79544

Browse files
NXP backend: Resolve limitations of uncertain tensor formats. (#14576)
### Summary This PR resolves format related issues by inferring the format (NCHW/NHWC) for all nodes before partitioning. These formats are then used by the NeutronPartitioner to accurately determine which nodes are supported on Neutron. ### Test plan Unit tests provided, and correct function is tested by nearly every test in the nxp backend. cc @kimishpatel
1 parent 11ff521 commit bf79544

18 files changed

+430
-93
lines changed

backends/nxp/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,18 @@ runtime.python_library(
3232
],
3333
)
3434

35+
runtime.python_library(
36+
name = "_passes",
37+
srcs = glob([
38+
"_passes/*.py",
39+
]),
40+
deps = [
41+
"//caffe2:torch",
42+
"//executorch/exir:lib",
43+
"//executorch/exir:pass_manager",
44+
],
45+
)
46+
3547
runtime.python_library(
3648
name = "quantizer",
3749
srcs = [
@@ -65,6 +77,7 @@ runtime.python_library(
6577
deps = [
6678
":neutron_sdk",
6779
":aten_passes",
80+
":_passes",
6881
":quantizer",
6982
"fbsource//third-party/pypi/flatbuffers:flatbuffers",
7083
"fbsource//third-party/pypi/ml-dtypes:ml-dtypes",
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2025 NXP
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import torch
9+
10+
from executorch.backends.nxp.backend.node_format_inference import (
11+
NodeFormat,
12+
NXP_NODE_FORMAT,
13+
)
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.pass_base import ExportPass, PassResult
16+
17+
18+
class RemoveGetItemPass(ExportPass):
19+
"""
20+
This remove item is used to remove getitem operator for max_pool2d_with_indices.default operator, and replace it with a single operator,
21+
that extracts the first output. More specifically, we are only getting the first output from aten::maxpool2d operator.
22+
Before Pass:
23+
MaxPool2d ---> GetItem[max_values, max_indexes]
24+
After Pass:
25+
MaxPool2d -> max_values
26+
"""
27+
28+
def call(self, graph_module: torch.fx.GraphModule):
29+
module = graph_module
30+
for node in module.graph.nodes:
31+
if node.op == "call_function":
32+
if (
33+
node.target.__name__ == "aten.max_pool2d_with_indices.default"
34+
or node.target.__name__ == "aten.max.dim"
35+
):
36+
users = list(node.users.keys())
37+
38+
if len(users) != 1:
39+
if len(users) == 2 and node.target.__name__ == "aten.max.dim":
40+
# Two users is allowed for max.dim. For that case,
41+
# rather than removing the getitem node in this
42+
# pass, we handle the getitem nodes in the op's
43+
# visitor when serializing
44+
continue
45+
else:
46+
raise AssertionError(
47+
f"Invalid number of users for {node.target.__name__}: {len(users)}"
48+
)
49+
50+
getitem_node = list(node.users.keys())[0]
51+
52+
if getitem_node.target.__name__ != "getitem":
53+
raise AssertionError(
54+
f"Expected max node's user to be getitem, got {getitem_node.target.__name__}"
55+
)
56+
57+
getitem_index = getitem_node.args[1]
58+
59+
with module.graph.inserting_before(node):
60+
if (
61+
node.target.__name__
62+
== "aten.max_pool2d_with_indices.default"
63+
):
64+
if getitem_index != 0:
65+
raise AssertionError(
66+
f"Expected second argument of getitem node for {node.target.__name__} to be 0, got "
67+
f"{getitem_index}. XNNPACK delegate currently only supports getting just the max "
68+
"values from the op but not getting the corresponding indices."
69+
)
70+
new_max_wd = module.graph.create_node(
71+
"call_function",
72+
exir_ops.edge.aten.max_pool2d.default,
73+
args=node.args,
74+
kwargs=node.kwargs,
75+
)
76+
77+
else:
78+
if getitem_index != 0:
79+
raise AssertionError(
80+
f"Expected second argument of getitem node for {node.target.__name__} to be 0, got "
81+
f"{getitem_index}. XNNPACK delegate currently only supports getting just the max "
82+
"values or getting both the max values and their corresponding indices from the "
83+
"op, but not getting the indices alone."
84+
)
85+
new_max_wd = module.graph.create_node(
86+
"call_function",
87+
exir_ops.edge.aten.amax.default,
88+
args=node.args,
89+
kwargs=node.kwargs,
90+
)
91+
92+
# MODIFIED PART START
93+
# Make sure to preserve the inferred node format.
94+
new_max_wd.meta[NXP_NODE_FORMAT] = node.meta.get(
95+
NXP_NODE_FORMAT, NodeFormat.NONE
96+
)
97+
# MODIFIED PART END
98+
99+
getitem_node.replace_all_uses_with(new_max_wd)
100+
101+
module.graph.erase_node(getitem_node)
102+
module.graph.erase_node(node)
103+
104+
graph_module.recompile()
105+
# Propagate metadata and retrace module
106+
graph_module = super().call(graph_module).graph_module
107+
108+
return PassResult(graph_module, True)

backends/nxp/backend/edge_program_converter.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,7 @@
1919
from torch.nn.parameter import Parameter
2020
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
2121
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
22-
from executorch.backends.nxp.backend.node_format_inference import (
23-
NodeFormat,
24-
NodeFormatInference,
25-
)
22+
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
2623
from executorch.exir.dialects._ops import ops as exir_ops
2724

2825
# noinspection PyProtectedMember
@@ -76,12 +73,10 @@ def convert_program(
7673
:param custom_delegation_options: Custom user options which affect node delegation.
7774
:return: TFLite flatbuffers as bytes.
7875
"""
79-
node_formats = NodeFormatInference(edge_program).identify_node_formats()
8076
parameters_mapping = self.map_inputs_to_parameters(edge_program)
8177

8278
cc = self.build_conversion_context(
8379
parameters_mapping,
84-
node_formats,
8580
neutron_target_spec,
8681
conversion_config,
8782
custom_delegation_options,
@@ -108,7 +103,7 @@ def convert_program(
108103
def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContext):
109104
for node in nodes:
110105
if node.op == "placeholder":
111-
node_format = context.node_formats[node]
106+
node_format = node.meta[NXP_NODE_FORMAT]
112107

113108
if node.name in context.parameters_mapping:
114109
# Node is placeholder and has data -> append as static tensor with data
@@ -121,7 +116,7 @@ def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContex
121116
context.tflite_builder.append_as_fake_tensor(node, node_format)
122117
elif node.op == "call_function":
123118
# Node is call function -> append only output as a tensor
124-
node_format = context.node_formats[node]
119+
node_format = node.meta[NXP_NODE_FORMAT]
125120
context.tflite_builder.append_as_fake_tensor(node, node_format)
126121
elif node.op == "output":
127122
# Nothing to do
@@ -179,7 +174,6 @@ def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Paramet
179174
@staticmethod
180175
def build_conversion_context(
181176
parameters_mapping: dict,
182-
node_formats: dict[Node, NodeFormat],
183177
neutron_target_spec: NeutronTargetSpec,
184178
conversion_config: ConversionConfig = _default_conversion_config,
185179
custom_delegation_options: CustomDelegationOptions = _default_delegation_options,
@@ -195,7 +189,6 @@ def build_conversion_context(
195189
tflite_builder,
196190
conversion_config,
197191
parameters_mapping,
198-
node_formats,
199192
custom_delegation_options,
200193
)
201194

backends/nxp/backend/ir/conversion_context.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,20 @@
1010
from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import (
1111
AtenModelBuilderDirector,
1212
)
13-
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
14-
from torch import Node
1513
from torch.nn import Parameter
1614

1715

1816
class ConversionContext:
1917
tflite_builder: AtenModelBuilderDirector
2018
conversion_config: ConversionConfig
2119
parameters_mapping: dict[str, Parameter]
22-
node_formats: dict[Node, NodeFormat]
2320
custom_delegation_options: CustomDelegationOptions
2421

2522
def __init__(
2623
self,
2724
tflite_builder: AtenModelBuilderDirector,
2825
conversion_config: ConversionConfig,
2926
parameters_mapping: dict,
30-
node_formats: dict[Node, NodeFormat],
3127
custom_delegation_options: CustomDelegationOptions,
3228
):
3329
"""
@@ -39,5 +35,4 @@ def __init__(
3935
self.tflite_builder = tflite_builder
4036
self.conversion_config = conversion_config
4137
self.parameters_mapping = parameters_mapping
42-
self.node_formats = node_formats
4338
self.custom_delegation_options = custom_delegation_options

backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from executorch.backends.nxp.backend.ir.converter.conversion import translator
1010
from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat
1111
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
12-
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
12+
from executorch.backends.nxp.backend.node_format import NodeFormat
1313
from torch.fx import Node
1414
from torch.nn import Parameter
1515

backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
CustomDelegationOptions,
1010
)
1111
from executorch.backends.nxp.backend.ir.converter.conversion import translator
12+
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
13+
create_channels_first_to_channels_last_permutation,
14+
)
1215
from executorch.backends.nxp.backend.ir.converter.node_converter import (
1316
_is_dequant_node,
1417
_is_quant_node,
@@ -18,6 +21,7 @@
1821
Concatenation,
1922
)
2023
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
24+
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
2125
from torch.fx import Node
2226
from torch.nn import Parameter
2327

@@ -85,39 +89,48 @@ def _is_supported_on_target(
8589
if dim == 0:
8690
return False
8791

88-
# If all input shapes are equal, the neutron is able to pad the last dimension of inputs and outputs.
89-
input_shapes = [_get_shape(input_) for input_ in node.all_input_nodes]
90-
if input_shapes.count(input_shapes[0]) == len(input_shapes):
91-
if dim == len(input_shapes[0]) - 1:
92-
return True
92+
# Neutron requires the channels to be a multiple of `num_macs`. The channels could either be the second or the
93+
# last dimension, depending on the formats of the node.
94+
if node.meta[NXP_NODE_FORMAT].is_channels_first():
95+
# During conversion to IR, the shape will be permuted to channels last, and the dimension on index
96+
# `1` will end up being the channels (last dim in NHWC).
97+
channels_index = 1
98+
to_nhwc_perm = create_channels_first_to_channels_last_permutation(
99+
len(node.meta["val"].shape), True
100+
)
101+
dim = to_nhwc_perm.index(
102+
dim
103+
) # Make sure the dim points to the NHWC dimension.
104+
else:
105+
# The shape will not be permuted during conversion, so the channels will remain the last dimension.
106+
channels_index = -1
93107

94-
# Neutron requires the channels to be a multiple of numMacs. The channels could either be the second or the
95-
# last dimension, depending on the formats of the node. The format, however, cannot be determined
96-
# during conversion, as it depends on what other nodes are delegated.
97108
input_channels = [
98-
# The second dimension is the channels in PyTorch. If the inputs/output are not channels first, it
99-
# will still be the channels in the IR.
100-
_get_shape(input_)[1]
101-
for input_ in node.all_input_nodes
102-
] + [
103-
# If the inputs/outputs are channels first, the last dimension will be the channels.
104-
_get_shape(input_)[-1]
105-
for input_ in node.all_input_nodes
109+
_get_shape(input_)[channels_index] for input_ in node.all_input_nodes
106110
]
107-
if any(
108-
(input_channel % neutron_target_spec.get_num_macs()) != 0
109-
for input_channel in input_channels
110-
):
111+
output_channels = _get_shape(node)[channels_index]
112+
113+
num_macs = neutron_target_spec.get_num_macs()
114+
input_shapes = [_get_shape(input_) for input_ in node.all_input_nodes]
115+
if any((input_channel % num_macs) != 0 for input_channel in input_channels):
111116
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1492
112-
return False
113117

114-
output_channels = [_get_shape(node)[1], _get_shape(node)[-1]]
115-
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493
116-
if any(
117-
(out_c % neutron_target_spec.get_num_macs()) != 0
118-
for out_c in output_channels
119-
):
120-
return False
118+
# If all input shapes are equal, the neutron is able to pad the last dimension of the inputs.
119+
if not (
120+
input_shapes.count(input_shapes[0]) == len(input_shapes)
121+
and dim == len(input_shapes[0]) - 1
122+
):
123+
return False
124+
125+
if (output_channels % num_macs) != 0:
126+
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493
127+
128+
# If all input shapes are equal, the neutron is able to pad the last dimension of the output.
129+
if not (
130+
input_shapes.count(input_shapes[0]) == len(input_shapes)
131+
and dim == len(input_shapes[0]) - 1
132+
):
133+
return False
121134

122135
if len(node.all_input_nodes) < 2: # Not supported on Neutron
123136
# TODO Try to skip the operator if this case is realistic.

backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
pad_v2_options,
2828
)
2929
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
30+
31+
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
3032
from torch.fx import Node
3133
from torch.nn import Parameter
3234

@@ -40,9 +42,16 @@ def _is_supported_on_target(
4042
custom_delegation_options: CustomDelegationOptions,
4143
) -> bool:
4244
paddings = node.args[1]
43-
if len(paddings) > 4 and paddings[4:6] != [0, 0]:
44-
# Attempt to Pad channels dimension, which is not supported on Neutron.
45-
return False
45+
if node.meta[NXP_NODE_FORMAT].is_channels_first():
46+
# Dim `1` will end up being the channels. It is padded by paddings[4:6].
47+
if len(paddings) > 4 and paddings[4:6] != [0, 0]:
48+
# Attempt to Pad channels dimension -> currently not supported
49+
return False
50+
else:
51+
# Dim `-1` will end up being the channels. It is padded by paddings[:2].
52+
if len(paddings) > 0 and paddings[:2] != [0, 0]:
53+
# Attempt to Pad channels dimension -> currently not supported
54+
return False
4655

4756
return True
4857

@@ -65,10 +74,6 @@ def _is_supported_in_IR(
6574
if not NodeConverter._has_shared_q_params_if_quantized(node):
6675
return False
6776

68-
if len(paddings) > 4 and paddings[4:6] != [0, 0]:
69-
# Attempt to Pad channels dimension -> currently not supported
70-
return False
71-
7277
return True
7378

7479
# noinspection PyMethodMayBeStatic

backends/nxp/backend/ir/tensor_formatting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#
77
from enum import Enum
88

9-
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
9+
from executorch.backends.nxp.backend.node_format import NodeFormat
1010

1111

1212
class TensorFormat(Enum):

0 commit comments

Comments
 (0)