Skip to content

Commit f1a9df5

Browse files
committed
NXP backend: Store inferred node format in the node.meta.
1 parent b3f3111 commit f1a9df5

File tree

4 files changed

+32
-30
lines changed

4 files changed

+32
-30
lines changed

backends/nxp/backend/edge_program_converter.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from torch.nn.parameter import Parameter
2020
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
2121
from executorch.backends.nxp.backend.node_format_inference import (
22-
NodeFormat,
2322
NodeFormatInference,
23+
NXP_NODE_FORMAT,
2424
)
2525
from executorch.exir.dialects._ops import ops as exir_ops
2626

@@ -70,12 +70,11 @@ def convert_program(
7070
:param custom_delegation_options: Custom user options which affect node delegation.
7171
:return: TFLite flatbuffers as bytes.
7272
"""
73-
node_formats = NodeFormatInference(edge_program).identify_node_formats()
73+
NodeFormatInference(edge_program).identify_node_formats()
7474
parameters_mapping = self.map_inputs_to_parameters(edge_program)
7575

7676
cc = self.build_conversion_context(
7777
parameters_mapping,
78-
node_formats,
7978
conversion_config,
8079
custom_delegation_options,
8180
)
@@ -101,7 +100,7 @@ def convert_program(
101100
def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContext):
102101
for node in nodes:
103102
if node.op == "placeholder":
104-
node_format = context.node_formats[node]
103+
node_format = node.meta[NXP_NODE_FORMAT]
105104

106105
if node.name in context.parameters_mapping:
107106
# Node is placeholder and has data -> append as static tensor with data
@@ -114,7 +113,7 @@ def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContex
114113
context.tflite_builder.append_as_fake_tensor(node, node_format)
115114
elif node.op == "call_function":
116115
# Node is call function -> append only output as a tensor
117-
node_format = context.node_formats[node]
116+
node_format = node.meta[NXP_NODE_FORMAT]
118117
context.tflite_builder.append_as_fake_tensor(node, node_format)
119118
elif node.op == "output":
120119
# Nothing to do
@@ -172,7 +171,6 @@ def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Paramet
172171
@staticmethod
173172
def build_conversion_context(
174173
parameters_mapping: dict,
175-
node_formats: dict[Node, NodeFormat],
176174
conversion_config: ConversionConfig = _default_conversion_config,
177175
custom_delegation_options: CustomDelegationOptions = _default_delegation_options,
178176
) -> ConversionContext:
@@ -187,7 +185,6 @@ def build_conversion_context(
187185
tflite_builder,
188186
conversion_config,
189187
parameters_mapping,
190-
node_formats,
191188
custom_delegation_options,
192189
)
193190

backends/nxp/backend/ir/conversion_context.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,20 @@
1010
from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import (
1111
AtenModelBuilderDirector,
1212
)
13-
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
14-
from torch import Node
1513
from torch.nn import Parameter
1614

1715

1816
class ConversionContext:
1917
tflite_builder: AtenModelBuilderDirector
2018
conversion_config: ConversionConfig
2119
parameters_mapping: dict[str, Parameter]
22-
node_formats: dict[Node, NodeFormat]
2320
custom_delegation_options: CustomDelegationOptions
2421

2522
def __init__(
2623
self,
2724
tflite_builder: AtenModelBuilderDirector,
2825
conversion_config: ConversionConfig,
2926
parameters_mapping: dict,
30-
node_formats: dict[Node, NodeFormat],
3127
custom_delegation_options: CustomDelegationOptions,
3228
):
3329
"""
@@ -39,5 +35,4 @@ def __init__(
3935
self.tflite_builder = tflite_builder
4036
self.conversion_config = conversion_config
4137
self.parameters_mapping = parameters_mapping
42-
self.node_formats = node_formats
4338
self.custom_delegation_options = custom_delegation_options

backends/nxp/backend/node_format_inference.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
logger = logging.getLogger(__name__)
1515

16+
NXP_NODE_FORMAT = "nxp_node_format" # Key into the `meta` attribute of nodes, which is mapped to the inferred format.
17+
1618

1719
class NodeFormat(Enum):
1820
# Node's output in NCHW format
@@ -43,8 +45,6 @@ class NodeFormatInference:
4345
# are channels first but output is formatless).
4446
ops_that_can_change_tensor_format = {exir_ops.edge.aten.view_copy.default}
4547

46-
_node_format_mapping: dict[Node, NodeFormat]
47-
4848
_type_changed_during_last_run: bool
4949

5050
# Mapping between Node and its ancestors (inputs)
@@ -57,7 +57,6 @@ def __init__(self, edge_program: ExportedProgram):
5757
self._edge_program = edge_program
5858

5959
self._nodes = edge_program.graph.nodes
60-
self._node_format_mapping = {}
6160
self._node_inputs = {
6261
node: node.all_input_nodes for node in edge_program.graph.nodes
6362
}
@@ -67,7 +66,7 @@ def __init__(self, edge_program: ExportedProgram):
6766

6867
self._type_changed_during_last_run = False
6968

70-
def identify_node_formats(self) -> dict[Node, NodeFormat]:
69+
def identify_node_formats(self):
7170
self._type_changed_during_last_run = True
7271

7372
# Re-run format inference until there are no changes
@@ -77,7 +76,15 @@ def identify_node_formats(self) -> dict[Node, NodeFormat]:
7776
for node in self._nodes:
7877
self._infer_format_of_nodes(node)
7978

80-
return self._node_format_mapping
79+
for node in self._nodes:
80+
if self._get_node_op_type(node) is None:
81+
continue
82+
if not hasattr(node, "meta"):
83+
logging.warning(f"Node `{node}` does not have the `meta` attribute.")
84+
node.meta = {}
85+
if NXP_NODE_FORMAT not in node.meta:
86+
logging.warning(f"Node `{node}` does not have inferred format.")
87+
node.meta[NXP_NODE_FORMAT] = NodeFormat.NONE
8188

8289
def _infer_format_of_nodes(self, node: Node):
8390
op_type = self._get_node_op_type(node)
@@ -151,7 +158,7 @@ def _assign_format_to_node(self, node: Node, node_format: NodeFormat):
151158
if old_node_format != node_format:
152159
self._type_changed_during_last_run = True
153160

154-
self._node_format_mapping[node] = node_format
161+
node.meta[NXP_NODE_FORMAT] = node_format
155162

156163
def _get_node_op_type(self, node: Node) -> str | None:
157164
"""
@@ -252,8 +259,10 @@ def _node_produces_or_consumes_channels_first_format(self, node) -> bool:
252259
for ancestor_node in input_nodes
253260
)
254261

255-
def _get_node_format(self, node):
256-
return self._node_format_mapping.get(node, NodeFormat.NONE)
262+
def _get_node_format(self, node) -> NodeFormat:
263+
if not hasattr(node, "meta"):
264+
node.meta = {}
265+
return node.meta.get(NXP_NODE_FORMAT, NodeFormat.NONE)
257266

258-
def _node_is_placeholder(self, node: Node):
267+
def _node_is_placeholder(self, node: Node) -> bool:
259268
return node.op == "placeholder"

backends/nxp/tests/test_node_format_inference.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from executorch.backends.nxp.backend.node_format_inference import (
1010
NodeFormat,
1111
NodeFormatInference,
12+
NXP_NODE_FORMAT,
1213
)
1314
from executorch.backends.nxp.neutron_pass_manager import NeutronPassManager
1415
from executorch.backends.nxp.tests.models import (
@@ -27,7 +28,7 @@ def test_convolution():
2728
exir_program = torch.export.export(model, example_input)
2829
edge_program = exir.to_edge(exir_program).exported_program()
2930

30-
node_formats = NodeFormatInference(edge_program).identify_node_formats()
31+
NodeFormatInference(edge_program).identify_node_formats()
3132

3233
expected_mapping = {
3334
"p_conv_weight": NodeFormat.CHANNELS_FIRST,
@@ -37,8 +38,8 @@ def test_convolution():
3738
"output": NodeFormat.CHANNELS_FIRST,
3839
}
3940

40-
for node, node_format in node_formats.items():
41-
assert expected_mapping[node.name] == node_format
41+
for node in edge_program.graph.nodes:
42+
assert expected_mapping[node.name] == node.meta[NXP_NODE_FORMAT]
4243

4344

4445
def test_softmax():
@@ -48,16 +49,16 @@ def test_softmax():
4849
exir_program = torch.export.export(model, example_input)
4950
edge_program = exir.to_edge(exir_program).exported_program()
5051

51-
node_formats = NodeFormatInference(edge_program).identify_node_formats()
52+
NodeFormatInference(edge_program).identify_node_formats()
5253

5354
expected_mapping = {
5455
"x": NodeFormat.FORMATLESS,
5556
"aten__softmax_default": NodeFormat.FORMATLESS,
5657
"output": NodeFormat.FORMATLESS,
5758
}
5859

59-
for node, node_format in node_formats.items():
60-
assert expected_mapping[node.name] == node_format
60+
for node in edge_program.graph.nodes:
61+
assert expected_mapping[node.name] == node.meta[NXP_NODE_FORMAT]
6162

6263

6364
def test_maxpool2d():
@@ -78,13 +79,13 @@ def test_maxpool2d():
7879

7980
# Remove MaxPool-related "getitem" nodes from graph
8081
edge_program = NeutronPassManager(edge_program, [RemoveGetItemPass]).transform()
81-
node_formats = NodeFormatInference(edge_program).identify_node_formats()
82+
NodeFormatInference(edge_program).identify_node_formats()
8283

8384
expected_mapping = {
8485
"x": NodeFormat.CHANNELS_FIRST,
8586
"aten_max_pool2d_default": NodeFormat.CHANNELS_FIRST,
8687
"output": NodeFormat.CHANNELS_FIRST,
8788
}
8889

89-
for node, node_format in node_formats.items():
90-
assert expected_mapping[node.name] == node_format
90+
for node in edge_program.graph.nodes:
91+
assert expected_mapping[node.name] == node.meta[NXP_NODE_FORMAT]

0 commit comments

Comments
 (0)