Skip to content

Commit 42d2f8d

Browse files
committed
NXP backend: Store inferred node format in the node.meta.
1 parent 70ea661 commit 42d2f8d

File tree

4 files changed

+34
-32
lines changed

4 files changed

+34
-32
lines changed

backends/nxp/backend/edge_program_converter.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
2121
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
2222
from executorch.backends.nxp.backend.node_format_inference import (
23-
NodeFormat,
2423
NodeFormatInference,
24+
NXP_NODE_FORMAT,
2525
)
2626
from executorch.exir.dialects._ops import ops as exir_ops
2727

@@ -74,12 +74,11 @@ def convert_program(
7474
:param custom_delegation_options: Custom user options which affect node delegation.
7575
:return: TFLite flatbuffers as bytes.
7676
"""
77-
node_formats = NodeFormatInference(edge_program).identify_node_formats()
77+
NodeFormatInference(edge_program).identify_node_formats()
7878
parameters_mapping = self.map_inputs_to_parameters(edge_program)
7979

8080
cc = self.build_conversion_context(
8181
parameters_mapping,
82-
node_formats,
8382
neutron_target_spec,
8483
conversion_config,
8584
custom_delegation_options,
@@ -106,7 +105,7 @@ def convert_program(
106105
def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContext):
107106
for node in nodes:
108107
if node.op == "placeholder":
109-
node_format = context.node_formats[node]
108+
node_format = node.meta[NXP_NODE_FORMAT]
110109

111110
if node.name in context.parameters_mapping:
112111
# Node is placeholder and has data -> append as static tensor with data
@@ -119,7 +118,7 @@ def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContex
119118
context.tflite_builder.append_as_fake_tensor(node, node_format)
120119
elif node.op == "call_function":
121120
# Node is call function -> append only output as a tensor
122-
node_format = context.node_formats[node]
121+
node_format = node.meta[NXP_NODE_FORMAT]
123122
context.tflite_builder.append_as_fake_tensor(node, node_format)
124123
elif node.op == "output":
125124
# Nothing to do
@@ -177,7 +176,6 @@ def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Paramet
177176
@staticmethod
178177
def build_conversion_context(
179178
parameters_mapping: dict,
180-
node_formats: dict[Node, NodeFormat],
181179
neutron_target_spec: NeutronTargetSpec,
182180
conversion_config: ConversionConfig = _default_conversion_config,
183181
custom_delegation_options: CustomDelegationOptions = _default_delegation_options,
@@ -193,7 +191,6 @@ def build_conversion_context(
193191
tflite_builder,
194192
conversion_config,
195193
parameters_mapping,
196-
node_formats,
197194
custom_delegation_options,
198195
)
199196

backends/nxp/backend/ir/conversion_context.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,20 @@
1010
from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import (
1111
AtenModelBuilderDirector,
1212
)
13-
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
14-
from torch import Node
1513
from torch.nn import Parameter
1614

1715

1816
class ConversionContext:
1917
tflite_builder: AtenModelBuilderDirector
2018
conversion_config: ConversionConfig
2119
parameters_mapping: dict[str, Parameter]
22-
node_formats: dict[Node, NodeFormat]
2320
custom_delegation_options: CustomDelegationOptions
2421

2522
def __init__(
2623
self,
2724
tflite_builder: AtenModelBuilderDirector,
2825
conversion_config: ConversionConfig,
2926
parameters_mapping: dict,
30-
node_formats: dict[Node, NodeFormat],
3127
custom_delegation_options: CustomDelegationOptions,
3228
):
3329
"""
@@ -39,5 +35,4 @@ def __init__(
3935
self.tflite_builder = tflite_builder
4036
self.conversion_config = conversion_config
4137
self.parameters_mapping = parameters_mapping
42-
self.node_formats = node_formats
4338
self.custom_delegation_options = custom_delegation_options

backends/nxp/backend/node_format_inference.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 NXP
1+
# Copyright 2024-2025 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -13,6 +13,8 @@
1313

1414
logger = logging.getLogger(__name__)
1515

16+
NXP_NODE_FORMAT = "nxp_node_format" # Key into the `meta` attribute of nodes, which is mapped to the inferred format.
17+
1618

1719
class NodeFormat(Enum):
1820
# Node's output in NCHW format
@@ -43,8 +45,6 @@ class NodeFormatInference:
4345
# are channels first but output is formatless).
4446
ops_that_can_change_tensor_format = {exir_ops.edge.aten.view_copy.default}
4547

46-
_node_format_mapping: dict[Node, NodeFormat]
47-
4848
_type_changed_during_last_run: bool
4949

5050
# Mapping between Node and its ancestors (inputs)
@@ -57,7 +57,6 @@ def __init__(self, edge_program: ExportedProgram):
5757
self._edge_program = edge_program
5858

5959
self._nodes = edge_program.graph.nodes
60-
self._node_format_mapping = {}
6160
self._node_inputs = {
6261
node: node.all_input_nodes for node in edge_program.graph.nodes
6362
}
@@ -67,7 +66,7 @@ def __init__(self, edge_program: ExportedProgram):
6766

6867
self._type_changed_during_last_run = False
6968

70-
def identify_node_formats(self) -> dict[Node, NodeFormat]:
69+
def identify_node_formats(self):
7170
self._type_changed_during_last_run = True
7271

7372
# Re-run format inference until there are no changes
@@ -77,7 +76,15 @@ def identify_node_formats(self) -> dict[Node, NodeFormat]:
7776
for node in self._nodes:
7877
self._infer_format_of_nodes(node)
7978

80-
return self._node_format_mapping
79+
for node in self._nodes:
80+
if self._get_node_op_type(node) is None:
81+
continue
82+
if not hasattr(node, "meta"):
83+
logging.warning(f"Node `{node}` does not have the `meta` attribute.")
84+
node.meta = {}
85+
if NXP_NODE_FORMAT not in node.meta:
86+
logging.warning(f"Node `{node}` does not have inferred format.")
87+
node.meta[NXP_NODE_FORMAT] = NodeFormat.NONE
8188

8289
def _infer_format_of_nodes(self, node: Node):
8390
op_type = self._get_node_op_type(node)
@@ -151,7 +158,7 @@ def _assign_format_to_node(self, node: Node, node_format: NodeFormat):
151158
if old_node_format != node_format:
152159
self._type_changed_during_last_run = True
153160

154-
self._node_format_mapping[node] = node_format
161+
node.meta[NXP_NODE_FORMAT] = node_format
155162

156163
def _get_node_op_type(self, node: Node) -> str | None:
157164
"""
@@ -252,8 +259,10 @@ def _node_produces_or_consumes_channels_first_format(self, node) -> bool:
252259
for ancestor_node in input_nodes
253260
)
254261

255-
def _get_node_format(self, node):
256-
return self._node_format_mapping.get(node, NodeFormat.NONE)
262+
def _get_node_format(self, node) -> NodeFormat:
263+
if not hasattr(node, "meta"):
264+
node.meta = {}
265+
return node.meta.get(NXP_NODE_FORMAT, NodeFormat.NONE)
257266

258-
def _node_is_placeholder(self, node: Node):
267+
def _node_is_placeholder(self, node: Node) -> bool:
259268
return node.op == "placeholder"

backends/nxp/tests/test_node_format_inference.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 NXP
1+
# Copyright 2024-2025 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -9,6 +9,7 @@
99
from executorch.backends.nxp.backend.node_format_inference import (
1010
NodeFormat,
1111
NodeFormatInference,
12+
NXP_NODE_FORMAT,
1213
)
1314
from executorch.backends.nxp.neutron_pass_manager import NeutronPassManager
1415
from executorch.backends.nxp.tests.models import (
@@ -27,7 +28,7 @@ def test_convolution():
2728
exir_program = torch.export.export(model, example_input)
2829
edge_program = exir.to_edge(exir_program).exported_program()
2930

30-
node_formats = NodeFormatInference(edge_program).identify_node_formats()
31+
NodeFormatInference(edge_program).identify_node_formats()
3132

3233
expected_mapping = {
3334
"p_conv_weight": NodeFormat.CHANNELS_FIRST,
@@ -37,8 +38,8 @@ def test_convolution():
3738
"output": NodeFormat.CHANNELS_FIRST,
3839
}
3940

40-
for node, node_format in node_formats.items():
41-
assert expected_mapping[node.name] == node_format
41+
for node in edge_program.graph.nodes:
42+
assert expected_mapping[node.name] == node.meta[NXP_NODE_FORMAT]
4243

4344

4445
def test_softmax():
@@ -48,16 +49,16 @@ def test_softmax():
4849
exir_program = torch.export.export(model, example_input)
4950
edge_program = exir.to_edge(exir_program).exported_program()
5051

51-
node_formats = NodeFormatInference(edge_program).identify_node_formats()
52+
NodeFormatInference(edge_program).identify_node_formats()
5253

5354
expected_mapping = {
5455
"x": NodeFormat.FORMATLESS,
5556
"aten__softmax_default": NodeFormat.FORMATLESS,
5657
"output": NodeFormat.FORMATLESS,
5758
}
5859

59-
for node, node_format in node_formats.items():
60-
assert expected_mapping[node.name] == node_format
60+
for node in edge_program.graph.nodes:
61+
assert expected_mapping[node.name] == node.meta[NXP_NODE_FORMAT]
6162

6263

6364
def test_maxpool2d():
@@ -78,13 +79,13 @@ def test_maxpool2d():
7879

7980
# Remove MaxPool-related "getitem" nodes from graph
8081
edge_program = NeutronPassManager(edge_program, [RemoveGetItemPass]).transform()
81-
node_formats = NodeFormatInference(edge_program).identify_node_formats()
82+
NodeFormatInference(edge_program).identify_node_formats()
8283

8384
expected_mapping = {
8485
"x": NodeFormat.CHANNELS_FIRST,
8586
"aten_max_pool2d_default": NodeFormat.CHANNELS_FIRST,
8687
"output": NodeFormat.CHANNELS_FIRST,
8788
}
8889

89-
for node, node_format in node_formats.items():
90-
assert expected_mapping[node.name] == node_format
90+
for node in edge_program.graph.nodes:
91+
assert expected_mapping[node.name] == node.meta[NXP_NODE_FORMAT]

0 commit comments

Comments
 (0)