Skip to content

Commit a25d435

Browse files
committed
NXP backend: Do not infer format for unknown nodes.
1 parent 1851091 commit a25d435

File tree

7 files changed

+58
-24
lines changed

7 files changed

+58
-24
lines changed

backends/nxp/backend/edge_program_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch.nn.parameter import Parameter
2020
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
2121
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
22-
from executorch.backends.nxp.backend.node_format_inference import NXP_NODE_FORMAT
22+
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
2323
from executorch.exir.dialects._ops import ops as exir_ops
2424

2525
# noinspection PyProtectedMember

backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from executorch.backends.nxp.backend.ir.converter.conversion import translator
1010
from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat
1111
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
12-
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
12+
from executorch.backends.nxp.backend.node_format import NodeFormat
1313
from torch.fx import Node
1414
from torch.nn import Parameter
1515

backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
Concatenation,
2222
)
2323
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
24-
from executorch.backends.nxp.backend.node_format_inference import NXP_NODE_FORMAT
24+
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
2525
from torch.fx import Node
2626
from torch.nn import Parameter
2727

backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
)
2929
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
3030

31-
from executorch.backends.nxp.backend.node_format_inference import NXP_NODE_FORMAT
31+
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
3232
from torch.fx import Node
3333
from torch.nn import Parameter
3434

backends/nxp/backend/ir/tensor_formatting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#
77
from enum import Enum
88

9-
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
9+
from executorch.backends.nxp.backend.node_format import NodeFormat
1010

1111

1212
class TensorFormat(Enum):
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from enum import Enum
7+
8+
# Key into the `meta` attribute of nodes, which is mapped to their inferred node format.
9+
NXP_NODE_FORMAT = "nxp_node_format"
10+
11+
12+
class NodeFormat(Enum):
13+
# Node's output in NCHW format
14+
CHANNELS_FIRST = 0
15+
16+
# Node's output format has no meaning
17+
FORMATLESS = 1
18+
19+
# Format has not been identified
20+
NONE = 2
21+
22+
def is_channels_first(self) -> bool:
23+
return self == NodeFormat.CHANNELS_FIRST

backends/nxp/backend/node_format_inference.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,18 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import logging
7-
from enum import Enum
7+
import operator
88

9+
from executorch.backends.nxp.backend.edge_program_converter import functions_converters
10+
from executorch.backends.nxp.backend.node_format import NodeFormat, NXP_NODE_FORMAT
911
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1013

11-
from torch import Node
1214
from torch.export import ExportedProgram
15+
from torch.fx import Node
1316

1417
logger = logging.getLogger(__name__)
1518

16-
NXP_NODE_FORMAT = "nxp_node_format" # Key into the `meta` attribute of nodes, which is mapped to the inferred format.
17-
18-
19-
class NodeFormat(Enum):
20-
# Node's output in NCHW format
21-
CHANNELS_FIRST = 0
22-
23-
# Node's output format has no meaning
24-
FORMATLESS = 1
25-
26-
# Format has not been identified
27-
NONE = 2
28-
29-
def is_channels_first(self) -> bool:
30-
return self == NodeFormat.CHANNELS_FIRST
31-
3219

3320
class NodeFormatInference:
3421
# Dictionary with Edge Aten ops that always use channels first format.
@@ -53,6 +40,9 @@ class NodeFormatInference:
5340
# Mapping between Node and its children (outputs)
5441
_node_outputs: dict[Node, list[Node]]
5542

43+
# List of all edge operations, which are supported by the converter.
44+
_known_targets: list[EdgeOpOverload]
45+
5646
def __init__(self, edge_program: ExportedProgram):
5747
self._edge_program = edge_program
5848

@@ -66,6 +56,13 @@ def __init__(self, edge_program: ExportedProgram):
6656

6757
self._type_changed_during_last_run = False
6858

59+
self._known_targets = list(functions_converters) + [
60+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
61+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
62+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
63+
operator.getitem,
64+
]
65+
6966
def identify_node_formats(self):
7067
self._type_changed_during_last_run = True
7168

@@ -100,9 +97,19 @@ def _infer_format_of_nodes(self, node: Node):
10097
logger.error(
10198
f"Node format inference for node type: {op_type} not found!"
10299
)
103-
else:
100+
elif node.op != "call_function" or (
101+
hasattr(node, "target") and node.target in self._known_targets
102+
):
103+
# Generic node, or tensor.
104104
self._handle_node_which_can_use_any_node_format(node)
105105

106+
else:
107+
# Don't infer the format for unknown nodes. These nodes will never be delegated, so they will divide
108+
# delegated partitions. Propagating the format here could unnecessarily enforce the format in one of these
109+
# partitions, which would require extra transpositions.
110+
for processed_node in self._node_inputs[node] + [node]:
111+
self._assign_format_to_node(processed_node, NodeFormat.NONE)
112+
106113
def _infer_format_based_on_io_ranks(self, node: Node):
107114
"""Determine the format of the output tensor of given "reshape style operator" based on the ranks of its input
108115
and output.
@@ -155,6 +162,10 @@ def _assign_format_to_node(self, node: Node, node_format: NodeFormat):
155162
# Once CHANNEL_FIRST was assigned, we don't want to reassign
156163
return
157164

165+
if node_format is NodeFormat.NONE and old_node_format is not NodeFormat.NONE:
166+
# A format has already been assigned to the node before. Don't replace it with `NONE`.
167+
return
168+
158169
if old_node_format != node_format:
159170
self._type_changed_during_last_run = True
160171

0 commit comments

Comments
 (0)