Skip to content

Commit 91711dc

Browse files
committed
NXP backend: Perform node format inference before partitioning.
Before, the format inference was done during conversion to NeutronIR (after partitioning), so the partitioner didn't yet know the formats. Now, the partitioner has the format data, which can be used to accurately select nodes for delegation.
1 parent 22009ae commit 91711dc

File tree

6 files changed

+23
-7
lines changed

6 files changed

+23
-7
lines changed

backends/nxp/backend/edge_program_converter.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@
1818
from torch.fx import Node
1919
from torch.nn.parameter import Parameter
2020
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
21-
from executorch.backends.nxp.backend.node_format_inference import (
22-
NodeFormatInference,
23-
NXP_NODE_FORMAT,
24-
)
21+
from executorch.backends.nxp.backend.node_format_inference import NXP_NODE_FORMAT
2522
from executorch.exir.dialects._ops import ops as exir_ops
2623

2724
# noinspection PyProtectedMember
@@ -70,7 +67,6 @@ def convert_program(
7067
:param custom_delegation_options: Custom user options which affect node delegation.
7168
:return: TFLite flatbuffers as bytes.
7269
"""
73-
NodeFormatInference(edge_program).identify_node_formats()
7470
parameters_mapping = self.map_inputs_to_parameters(edge_program)
7571

7672
cc = self.build_conversion_context(

backends/nxp/neutron_partitioner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torch.fx.passes.operator_support import OperatorSupportBase
2525
from torch.nn import Parameter
2626
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
27+
from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference
2728
from executorch.backends.nxp.nxp_backend import NeutronBackend
2829
from executorch.exir.backend.compile_spec_schema import CompileSpec
2930
from executorch.exir.backend.partitioner import (
@@ -342,6 +343,10 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
342343
allows_single_node_partition=True,
343344
)
344345

346+
# Identify the format (NCHW/NHWC/...) for all nodes in the graph, and store it in the `node.meta`.
347+
# This format will be used by the `CapabilityBasedPartitioner` to determine which nodes will be delegated.
348+
NodeFormatInference(exported_program).identify_node_formats()
349+
345350
partition_list = capability_partitioner.propose_partitions()
346351
for partition in partition_list:
347352
for node in partition.nodes:

backends/nxp/tests/executors.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-2024 NXP
1+
# Copyright 2023-2025 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -22,11 +22,12 @@
2222
NodeConverter,
2323
Target,
2424
)
25+
26+
from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference
2527
from torch.export import ExportedProgram
2628
from torch.fx import Node
2729
from torch.fx.graph import Graph
2830

29-
3031
# If executed on i.MX platform, there is no tensorflow module. And typically the intention is to use the tflite python
3132
# interpreter available in tflite_runtime
3233
try:
@@ -310,6 +311,7 @@ def convert_run_compare(
310311
) -> (TFLiteExecutor, EdgeProgramExecutor):
311312

312313
if tfl_model is None:
314+
NodeFormatInference(edge_program).identify_node_formats()
313315
tfl_model, _ = EdgeProgramToIRConverter().convert_program(
314316
edge_program, conversion_config
315317
)

backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from executorch.backends.nxp.tests.executors import (
1818
convert_run_compare,
1919
graph_contains_any_of_ops,
20+
ToNCHWPreprocess,
21+
ToNHWCPreprocess,
2022
)
2123
from executorch.exir.dialects._ops import ops as exir_ops
2224
from torch.export import ExportedProgram
@@ -126,6 +128,8 @@ def test_cat__channels_first__same_shapes(dim, num_inputs, mocker):
126128
exported_program,
127129
tfl_model=tflite_flatbuffers_model,
128130
input_data=input_data,
131+
tflite_input_preprocess=ToNHWCPreprocess(),
132+
tflite_output_preprocess=ToNCHWPreprocess(),
129133
atol=1,
130134
)
131135

@@ -241,6 +245,8 @@ def test_cat__channels_first__different_shapes(dim, num_inputs, mocker):
241245
exported_program,
242246
tfl_model=tflite_flatbuffers_model,
243247
input_data=input_data,
248+
tflite_input_preprocess=ToNHWCPreprocess(),
249+
tflite_output_preprocess=ToNCHWPreprocess(),
244250
atol=1,
245251
)
246252

backends/nxp/tests/ir/converter/node_converter/test_softmax_converter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
EdgeProgramToIRConverter,
1212
)
1313
from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig
14+
from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference
1415
from executorch.backends.nxp.tests.executorch_pipeline import to_edge_program
1516
from executorch.backends.nxp.tests.executors import convert_run_compare
1617
from executorch.backends.nxp.tests.models import SoftmaxConvModule, SoftmaxModule
@@ -56,6 +57,7 @@ def test_softmax_conversion__unknown_input_format(input_shape, dim: int):
5657
model = SoftmaxModule(dim)
5758

5859
edge_program = to_edge_program(model, input_shape).exported_program()
60+
NodeFormatInference(edge_program).identify_node_formats()
5961

6062
# Currently this test not pass because the convertibility checker doesn't use tensor formats.
6163
with pytest.raises(
@@ -78,6 +80,7 @@ def test_softmax_conversion_channel_last(input_shape, dim: int):
7880
model = SoftmaxConvModule(dim)
7981

8082
edge_program = to_edge_program(model, input_shape).exported_program()
83+
NodeFormatInference(edge_program).identify_node_formats()
8184

8285
# TODO (Robert Kalmar) Currently this test not pass because the convertibility checker doesn't use tensor formats.
8386
with pytest.raises(
@@ -104,6 +107,7 @@ def test_softmax_conversion_unsupported_dims(input_shape, dim: int):
104107
model = SoftmaxModule(dim)
105108

106109
edge_program = to_edge_program(model, input_shape).exported_program()
110+
NodeFormatInference(edge_program).identify_node_formats()
107111

108112
with pytest.raises(
109113
AssertionError, match="`aten__softmax_default` is not convertible"

backends/nxp/tests/test_neutron_converter_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from executorch.backends.nxp.backend.neutron_converter_manager import (
1414
NeutronConverterManager,
1515
)
16+
from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference
1617
from executorch.backends.nxp.tests.models import Conv2dModule
1718

1819

@@ -23,6 +24,7 @@ def test_conv2d_neutron_conversion__default_flavor():
2324
exir_program = torch.export.export(model, example_input)
2425
edge_program_manager = exir.to_edge(exir_program)
2526

27+
NodeFormatInference(edge_program_manager.exported_program()).identify_node_formats()
2628
edge_program_converter = EdgeProgramToIRConverter()
2729
tflite_model, _ = edge_program_converter.convert_program(
2830
edge_program_manager.exported_program()
@@ -45,6 +47,7 @@ def test__conv2d_neutron_conversion__invalid_flavor():
4547
exir_program = torch.export.export(model, example_input)
4648
edge_program_manager = exir.to_edge(exir_program)
4749

50+
NodeFormatInference(edge_program_manager.exported_program()).identify_node_formats()
4851
edge_program_converter = EdgeProgramToIRConverter()
4952
tflite_model, _ = edge_program_converter.convert_program(
5053
edge_program_manager.exported_program()

0 commit comments

Comments
 (0)