Skip to content

Commit 83cd7ca

Browse files
committed
NXP backend: Explicitly replace IO TensorFormat with NodeFormat.
Before, this change was "hidden", and it was only done to some inputs/outputs.
1 parent 06ff723 commit 83cd7ca

File tree

5 files changed

+48
-16
lines changed

5 files changed

+48
-16
lines changed

backends/nxp/backend/edge_program_converter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch.nn.parameter import Parameter
2020
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
2121
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
22-
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
22+
from executorch.backends.nxp.backend.node_format import NodeFormat, NXP_NODE_FORMAT
2323
from executorch.exir.dialects._ops import ops as exir_ops
2424

2525
# noinspection PyProtectedMember
@@ -63,7 +63,7 @@ def convert_program(
6363
conversion_config: ConversionConfig = _default_conversion_config,
6464
neutron_target_spec: NeutronTargetSpec = _default_target_spec,
6565
custom_delegation_options: CustomDelegationOptions = _default_delegation_options,
66-
) -> (bytes, dict):
66+
) -> (bytes, dict[str, NodeFormat]):
6767
"""
6868
Convert ExportedProgram in Edge dialect to IR (TFLite flatbuffers) as bytes.
6969

backends/nxp/backend/ir/converter/builder/model_builder.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
# License: MIT
66
# See the LICENSE_MIT for more details.
77
#
8+
89
from copy import deepcopy
10+
from itertools import chain
911
from typing import Dict, List, Optional, Union
1012

1113
import executorch.backends.nxp.backend.ir.converter.conversion.translator as translator
@@ -221,7 +223,7 @@ def channels_first_version_of(self, t_tensor: tflite_model.Tensor):
221223
new_tensor.shape = translator.channels_last_shape_to_channels_first(
222224
t_tensor.shape
223225
)
224-
new_tensor.tensor_format = new_tensor.tensor_format.to_node_format()
226+
new_tensor.tensor_format = TensorFormat.CHANNELS_FIRST
225227

226228
perm = translator.create_channels_last_to_channels_first_permutation(
227229
t_tensor.rank
@@ -382,7 +384,7 @@ def _make_inputs_channels_first(self):
382384
input_tensor, input_tensor.name + "_channels_first"
383385
)
384386
new_input.shape = new_input_shape
385-
new_input.tensor_format = input_tensor.tensor_format.to_node_format()
387+
new_input.tensor_format = TensorFormat.CHANNELS_FIRST
386388

387389
transpose = self._create_transpose_operator(
388390
new_input, input_tensor, perm
@@ -458,6 +460,14 @@ def _keep_one_empty_buffer(self):
458460
# It's safe to replace the buffer.
459461
t.tmp_buffer = empty_buffer
460462

463+
def replace_io_tensor_format_with_node_format(self):
464+
for t in chain(
465+
self.get_sub_graph().inputs.tmp_inputs,
466+
self.get_sub_graph().outputs.tmp_outputs,
467+
):
468+
if isinstance(t.tensor_format, TensorFormat):
469+
t.tensor_format = t.tensor_format.to_equal_node_format()
470+
461471
def finish(self) -> tflite_model.Model:
462472
"""Finalize and optimize the converted TFLite model. Then return it.
463473
@@ -478,6 +488,8 @@ def finish(self) -> tflite_model.Model:
478488

479489
self._keep_one_empty_buffer()
480490

491+
self.replace_io_tensor_format_with_node_format()
492+
481493
# Remove outputs, which are not produced by any node. Otherwise, there would be errors after inference.
482494
operator_outputs = []
483495
for op in self.get_operators().vector:

backends/nxp/backend/ir/tensor_formatting.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,32 @@ def is_channels_last(self) -> bool:
3838

3939
@staticmethod
4040
def from_node_format(node_format: NodeFormat):
41-
if node_format.is_channels_first():
42-
return TensorFormat.CHANNELS_LAST
41+
if node_format == NodeFormat.CHANNELS_FIRST:
42+
return TensorFormat.CHANNELS_LAST # Format is swapped.
43+
elif node_format == NodeFormat.CHANNELS_LAST:
44+
return TensorFormat.CHANNELS_FIRST # Format is swapped.
4345
elif node_format == NodeFormat.FORMATLESS:
4446
return TensorFormat.FORMATLESS
4547
else:
4648
return TensorFormat.NONE
4749

4850
def to_node_format(self):
4951
if self == TensorFormat.CHANNELS_LAST:
50-
return NodeFormat.CHANNELS_FIRST
52+
return NodeFormat.CHANNELS_FIRST # Format is swapped.
5153
elif self == TensorFormat.FORMATLESS:
5254
return NodeFormat.FORMATLESS
55+
elif self == TensorFormat.CHANNELS_FIRST:
56+
return NodeFormat.CHANNELS_LAST # Format is swapped.
5357
else:
5458
return NodeFormat.NONE
59+
60+
def to_equal_node_format(self):
61+
match self:
62+
case TensorFormat.CHANNELS_FIRST:
63+
return NodeFormat.CHANNELS_FIRST
64+
case TensorFormat.CHANNELS_LAST:
65+
return NodeFormat.CHANNELS_LAST
66+
case TensorFormat.FORMATLESS:
67+
return NodeFormat.FORMATLESS
68+
case _:
69+
return NodeFormat.NONE

backends/nxp/backend/node_format.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,8 @@ class NodeFormat(Enum):
1919
# Format has not been identified
2020
NONE = 2
2121

22+
# NHWC
23+
CHANNELS_LAST = 3
24+
2225
def is_channels_first(self) -> bool:
2326
return self == NodeFormat.CHANNELS_FIRST

backends/nxp/nxp_backend.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,17 @@
1414

1515
import numpy as np
1616
import torch
17-
from executorch.backends.nxp._passes.remove_getitem_pass import RemoveGetItemPass
1817

18+
from executorch.backends.nxp._passes.remove_getitem_pass import RemoveGetItemPass
1919
from executorch.backends.nxp.backend.edge_program_converter import (
2020
EdgeProgramToIRConverter,
2121
)
2222
from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig
23-
from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat
2423
from executorch.backends.nxp.backend.neutron_converter_manager import (
2524
NeutronConverterManager,
2625
)
2726
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
27+
from executorch.backends.nxp.backend.node_format import NodeFormat
2828
from executorch.backends.nxp.neutron_node_extraction import (
2929
extract_artifacts_from_neutron_node,
3030
NeutronNodeArtifacts,
@@ -264,7 +264,9 @@ def _format_string_for_array(self, array: np.ndarray) -> str:
264264

265265
return f"{array.size}s{self._padding_format_string_for_array(array)}"
266266

267-
def _create_payload_header(self, io_formats, neutron_artifacts) -> np.ndarray:
267+
def _create_payload_header(
268+
self, io_formats: dict[str, list[NodeFormat]], neutron_artifacts
269+
) -> np.ndarray:
268270
"""
269271
Create bytes header for returned payload. It contains information about
270272
input and output tensor formats. Tensors are ordered based on graph signature
@@ -302,9 +304,7 @@ def _create_payload_header(self, io_formats, neutron_artifacts) -> np.ndarray:
302304
for input_name in neutron_artifacts.input_names:
303305
try:
304306
header_data.append(
305-
1
306-
if inputs[input_name.decode()] == TensorFormat.CHANNELS_LAST
307-
else 0
307+
1 if inputs[input_name.decode()] == NodeFormat.CHANNELS_LAST else 0
308308
)
309309
except KeyError:
310310
raise AssertionError(
@@ -315,7 +315,7 @@ def _create_payload_header(self, io_formats, neutron_artifacts) -> np.ndarray:
315315
try:
316316
header_data.append(
317317
1
318-
if outputs[output_name.decode()] == TensorFormat.CHANNELS_LAST
318+
if outputs[output_name.decode()] == NodeFormat.CHANNELS_LAST
319319
else 0
320320
)
321321
except KeyError:
@@ -354,7 +354,9 @@ def _pack_with_alignment(
354354
neutron_artifacts.kernels.tobytes(),
355355
)
356356

357-
def get_binary_payload(self, io_formats, neutron_model) -> bytes:
357+
def get_binary_payload(
358+
self, io_formats: dict[str, list[NodeFormat]], neutron_model
359+
) -> bytes:
358360
"""
359361
Get binary payload for provided input/output tensor formats and neutron_model. Returned data have
360362
following structure:
@@ -374,7 +376,7 @@ def get_binary_payload(self, io_formats, neutron_model) -> bytes:
374376
Tensor format definition: '0x1' == CHANNELS_LAST, '0x0' == FORMATLESS (no format).
375377
376378
:param io_formats: Dictionary with keys 'inputs' and 'outputs' that contains dictionaries
377-
mapping tensor name to TensorFormat.
379+
mapping tensor name to NodeFormat.
378380
:param neutron_model: Neutron model with single NeutronGraph node.
379381
:return: 16 bytes aligned binary payload.
380382
"""

0 commit comments

Comments
 (0)