diff --git a/backends/nxp/TARGETS b/backends/nxp/TARGETS index 875f9813f43..a5a0508b33c 100644 --- a/backends/nxp/TARGETS +++ b/backends/nxp/TARGETS @@ -32,6 +32,18 @@ runtime.python_library( ], ) +runtime.python_library( + name = "_passes", + srcs = glob([ + "_passes/*.py", + ]), + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/exir:pass_manager", + ], +) + runtime.python_library( name = "quantizer", srcs = [ @@ -65,6 +77,7 @@ runtime.python_library( deps = [ ":neutron_sdk", ":aten_passes", + ":_passes", ":quantizer", "fbsource//third-party/pypi/flatbuffers:flatbuffers", "fbsource//third-party/pypi/ml-dtypes:ml-dtypes", diff --git a/backends/nxp/_passes/remove_getitem_pass.py b/backends/nxp/_passes/remove_getitem_pass.py new file mode 100644 index 00000000000..646f5083adf --- /dev/null +++ b/backends/nxp/_passes/remove_getitem_pass.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2025 NXP +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from executorch.backends.nxp.backend.node_format_inference import ( + NodeFormat, + NXP_NODE_FORMAT, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class RemoveGetItemPass(ExportPass): + """ + This remove item is used to remove getitem operator for max_pool2d_with_indices.default operator, and replace it with a single operator, + that extracts the first output. More specifically, we are only getting the first output from aten::maxpool2d operator. + Before Pass: + MaxPool2d ---> GetItem[max_values, max_indexes] + After Pass: + MaxPool2d -> max_values + """ + + def call(self, graph_module: torch.fx.GraphModule): + module = graph_module + for node in module.graph.nodes: + if node.op == "call_function": + if ( + node.target.__name__ == "aten.max_pool2d_with_indices.default" + or node.target.__name__ == "aten.max.dim" + ): + users = list(node.users.keys()) + + if len(users) != 1: + if len(users) == 2 and node.target.__name__ == "aten.max.dim": + # Two users is allowed for max.dim. For that case, + # rather than removing the getitem node in this + # pass, we handle the getitem nodes in the op's + # visitor when serializing + continue + else: + raise AssertionError( + f"Invalid number of users for {node.target.__name__}: {len(users)}" + ) + + getitem_node = list(node.users.keys())[0] + + if getitem_node.target.__name__ != "getitem": + raise AssertionError( + f"Expected max node's user to be getitem, got {getitem_node.target.__name__}" + ) + + getitem_index = getitem_node.args[1] + + with module.graph.inserting_before(node): + if ( + node.target.__name__ + == "aten.max_pool2d_with_indices.default" + ): + if getitem_index != 0: + raise AssertionError( + f"Expected second argument of getitem node for {node.target.__name__} to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values from the op but not getting the corresponding indices." + ) + new_max_wd = module.graph.create_node( + "call_function", + exir_ops.edge.aten.max_pool2d.default, + args=node.args, + kwargs=node.kwargs, + ) + + else: + if getitem_index != 0: + raise AssertionError( + f"Expected second argument of getitem node for {node.target.__name__} to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values or getting both the max values and their corresponding indices from the op, but not getting the indices alone." + ) + new_max_wd = module.graph.create_node( + "call_function", + exir_ops.edge.aten.amax.default, + args=node.args, + kwargs=node.kwargs, + ) + + # MODIFIED PART START + # Make sure to preserve the inferred node format. + new_max_wd.meta[NXP_NODE_FORMAT] = node.meta.get( + NXP_NODE_FORMAT, NodeFormat.NONE + ) + # MODIFIED PART END + + getitem_node.replace_all_uses_with(new_max_wd) + + module.graph.erase_node(getitem_node) + module.graph.erase_node(node) + + graph_module.recompile() + # Propagate metadata and retrace module + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/nxp/backend/edge_program_converter.py b/backends/nxp/backend/edge_program_converter.py index fcfb9787715..4189ac2dc47 100644 --- a/backends/nxp/backend/edge_program_converter.py +++ b/backends/nxp/backend/edge_program_converter.py @@ -19,10 +19,7 @@ from torch.nn.parameter import Parameter from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403 from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec -from executorch.backends.nxp.backend.node_format_inference import ( - NodeFormat, - NodeFormatInference, -) +from executorch.backends.nxp.backend.node_format import NodeFormat, NXP_NODE_FORMAT from executorch.exir.dialects._ops import ops as exir_ops # noinspection PyProtectedMember @@ -66,7 +63,7 @@ def convert_program( conversion_config: ConversionConfig = _default_conversion_config, neutron_target_spec: NeutronTargetSpec = _default_target_spec, custom_delegation_options: CustomDelegationOptions = _default_delegation_options, - ) -> (bytes, dict): + ) -> (bytes, dict[str, NodeFormat]): """ Convert ExportedProgram in Edge dialect to IR (TFLite flatbuffers) as bytes. @@ -76,12 +73,10 @@ def convert_program( :param custom_delegation_options: Custom user options which affect node delegation. :return: TFLite flatbuffers as bytes. """ - node_formats = NodeFormatInference(edge_program).identify_node_formats() parameters_mapping = self.map_inputs_to_parameters(edge_program) cc = self.build_conversion_context( parameters_mapping, - node_formats, neutron_target_spec, conversion_config, custom_delegation_options, @@ -92,13 +87,16 @@ def convert_program( self._convert_qdq_cluster_q_dq_nodes(edge_program.graph.nodes, cc) self._process_nodes(edge_program.graph.nodes, cc) - # Assign output - io_formats = cc.tflite_builder.assign_model_io_to_subgraph_and_get_io_formats( - edge_program.graph_signature - ) + # Assign the model its inputs and outputs. + cc.tflite_builder.assign_model_io_to_subgraph(edge_program.graph_signature) - # TFLite model generation + # Apply optimizations and finalize the model. internal_tflite_model = cc.tflite_builder.finish() + + # Extract the formats of the model's inputs and outputs. + io_formats = cc.tflite_builder.get_io_formats(edge_program.graph_signature) + + # TFLite model generation flatbuffers_builder = flatbuffers.Builder() internal_tflite_model.gen_tflite(flatbuffers_builder) @@ -108,7 +106,7 @@ def convert_program( def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContext): for node in nodes: if node.op == "placeholder": - node_format = context.node_formats[node] + node_format = node.meta[NXP_NODE_FORMAT] if node.name in context.parameters_mapping: # Node is placeholder and has data -> append as static tensor with data @@ -121,7 +119,7 @@ def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContex context.tflite_builder.append_as_fake_tensor(node, node_format) elif node.op == "call_function": # Node is call function -> append only output as a tensor - node_format = context.node_formats[node] + node_format = node.meta[NXP_NODE_FORMAT] context.tflite_builder.append_as_fake_tensor(node, node_format) elif node.op == "output": # Nothing to do @@ -179,7 +177,6 @@ def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Paramet @staticmethod def build_conversion_context( parameters_mapping: dict, - node_formats: dict[Node, NodeFormat], neutron_target_spec: NeutronTargetSpec, conversion_config: ConversionConfig = _default_conversion_config, custom_delegation_options: CustomDelegationOptions = _default_delegation_options, @@ -195,7 +192,6 @@ def build_conversion_context( tflite_builder, conversion_config, parameters_mapping, - node_formats, custom_delegation_options, ) diff --git a/backends/nxp/backend/ir/conversion_config.py b/backends/nxp/backend/ir/conversion_config.py index 622735e881f..ee77d0a1d5b 100644 --- a/backends/nxp/backend/ir/conversion_config.py +++ b/backends/nxp/backend/ir/conversion_config.py @@ -13,7 +13,7 @@ def __init__(self, args: dict | None = None): :param args: Optional dictionary with conversion arguments. Unknown arguments are ignored. """ - self.keep_io_format: bool = False + self.use_neutron_for_format_conversion: bool = False self.allow_inputs_stripping: bool = True self.qdq_aware_conversion: bool = True self.symbolic_dimensions_mapping: dict[str, int] | None = None diff --git a/backends/nxp/backend/ir/conversion_context.py b/backends/nxp/backend/ir/conversion_context.py index 6fb7e98424e..d4746fbde01 100644 --- a/backends/nxp/backend/ir/conversion_context.py +++ b/backends/nxp/backend/ir/conversion_context.py @@ -10,8 +10,6 @@ from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import ( AtenModelBuilderDirector, ) -from executorch.backends.nxp.backend.node_format_inference import NodeFormat -from torch import Node from torch.nn import Parameter @@ -19,7 +17,6 @@ class ConversionContext: tflite_builder: AtenModelBuilderDirector conversion_config: ConversionConfig parameters_mapping: dict[str, Parameter] - node_formats: dict[Node, NodeFormat] custom_delegation_options: CustomDelegationOptions def __init__( @@ -27,7 +24,6 @@ def __init__( tflite_builder: AtenModelBuilderDirector, conversion_config: ConversionConfig, parameters_mapping: dict, - node_formats: dict[Node, NodeFormat], custom_delegation_options: CustomDelegationOptions, ): """ @@ -39,5 +35,4 @@ def __init__( self.tflite_builder = tflite_builder self.conversion_config = conversion_config self.parameters_mapping = parameters_mapping - self.node_formats = node_formats self.custom_delegation_options = custom_delegation_options diff --git a/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py b/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py index a420cea9aa7..658b4fc93f7 100644 --- a/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py +++ b/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py @@ -9,7 +9,7 @@ from executorch.backends.nxp.backend.ir.converter.conversion import translator from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model -from executorch.backends.nxp.backend.node_format_inference import NodeFormat +from executorch.backends.nxp.backend.node_format import NodeFormat from torch.fx import Node from torch.nn import Parameter @@ -88,19 +88,40 @@ def append_operators(self, ops_to_add: list[tflite_model.Operator]): self.check_and_append_operator(op) - def assign_model_io_to_subgraph_and_get_io_formats( - self, graph_signature - ) -> dict[str, dict]: - """ - Assign model's inputs/outputs to SubGraph. + def get_io_formats(self, graph_signature) -> dict[str, dict[str, TensorFormat]]: + """Get a mapping from tensor names to their formats. - :param graph_signature: Instance of GraphSignature. + :param graph_signature: Instance of GraphSignature. :returns: Mapping between IO tensors' names and their formats. """ io_formats = { "inputs": {}, "outputs": {}, } + for input_name in graph_signature.user_inputs: + tensor = self.tensor_for_name(input_name) + assert input_name == tensor.name, ( + "Program's input name doesn't match with tensor name in TFLite. " + "Input was probably redirected." + ) + io_formats["inputs"][tensor.name] = tensor.tensor_format + + for output_name in graph_signature.user_outputs: + tensor = self.tensor_for_name(output_name) + assert output_name == tensor.name, ( + "Program's output name doesn't match with tensor name in TFLite. " + "Output was probably redirected." + ) + io_formats["outputs"][tensor.name] = tensor.tensor_format + + return io_formats + + def assign_model_io_to_subgraph(self, graph_signature): + """ + Assign model's inputs/outputs to SubGraph. + + :param graph_signature: Instance of GraphSignature. + """ self.get_sub_graph().inputs = tflite_model.SubGraphInputs() for input_name in graph_signature.user_inputs: @@ -110,7 +131,6 @@ def assign_model_io_to_subgraph_and_get_io_formats( "Input was probably redirected." ) self.get_sub_graph().inputs.tmp_inputs.append(tensor) - io_formats["inputs"][tensor.name] = tensor.tensor_format self.get_sub_graph().outputs = tflite_model.SubGraphOutputs() for output_name in graph_signature.user_outputs: @@ -120,7 +140,3 @@ def assign_model_io_to_subgraph_and_get_io_formats( "Output was probably redirected." ) self.get_sub_graph().outputs.tmp_outputs.append(tensor) - - io_formats["outputs"][tensor.name] = tensor.tensor_format - - return io_formats diff --git a/backends/nxp/backend/ir/converter/builder/model_builder.py b/backends/nxp/backend/ir/converter/builder/model_builder.py index 643a6231d15..cfd80d8e300 100755 --- a/backends/nxp/backend/ir/converter/builder/model_builder.py +++ b/backends/nxp/backend/ir/converter/builder/model_builder.py @@ -5,7 +5,9 @@ # License: MIT # See the LICENSE_MIT for more details. # + from copy import deepcopy +from itertools import chain from typing import Dict, List, Optional, Union import executorch.backends.nxp.backend.ir.converter.conversion.translator as translator @@ -48,6 +50,9 @@ FlexTranspose, ) from executorch.backends.nxp.backend.ir.tflite_optimizer import optimizer +from executorch.backends.nxp.backend.neutron_operator_support import ( + transposition_is_supported_on_neutron, +) from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec @@ -218,7 +223,7 @@ def channels_first_version_of(self, t_tensor: tflite_model.Tensor): new_tensor.shape = translator.channels_last_shape_to_channels_first( t_tensor.shape ) - new_tensor.tensor_format = new_tensor.tensor_format.to_node_format() + new_tensor.tensor_format = TensorFormat.CHANNELS_FIRST perm = translator.create_channels_last_to_channels_first_permutation( t_tensor.rank @@ -355,6 +360,19 @@ def _make_inputs_channels_first(self): if input_tensor.tensor_format.is_channels_last(): # Create a Transpose operator and replace the graph input + new_input_shape = translator.channels_last_shape_to_channels_first( + input_tensor.shape + ) + perm = translator.create_channels_first_to_channels_last_permutation( + input_tensor.rank + ) + + if not transposition_is_supported_on_neutron( + new_input_shape.vector, list(perm), self.neutron_target_spec + ): + new_inputs.append(input_tensor) + continue + if input_tensor.rank > 6: msg = ( f"Couldn't preserve the shape of input tensor '{input_tensor.name}', because it has " @@ -365,14 +383,9 @@ def _make_inputs_channels_first(self): new_input = self.duplicate_tensor( input_tensor, input_tensor.name + "_channels_first" ) - new_input.shape = translator.channels_last_shape_to_channels_first( - input_tensor.shape - ) - new_input.tensor_format = input_tensor.tensor_format.to_node_format() + new_input.shape = new_input_shape + new_input.tensor_format = TensorFormat.CHANNELS_FIRST - perm = translator.create_channels_first_to_channels_last_permutation( - input_tensor.rank - ) transpose = self._create_transpose_operator( new_input, input_tensor, perm ) @@ -397,6 +410,16 @@ def _make_outputs_channels_first(self): if output_tensor.tensor_format.is_channels_last(): # Add a Transpose operator, to make the output channels first + shape = output_tensor.shape.vector + perm = translator.create_channels_last_to_channels_first_permutation( + len(shape), True + ) + if not transposition_is_supported_on_neutron( + shape, perm, self.neutron_target_spec + ): + new_outputs.append(output_tensor) + continue + if output_tensor.rank > 6: logger.e( logger.Code.IO_PRESERVATION_ERROR, @@ -437,6 +460,14 @@ def _keep_one_empty_buffer(self): # It's safe to replace the buffer. t.tmp_buffer = empty_buffer + def replace_io_tensor_format_with_node_format(self): + for t in chain( + self.get_sub_graph().inputs.tmp_inputs, + self.get_sub_graph().outputs.tmp_outputs, + ): + if isinstance(t.tensor_format, TensorFormat): + t.tensor_format = t.tensor_format.to_equal_node_format() + def finish(self) -> tflite_model.Model: """Finalize and optimize the converted TFLite model. Then return it. @@ -444,19 +475,23 @@ def finish(self) -> tflite_model.Model: :return: The final TFLite model. """ - if self.conversion_config.keep_io_format: + if self.conversion_config.use_neutron_for_format_conversion: # If the input or output is channels last, add a Transpose operator, to make is channels first. self._make_inputs_channels_first() self._make_outputs_channels_first() # Apply optimizations to the internal TFLite model. - optimizer.Optimizer(self, self.conversion_config).optimize( + optimizer.Optimizer( + self, self.conversion_config, self.neutron_target_spec + ).optimize( self.conversion_config.optimization_whitelist, self.conversion_config.optimization_blacklist, ) self._keep_one_empty_buffer() + self.replace_io_tensor_format_with_node_format() + # Remove outputs, which are not produced by any node. Otherwise, there would be errors after inference. operator_outputs = [] for op in self.get_operators().vector: diff --git a/backends/nxp/backend/ir/converter/node_converter.py b/backends/nxp/backend/ir/converter/node_converter.py index 36266486aac..b69861f85b0 100755 --- a/backends/nxp/backend/ir/converter/node_converter.py +++ b/backends/nxp/backend/ir/converter/node_converter.py @@ -185,6 +185,14 @@ def builder(self) -> AtenModelBuilderDirector: """ return self.context.tflite_builder + @property + def neutron_target_spec(self) -> NeutronTargetSpec: + """ + Get an instance of NeutronTargetSpec from the conversion context. + :return: NeutronTargetSpec instance. + """ + return self.builder.neutron_target_spec + def _create_tflite_op_with_io_tensors(self, node: Node) -> tflite_model.Operator: """ Create TFLite op wrapper with input/output tensors added into 'tmp_inputs' and 'tmp_outputs'. diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py index 22ca258cd4f..bf6dd35a3d4 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py @@ -18,6 +18,7 @@ Concatenation, ) from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT from torch.fx import Node from torch.nn import Parameter @@ -85,32 +86,28 @@ def _is_supported_on_target( if dim == 0: return False - # Neutron requires the channels to be a multiple of numMacs. The channels could either be the second or the - # last dimension, depending on the formats of the node. The format, however, cannot be determined - # during conversion, as it depends on what other nodes are delegated. + # Neutron requires the channels to be a multiple of `8`. The channels could either be the second or the + # last dimension, depending on the formats of the node. + if node.meta[NXP_NODE_FORMAT].is_channels_first(): + # During conversion to IR, the shape will be permuted to channels last, and the dimension on index + # `1` will end up being the channels (last dim in NHWC). + channels_index = 1 + else: + # The shape will not be permuted during conversion, so the channels will remain the last dimension. + channels_index = -1 + input_channels = [ - # The second dimension is the channels in PyTorch. If the inputs/output are not channels first, it - # will still be the channels in the IR. - _get_shape(input_)[1] - for input_ in node.all_input_nodes - ] + [ - # If the inputs/outputs are channels first, the last dimension will be the channels. - _get_shape(input_)[-1] - for input_ in node.all_input_nodes + _get_shape(input_)[channels_index] for input_ in node.all_input_nodes ] - if any( - (input_channel % neutron_target_spec.get_num_macs()) != 0 - for input_channel in input_channels - ): + output_channels = _get_shape(node)[channels_index] + + num_macs = neutron_target_spec.get_num_macs() + if any((input_channel % num_macs) != 0 for input_channel in input_channels): # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1492 return False - output_channels = [_get_shape(node)[1], _get_shape(node)[-1]] - # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493 - if any( - (out_c % neutron_target_spec.get_num_macs()) != 0 - for out_c in output_channels - ): + if (output_channels % num_macs) != 0: + # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493 return False if len(node.all_input_nodes) < 2: # Not supported on Neutron diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py index 499541aa58c..29a8f7d51bb 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py @@ -27,6 +27,8 @@ pad_v2_options, ) from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec + +from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT from torch.fx import Node from torch.nn import Parameter @@ -40,9 +42,16 @@ def _is_supported_on_target( custom_delegation_options: CustomDelegationOptions, ) -> bool: paddings = node.args[1] - if len(paddings) > 4 and paddings[4:6] != [0, 0]: - # Attempt to Pad channels dimension, which is not supported on Neutron. - return False + if node.meta[NXP_NODE_FORMAT].is_channels_first(): + # Dim `1` will end up being the channels. It is padded by paddings[4:6]. + if len(paddings) > 4 and paddings[4:6] != [0, 0]: + # Attempt to Pad channels dimension -> currently not supported + return False + else: + # Dim `-1` will end up being the channels. It is padded by paddings[:2]. + if len(paddings) > 0 and paddings[:2] != [0, 0]: + # Attempt to Pad channels dimension -> currently not supported + return False return True @@ -65,10 +74,6 @@ def _is_supported_in_IR( if not NodeConverter._has_shared_q_params_if_quantized(node): return False - if len(paddings) > 4 and paddings[4:6] != [0, 0]: - # Attempt to Pad channels dimension -> currently not supported - return False - return True # noinspection PyMethodMayBeStatic diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/permute_copy_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/permute_copy_converter.py index f0150b4bc1f..88ba451fbbc 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/permute_copy_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/permute_copy_converter.py @@ -4,28 +4,432 @@ # LICENSE file in the root directory of this source tree. import numpy as np +import torch +from executorch.backends.nxp.backend.edge_helper import ( + node_is_effectively_static_tensor, +) +from executorch.backends.nxp.backend.ir.conversion_context import ConversionContext from executorch.backends.nxp.backend.ir.converter import quantization_utils +from executorch.backends.nxp.backend.ir.converter.conversion import translator from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, + NeutronTargetSpec, NodeConverter, ) +from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat +from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( transpose_options, ) +from executorch.backends.nxp.backend.neutron_operator_support import ( + is_tensor_invariant_permutation, + transposition_is_supported_on_neutron, +) +from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT from torch.fx import Node from torch.nn import Parameter +def _get_shape(node: torch.fx.Node) -> list[int]: + return list(node.meta["val"].shape) + + +def get_supported_transpositions(node: Node, neutron_target_spec: NeutronTargetSpec): + """Since ExecuTorch and NeutronIR use different tensor formats, we must consider the different possible cases + which may occur. The main permutation is always done on channels_first/formatless data, and the output is + channels_first/formatless as well. If this is not the case, a `Transpose` is inserted before and/or + after the main `Transpose`, to make the input/output channels_first. These additional `Transpose` + ops must be supported by Neutron as well. Alternatively, consecutive `Transpose` ops can be fused + together. It is possible for a pair of unsupported permutation to result in a supported one. + Therefore, the merged permutations must also be considered. + + This function identifies which of these permutations are supported on neutron, and returns a dictionary with the + support summary and the corresponding permutations. + + :param node: The `permute_copy` node to base the support analysis from/ + :param neutron_target_spec: NeutronTagetSpec instance. + :return: A dictionary containing the support status and permutation, for all the possible permutations which may be + used during the conversion of the `node`. + """ + + input_shape = node.args[0].meta["val"].shape + output_shape = node.meta["val"].shape + perm = list(node.args[1]) + + to_nchw_perm = translator.create_channels_last_to_channels_first_permutation( + len(input_shape), True + ) + to_nhwc_perm = translator.create_channels_first_to_channels_last_permutation( + len(input_shape), True + ) + channels_last_input_shape = translator.apply_permutation_to( + input_shape, to_nhwc_perm + ) + + main_perm_supported = transposition_is_supported_on_neutron( + input_shape, perm, neutron_target_spec + ) + + # "To NCHW" permutation, in case the input is channels last. + separate_pre_transpose_supported = transposition_is_supported_on_neutron( + channels_last_input_shape, to_nchw_perm, neutron_target_spec + ) + # The main permutation and the previous one merged. + merged_pre_transpose_supported = transposition_is_supported_on_neutron( + channels_last_input_shape, + merged_pre_transpose_permutation := translator.combine_permutations( + to_nchw_perm, perm + ), + neutron_target_spec, + ) + + # "To NHWC" permutation after the main `Transpose`. + separate_post_transpose_supported = transposition_is_supported_on_neutron( + output_shape, to_nhwc_perm, neutron_target_spec + ) + + # The main permutation and the previous one merged. + merged_post_transpose_supported = transposition_is_supported_on_neutron( + input_shape, + merged_post_transpose_permutation := translator.combine_permutations( + perm, to_nhwc_perm + ), + neutron_target_spec, + ) + + # "To NCHW", main permutation, and "to NHWC" all merged. + everything_merged_supported = transposition_is_supported_on_neutron( + input_shape, + everything_merged_permutation := translator.combine_permutations( + translator.combine_permutations(to_nchw_perm, perm), to_nhwc_perm + ), + neutron_target_spec, + ) + + return { + "main": {"supported": main_perm_supported, "perm": perm}, + "separate_pre": { + "supported": separate_pre_transpose_supported, + "perm": to_nchw_perm, + }, + "merged_pre": { + "supported": merged_pre_transpose_supported, + "perm": merged_pre_transpose_permutation, + }, + "separate_post": { + "supported": separate_post_transpose_supported, + "perm": to_nhwc_perm, + }, + "merged_post": { + "supported": merged_post_transpose_supported, + "perm": merged_post_transpose_permutation, + }, + "everything_merged": { + "supported": everything_merged_supported, + "perm": everything_merged_permutation, + }, + } + + +Permutation = list[int] + + +class PermuteCopyFormatHandler: + def __init__(self, context: ConversionContext): + self.context = context + + @property + def neutron_target_spec(self): + return self.context.tflite_builder.neutron_target_spec + + @property + def builder(self): + return self.context.tflite_builder + + def _handle_channels_first_input_and_formatless_output( + self, perm_dict, node, t_op, ops + ) -> Permutation: + # The input must be permuted. + # Either combine the permutations, or prepend a `Transpose` operator. + if perm_dict["merged_pre"]["supported"]: + # Use the combined permutation. + perm = perm_dict["merged_pre"]["perm"] + + elif perm_dict["separate_pre"]["supported"] and perm_dict["main"]["supported"]: + # Prepend a `Transpose` operator to make the input channels first. + ops.add_pre( + self.builder.create_transpose_operator_before( + t_op, 0, perm_dict["separate_pre"]["perm"] + ) + ) + perm = perm_dict["main"]["perm"] + + elif node_is_effectively_static_tensor( + node.args[0], self.context.parameters_mapping + ): + perm = perm_dict["main"]["perm"] + + else: + # The `permute_copy` cannot be represented in Neutron. This should never happen. + raise RuntimeError( + "A `permute_copy` node was incorrectly selected for delegation. Please report this." + ) + + t_op.tmp_inputs[0].tensor_format = TensorFormat.CHANNELS_FIRST + + return perm + + def _handle_formatless_input_and_channels_first_output( + self, perm_dict, node, t_op, ops + ) -> Permutation: + # The output must be permuted. + # Either combine the permutations, or append a `Transpose` operator. + if perm_dict["merged_post"]["supported"]: + # Use the combined permutation. + perm = perm_dict["merged_post"]["perm"] + + elif perm_dict["main"]["supported"] and perm_dict["separate_post"]["supported"]: + # Append a `Transpose` operator to make the output channels first. + perm = perm_dict["main"]["perm"] + ops.add_post( + self.builder.create_transpose_operator_after( + t_op, 0, perm_dict["separate_post"]["perm"] + ) + ) + + elif node_is_effectively_static_tensor( + node.args[0], self.context.parameters_mapping + ): + perm = perm_dict["main"]["perm"] + + else: + # The `permute_copy` cannot be represented in Neutron. This should never happen. + raise RuntimeError( + "A `permute_copy` node was incorrectly selected for delegation. Please report this." + ) + + t_op.tmp_outputs[0].tensor_format = TensorFormat.CHANNELS_FIRST + + return perm + + def _handle_channels_first_input_and_output( + self, perm_dict, node, t_op, ops + ) -> Permutation: + # Both input and output must be permuted, or some merged permutations must be supported. + if perm_dict["everything_merged"]["supported"]: + # Combine all 3 permutations into 1. + perm = perm_dict["everything_merged"]["perm"] + + elif ( + perm_dict["merged_pre"]["supported"] + and perm_dict["separate_post"]["supported"] + ): + # Combine the input and main permutations, and append a `Transpose` to handle the output permutation. + perm = perm_dict["merged_pre"]["perm"] + ops.add_post( + self.builder.create_transpose_operator_after( + t_op, 0, perm_dict["separate_post"]["perm"] + ) + ) + + elif ( + perm_dict["separate_pre"]["supported"] + and perm_dict["merged_post"]["supported"] + ): + # Prepend a `Transpose` to handle the input permutation, and combine the main and output permutations. + ops.add_pre( + self.builder.create_transpose_operator_before( + t_op, 0, perm_dict["separate_pre"]["perm"] + ) + ) + perm = perm_dict["everything_merged"]["supported"] + + elif ( + perm_dict["separate_pre"]["supported"] + and perm_dict["main"]["supported"] + and perm_dict["separate_post"]["supported"] + ): + # Handle each permutation separately. + ops.add_pre( + self.builder.create_transpose_operator_before( + t_op, 0, perm_dict["separate_pre"]["perm"] + ) + ) + perm = perm_dict["main"]["perm"] + ops.add_post( + self.builder.create_transpose_operator_after( + t_op, 0, perm_dict["separate_post"]["perm"] + ) + ) + + elif node_is_effectively_static_tensor( + node.args[0], self.context.parameters_mapping + ): + perm = perm_dict["main"]["perm"] + + else: + # The `permute_copy` cannot be represented in Neutron. This should never happen. + raise RuntimeError( + "A `permute_copy` node was incorrectly selected for delegation. Please report this." + ) + + t_op.tmp_inputs[0].tensor_format = TensorFormat.CHANNELS_FIRST + t_op.tmp_outputs[0].tensor_format = TensorFormat.CHANNELS_FIRST + + return perm + + def _handle_formatless_input_and_output( + self, perm_dict, node, t_op, ops + ) -> Permutation: + # Neither the input nor the output have to be permuted. + if perm_dict["main"]["supported"]: + perm = perm_dict["main"]["perm"] + + elif node_is_effectively_static_tensor( + node.args[0], self.context.parameters_mapping + ): + perm = perm_dict["main"]["perm"] + + else: + # The `permute_copy` cannot be represented in Neutron. This should never happen. + raise RuntimeError( + "A `permute_copy` node was incorrectly selected for delegation. Please report this." + ) + + return perm + + def handle_tensor_formats(self, t_op: tflite_model.Operator, node: Node) -> OpsList: + """Due to the different tensor formats used by ExecuTorch and NeutronIR, it may be necessary to modify the + permutation, or insert extra permutations to equalize the tensor formats. + This method identifies the four possible cases of input/output formats, and finds the conversion solution + which minimizes the number of necessary `Transpose` operators. + """ + perm_dict = get_supported_transpositions(node, self.neutron_target_spec) + + ops = OpsList(middle_op=t_op) + input_format, output_format = ( + node.args[0].meta[NXP_NODE_FORMAT], + node.meta[NXP_NODE_FORMAT], + ) + if input_format.is_channels_first() and (not output_format.is_channels_first()): + perm = self._handle_channels_first_input_and_formatless_output( + perm_dict, node, t_op, ops + ) + + elif ( + not input_format.is_channels_first() + ) and output_format.is_channels_first(): + perm = self._handle_formatless_input_and_channels_first_output( + perm_dict, node, t_op, ops + ) + + elif input_format.is_channels_first() and output_format.is_channels_first(): + perm = self._handle_channels_first_input_and_output( + perm_dict, node, t_op, ops + ) + + else: + perm = self._handle_formatless_input_and_output(perm_dict, node, t_op, ops) + + perm_tensor = self.builder.create_tensor_for_data( + np.array(perm, "int32"), "perm" + ) + + # Use the final permutation as the operator's second input. + t_op.tmp_inputs = [t_op.tmp_inputs[0], perm_tensor] + + return ops + + class PermuteCopyConverter(NodeConverter): + @staticmethod + def _is_supported_on_target( + node: Node, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + custom_delegation_options: CustomDelegationOptions, + ) -> bool: + if node_is_effectively_static_tensor(node.args[0], parameters_mapping): + return ( + True # The operator computes on static data. It will be removed later. + ) + + input_shape = _get_shape(node.args[0]) + perm = list(node.args[1]) + + to_nhwc_perm = translator.create_channels_first_to_channels_last_permutation( + len(input_shape), True + ) + channels_last_input_shape = translator.apply_permutation_to( + input_shape, to_nhwc_perm + ) + + if is_tensor_invariant_permutation( + input_shape, perm + ) and is_tensor_invariant_permutation(channels_last_input_shape, perm): + # The `permute_copy` can always be represented as a Reshape. + return True + + perm_dict = get_supported_transpositions(node, neutron_target_spec) + + input_format, output_format = ( + node.args[0].meta[NXP_NODE_FORMAT], + node.meta[NXP_NODE_FORMAT], + ) + if input_format.is_channels_first() and (not output_format.is_channels_first()): + # Just the input must be permuted. + return ( + perm_dict["separate_pre"]["supported"] + and perm_dict["main"]["supported"] + ) or perm_dict["merged_pre"]["supported"] + + elif ( + not input_format.is_channels_first() + ) and output_format.is_channels_first(): + # Just the output must be permuted. + return ( + perm_dict["separate_post"]["supported"] + and perm_dict["main"]["supported"] + ) or perm_dict["merged_post"]["supported"] + + elif input_format.is_channels_first() and output_format.is_channels_first(): + # Both input and output must be permuted. + return ( + # Separate IO transpositions. + ( + perm_dict["separate_pre"]["supported"] + and perm_dict["main"]["supported"] + and perm_dict["separate_post"]["supported"] + ) + # Separate input, merged output. + or ( + perm_dict["separate_pre"]["supported"] + and perm_dict["merged_post"]["supported"] + ) + # Merged input, separate output. + or ( + perm_dict["merged_pre"]["supported"] + and perm_dict["separate_post"]["supported"] + ) + # Merged input and output. + or perm_dict["everything_merged"]["supported"] + ) + else: + # Simplest case. No format changes required. + return perm_dict["main"]["supported"] + @staticmethod def _is_supported_in_IR( node: Node, parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: + if not NodeConverter._has_shared_q_params_if_quantized(node): + return False + return True def convert(self, node: Node): @@ -53,13 +457,6 @@ def convert(self, node: Node): "match. This indicates error in quantizer." ) - perm = np.array(node.args[1], "int32") - perm_tensor = self.builder.create_tensor_for_data(perm, "perm") - - # Assign the operator its TFLite inputs and outputs - t_op.tmp_inputs = [x, perm_tensor] - t_op.tmp_outputs = [y] - - ops_to_add = OpsList(middle_op=t_op) + ops = PermuteCopyFormatHandler(self.context).handle_tensor_formats(t_op, node) - self.builder.append_operators(ops_to_add.flatten()) + self.builder.append_operators(ops.flatten()) diff --git a/backends/nxp/backend/ir/tensor_formatting.py b/backends/nxp/backend/ir/tensor_formatting.py index db24576e81f..71b697a0eba 100644 --- a/backends/nxp/backend/ir/tensor_formatting.py +++ b/backends/nxp/backend/ir/tensor_formatting.py @@ -6,7 +6,7 @@ # from enum import Enum -from executorch.backends.nxp.backend.node_format_inference import NodeFormat +from executorch.backends.nxp.backend.node_format import NodeFormat class TensorFormat(Enum): @@ -38,8 +38,10 @@ def is_channels_last(self) -> bool: @staticmethod def from_node_format(node_format: NodeFormat): - if node_format.is_channels_first(): - return TensorFormat.CHANNELS_LAST + if node_format == NodeFormat.CHANNELS_FIRST: + return TensorFormat.CHANNELS_LAST # Format is swapped. + elif node_format == NodeFormat.CHANNELS_LAST: + return TensorFormat.CHANNELS_FIRST # Format is swapped. elif node_format == NodeFormat.FORMATLESS: return TensorFormat.FORMATLESS else: @@ -47,8 +49,21 @@ def from_node_format(node_format: NodeFormat): def to_node_format(self): if self == TensorFormat.CHANNELS_LAST: - return NodeFormat.CHANNELS_FIRST + return NodeFormat.CHANNELS_FIRST # Format is swapped. elif self == TensorFormat.FORMATLESS: return NodeFormat.FORMATLESS + elif self == TensorFormat.CHANNELS_FIRST: + return NodeFormat.CHANNELS_LAST # Format is swapped. else: return NodeFormat.NONE + + def to_equal_node_format(self): + match self: + case TensorFormat.CHANNELS_FIRST: + return NodeFormat.CHANNELS_FIRST + case TensorFormat.CHANNELS_LAST: + return NodeFormat.CHANNELS_LAST + case TensorFormat.FORMATLESS: + return NodeFormat.FORMATLESS + case _: + return NodeFormat.NONE diff --git a/backends/nxp/backend/ir/tflite_optimizer/optimizations/base_optimization.py b/backends/nxp/backend/ir/tflite_optimizer/optimizations/base_optimization.py index 6001ca961b8..18e397cc1bd 100755 --- a/backends/nxp/backend/ir/tflite_optimizer/optimizations/base_optimization.py +++ b/backends/nxp/backend/ir/tflite_optimizer/optimizations/base_optimization.py @@ -12,16 +12,21 @@ InputTensorToOpsMap, OutputTensorToOpMap, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec class BaseOptimization(ABC): _builder: "model_builder.ModelBuilder" def __init__( - self, builder: "model_builder.ModelBuilder", conversion_config: ConversionConfig + self, + builder: "model_builder.ModelBuilder", + conversion_config: ConversionConfig, + neutron_target_spec: NeutronTargetSpec, ): self._builder = builder self._conversion_config = conversion_config + self.neutron_target_spec = neutron_target_spec def _create_tensor_to_operator_dictionaries( self, diff --git a/backends/nxp/backend/ir/tflite_optimizer/optimizations/prune_transpose_operators.py b/backends/nxp/backend/ir/tflite_optimizer/optimizations/prune_transpose_operators.py index 0be46efcaa8..053e53d9df8 100755 --- a/backends/nxp/backend/ir/tflite_optimizer/optimizations/prune_transpose_operators.py +++ b/backends/nxp/backend/ir/tflite_optimizer/optimizations/prune_transpose_operators.py @@ -24,10 +24,14 @@ TensorIsNotModelOutput, TensorsHaveData, ) +from executorch.backends.nxp.backend.neutron_operator_support import ( + transposition_is_supported_on_neutron, +) class FuseTransposeOperators(BaseOptimization): - """Remove some `Transpose` operators in the following pattern. + """Remove some `Transpose` operators in the following pattern. This is only done if the resulting permutation is + supported on Neutron. │ 'x' ┌─────▼─────┐ @@ -61,12 +65,27 @@ def __call__(self) -> bool: ) in matcher.match_patterns(): x = tensor_map["x"] perm1 = tensor_map["perm1"].tmp_buffer.data + combined_perms = [] # Remove the leading transpose. for second_transpose in following_transposes: # Combine the permutations for a new permutation of the second `Transpose`. perm2 = second_transpose.tmp_inputs[1].tmp_buffer.data - combined_perm = np.array(combine_permutations(perm1, perm2), np.int32) + combined_perms.append( + np.array(combine_permutations(perm1, perm2), np.int32) + ) + + if not all( + transposition_is_supported_on_neutron( + x.shape.vector, list(perm), self.neutron_target_spec + ) + for perm in combined_perms + ): + continue # Avoid creating an unsupported permutation. + + for second_transpose, combined_perm in zip( + following_transposes, combined_perms + ): second_transpose.tmp_inputs[1] = self._builder.create_tensor_for_data( combined_perm, "perm" ) diff --git a/backends/nxp/backend/ir/tflite_optimizer/optimizer.py b/backends/nxp/backend/ir/tflite_optimizer/optimizer.py index 69b75b72cdd..c4d809512b6 100755 --- a/backends/nxp/backend/ir/tflite_optimizer/optimizer.py +++ b/backends/nxp/backend/ir/tflite_optimizer/optimizer.py @@ -24,6 +24,7 @@ FuseTransposeOperators, RemoveIdentityTransposeOperators, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec class Optimization(Enum): @@ -60,24 +61,25 @@ def __init__( self, builder: "model_builder.ModelBuilder", # noqa F821 conversion_config: ConversionConfig, + neutron_target_spec: NeutronTargetSpec, ): self._builder = builder self.optimization_map = { Optimization.FUSE_ACTIVATION_FUNCTIONS: FuseActivationFunctions( - builder, conversion_config + builder, conversion_config, neutron_target_spec ), Optimization.FUSE_TRANSPOSE_OPERATORS: FuseTransposeOperators( - builder, conversion_config + builder, conversion_config, neutron_target_spec ), Optimization.REMOVE_IDENTITY_TRANSPOSE_OPERATORS: RemoveIdentityTransposeOperators( - builder, conversion_config + builder, conversion_config, neutron_target_spec ), Optimization.PERMUTE_FULLY_CONNECTED_WEIGHTS_AFTER_RESHAPE: PermuteFullyConnectedWeightsAfterReshape( - builder, conversion_config + builder, conversion_config, neutron_target_spec ), Optimization.MOVE_ACTIVATION_BEFORE_CONCAT: MoveActivationBeforeConcatenation( - builder, conversion_config + builder, conversion_config, neutron_target_spec ), } diff --git a/backends/nxp/backend/neutron_converter_manager.py b/backends/nxp/backend/neutron_converter_manager.py index a6884a9ee24..7124929411e 100644 --- a/backends/nxp/backend/neutron_converter_manager.py +++ b/backends/nxp/backend/neutron_converter_manager.py @@ -2,6 +2,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import importlib import logging import multiprocessing @@ -75,6 +76,7 @@ def convert(self, tflite_model: bytes, target: str) -> bytes: cctx = self.neutron_converter.CompilationContext() cctx.targetOpts = self.neutron_converter.getNeutronTarget(target) cctx.compilationOpts.minNumOpsPerGraph = 1 + cctx.compilationOpts.excludeGraphPasses = "MergeTranspose" logger = multiprocessing.log_to_stderr() logger.setLevel(logging.WARNING) diff --git a/backends/nxp/backend/neutron_operator_support.py b/backends/nxp/backend/neutron_operator_support.py new file mode 100644 index 00000000000..df2ff1b5d0a --- /dev/null +++ b/backends/nxp/backend/neutron_operator_support.py @@ -0,0 +1,77 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec + + +def is_tensor_invariant_permutation( + input_shape: list[int], permutation: list[int] +) -> bool: + new_permutation = [ + perm_idx for perm_idx in permutation if input_shape[perm_idx] > 1 + ] + return new_permutation == sorted(new_permutation) + + +def transposition_is_supported_on_neutron( + input_shape: list[int], + permutation: list[int], + neutron_target_spec: NeutronTargetSpec, +) -> bool: + """This function determines if the current NeutronSoftware properly supports a `Transpose` operator with given + `input_shape` and `permutation`. + + :param input_shape: The shape of the main input tensor of the `Transpose` operator. + :param permutation: The permutation the `Transpose` operator is computing. + :param neutron_target_spec: Object holding some parameters of the target platform. + """ + num_macs = neutron_target_spec.get_num_macs() + + if is_tensor_invariant_permutation(input_shape, permutation): + # The `Transpose` will be turned into a `Reshape` by Neutron. The check includes the identity permutation. + return True + + if permutation == [0, 3, 1, 2]: + # NHWC -> NCHW + n, h, w, c = input_shape + + if h * w * c % num_macs != 0: # Official Neutron requirement. + return False + + if not ( + c % num_macs == 0 and h * w % num_macs == 0 + ): # Neutron would produce incorrect outputs. + return False + + if n != 1: + # Neutron only supports `Transpose` operators where the dimensions can be combined into 2 consecutive + # groups. These 2 groups are then transposed like a matrix, and the result is reshaped. Therefore, for the + # [0, 3, 1, 2] permutation, when h * w != 1 and c != 1, batch size must be 1. + return False + + return True + + elif permutation == [0, 2, 3, 1]: + # NCHW -> NHWC + + n, c, h, w = input_shape + + if w % num_macs != 0: # Official Neutron requirement. + return False + + if not ( + c % num_macs == 0 and h * w % num_macs == 0 + ): # Neutron would produce incorrect outputs. + return False + + if n != 1: + # Neutron only supports `Transpose` operators where the dimensions can be combined into 2 consecutive + # groups. These 2 groups are then transposed like a matrix, and the result is reshaped. Therefore, for the + # [0, 2, 3, 1] permutation, when h * w != 1 and c != 1, batch size must be 1. + return False + + return True + + return False diff --git a/backends/nxp/backend/node_format.py b/backends/nxp/backend/node_format.py new file mode 100644 index 00000000000..fd54e2365ed --- /dev/null +++ b/backends/nxp/backend/node_format.py @@ -0,0 +1,26 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + +# Key into the `meta` attribute of nodes, which is mapped to their inferred node format. +NXP_NODE_FORMAT = "nxp_node_format" + + +class NodeFormat(Enum): + # Node's output in NCHW format + CHANNELS_FIRST = 0 + + # Node's output format has no meaning + FORMATLESS = 1 + + # Format has not been identified + NONE = 2 + + # NHWC + CHANNELS_LAST = 3 + + def is_channels_first(self) -> bool: + return self == NodeFormat.CHANNELS_FIRST diff --git a/backends/nxp/backend/node_format_inference.py b/backends/nxp/backend/node_format_inference.py index 76b05d172a4..1f3fe1b4511 100644 --- a/backends/nxp/backend/node_format_inference.py +++ b/backends/nxp/backend/node_format_inference.py @@ -4,30 +4,19 @@ # LICENSE file in the root directory of this source tree. import logging -from enum import Enum +import operator +from executorch.backends.nxp.backend.edge_program_converter import functions_converters +from executorch.backends.nxp.backend.node_format import NodeFormat, NXP_NODE_FORMAT from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload -from torch import Node from torch.export import ExportedProgram +from torch.fx import Node logger = logging.getLogger(__name__) -class NodeFormat(Enum): - # Node's output in NCHW format - CHANNELS_FIRST = 0 - - # Node's output format has no meaning - FORMATLESS = 1 - - # Format has not been identified - NONE = 2 - - def is_channels_first(self) -> bool: - return self == NodeFormat.CHANNELS_FIRST - - class NodeFormatInference: # Dictionary with Edge Aten ops that always use channels first format. # The op in the dictionary is mapped to a dictionary, which holds indices to input nodes @@ -41,9 +30,10 @@ class NodeFormatInference: # A set of Edge Aten ops, which have the ability to change the format (for example - input nodes # are channels first but output is formatless). - ops_that_can_change_tensor_format = {exir_ops.edge.aten.view_copy.default} - - _node_format_mapping: dict[Node, NodeFormat] + ops_that_can_change_tensor_format = { + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.permute_copy.default, + } _type_changed_during_last_run: bool @@ -53,11 +43,13 @@ class NodeFormatInference: # Mapping between Node and its children (outputs) _node_outputs: dict[Node, list[Node]] + # List of all edge operations, which are supported by the converter. + _known_targets: list[EdgeOpOverload] + def __init__(self, edge_program: ExportedProgram): self._edge_program = edge_program self._nodes = edge_program.graph.nodes - self._node_format_mapping = {} self._node_inputs = { node: node.all_input_nodes for node in edge_program.graph.nodes } @@ -67,7 +59,14 @@ def __init__(self, edge_program: ExportedProgram): self._type_changed_during_last_run = False - def identify_node_formats(self) -> dict[Node, NodeFormat]: + self._known_targets = list(functions_converters) + [ + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + operator.getitem, + ] + + def identify_node_formats(self): self._type_changed_during_last_run = True # Re-run format inference until there are no changes @@ -77,25 +76,55 @@ def identify_node_formats(self) -> dict[Node, NodeFormat]: for node in self._nodes: self._infer_format_of_nodes(node) - return self._node_format_mapping + for node in self._nodes: + if self._get_node_op_type(node) is None: + continue + if not hasattr(node, "meta"): + logging.warning(f"Node `{node}` does not have the `meta` attribute.") + node.meta = {} + if NXP_NODE_FORMAT not in node.meta: + logging.warning(f"Node `{node}` does not have inferred format.") + node.meta[NXP_NODE_FORMAT] = NodeFormat.NONE def _infer_format_of_nodes(self, node: Node): op_type = self._get_node_op_type(node) if op_type in self.ops_with_channels_first_nodes: self._handle_node_which_uses_channels_first_format(node) + elif op_type in self.ops_that_can_change_tensor_format: - if op_type == exir_ops.edge.aten.view_copy.default: # view_copy + if op_type in [ + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.permute_copy.default, + ]: + # Try to assign the `formatless` format to the input and output. The converter will then handle the + # transition. + # Note: If the format for the input/output has already been assigned as channels first, it will NOT be + # overwritten. self._assign_format_to_node( self._node_outputs[node][0], NodeFormat.FORMATLESS ) + self._assign_format_to_node( + self._node_inputs[node][0], NodeFormat.FORMATLESS + ) + else: logger.error( f"Node format inference for node type: {op_type} not found!" ) - else: + elif node.op != "call_function" or ( + hasattr(node, "target") and node.target in self._known_targets + ): + # Generic node, or tensor. self._handle_node_which_can_use_any_node_format(node) + else: + # Don't infer the format for unknown nodes. These nodes will never be delegated, so they will divide + # delegated partitions. Propagating the format here could unnecessarily enforce the format in one of these + # partitions, which would require extra transpositions. + for processed_node in self._node_inputs[node] + [node]: + self._assign_format_to_node(processed_node, NodeFormat.NONE) + def _infer_format_based_on_io_ranks(self, node: Node): """Determine the format of the output tensor of given "reshape style operator" based on the ranks of its input and output. @@ -148,10 +177,14 @@ def _assign_format_to_node(self, node: Node, node_format: NodeFormat): # Once CHANNEL_FIRST was assigned, we don't want to reassign return + if node_format is NodeFormat.NONE and old_node_format is not NodeFormat.NONE: + # A format has already been assigned to the node before. Don't replace it with `NONE`. + return + if old_node_format != node_format: self._type_changed_during_last_run = True - self._node_format_mapping[node] = node_format + node.meta[NXP_NODE_FORMAT] = node_format def _get_node_op_type(self, node: Node) -> str | None: """ @@ -252,8 +285,10 @@ def _node_produces_or_consumes_channels_first_format(self, node) -> bool: for ancestor_node in input_nodes ) - def _get_node_format(self, node): - return self._node_format_mapping.get(node, NodeFormat.NONE) + def _get_node_format(self, node) -> NodeFormat: + if not hasattr(node, "meta"): + node.meta = {} + return node.meta.get(NXP_NODE_FORMAT, NodeFormat.NONE) - def _node_is_placeholder(self, node: Node): + def _node_is_placeholder(self, node: Node) -> bool: return node.op == "placeholder" diff --git a/backends/nxp/neutron_partitioner.py b/backends/nxp/neutron_partitioner.py index 965ad41309b..9993e134e35 100644 --- a/backends/nxp/neutron_partitioner.py +++ b/backends/nxp/neutron_partitioner.py @@ -25,6 +25,7 @@ from torch.nn import Parameter from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403 from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference from executorch.backends.nxp.nxp_backend import NeutronBackend from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import ( @@ -209,12 +210,13 @@ def tag_qdq_clusters(self, nodes: list[torch.fx.Node]): exir_ops.edge.aten.max_pool2d_with_indices.default: MaxPool2dConverter, # noqa F405 exir_ops.edge.aten.mean.dim: MeanDimConverter, # noqa F405 exir_ops.edge.aten.mm.default: MMConverter, # noqa F405 + exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405 exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405 exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405 + exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405 exir_ops.edge.aten.sub.Tensor: SubTensorConverter, # noqa F405 exir_ops.edge.aten.tanh.default: TanhConverter, # noqa F405 exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405 - exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405 } @@ -376,6 +378,10 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: allows_single_node_partition=True, ) + # Identify the format (NCHW/NHWC/...) for all nodes in the graph, and store it in the `node.meta`. + # This format will be used by the `CapabilityBasedPartitioner` to determine which nodes will be delegated. + NodeFormatInference(exported_program).identify_node_formats() + iteration_limit = len(exported_program.graph.nodes) for _ in range(iteration_limit): # Run the partitioning. diff --git a/backends/nxp/nxp_backend.py b/backends/nxp/nxp_backend.py index 44e9a19d9f2..457fa335ba6 100644 --- a/backends/nxp/nxp_backend.py +++ b/backends/nxp/nxp_backend.py @@ -15,20 +15,21 @@ import numpy as np import torch +from executorch.backends.nxp._passes.remove_getitem_pass import RemoveGetItemPass from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) -from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.backend.neutron_converter_manager import ( NeutronConverterManager, ) from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.backend.node_format import NodeFormat from executorch.backends.nxp.neutron_node_extraction import ( extract_artifacts_from_neutron_node, NeutronNodeArtifacts, ) from executorch.backends.nxp.neutron_pass_manager import NeutronPassManager -from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.verification.verifier import EXIREdgeDialectVerifier @@ -44,6 +45,7 @@ def __init__(self): self.output_format = None self.operators_not_to_delegate: List[str] = [] self.neutron_converter_flavor = None + self.use_neutron_for_format_conversion = True def _replace_colons(self, operator: str) -> str: """ @@ -57,6 +59,7 @@ def neutron_compile_spec( neutron_converter_flavor: str, extra_flags: Optional[str] = None, operators_not_to_delegate: Optional[List[str]] = None, + use_neutron_for_format_conversion: bool = True, ): """ Generate compile spec for Neutron NPU @@ -67,6 +70,9 @@ def neutron_compile_spec( "'neutron_converter_SDK_25_09' has flavor 'SDK_25_09'. extra_flags: Extra flags for the Neutron compiler operators_not_to_delegate: List of operators that should not be delegated + use_neutron_for_format_conversion: If True, the EdgeProgramToIRConverter will insert `Transpose` ops to + ensure that the IO matches the executorch partition, which will be + delegated to Neutron. """ self.neutron_converter_flavor = neutron_converter_flavor @@ -86,6 +92,8 @@ def neutron_compile_spec( self._replace_colons(op) for op in operators_not_to_delegate ] + self.use_neutron_for_format_conversion = use_neutron_for_format_conversion + return self def build(self): @@ -104,6 +112,10 @@ def build(self): "operators_not_to_delegate", ",".join(self.operators_not_to_delegate).encode(), ), + CompileSpec( + "use_neutron_for_format_conversion", + f"{self.use_neutron_for_format_conversion}".encode(), + ), ] return self.compile_spec @@ -115,6 +127,7 @@ def generate_neutron_compile_spec( system_config: Optional[str] = None, extra_flags: Optional[str] = None, operators_not_to_delegate: Optional[List[str]] = None, + use_neutron_for_format_conversion: bool = True, ) -> List[CompileSpec]: return ( NeutronCompileSpecBuilder() @@ -123,6 +136,7 @@ def generate_neutron_compile_spec( neutron_converter_flavor, extra_flags=extra_flags, operators_not_to_delegate=operators_not_to_delegate, + use_neutron_for_format_conversion=use_neutron_for_format_conversion, ) .build() ) @@ -145,6 +159,7 @@ def preprocess( # noqa C901 binary = bytes() target = "" neutron_converter_flavor = "" + use_neutron_for_format_conversion = None for spec in compile_spec: if spec.key == "output_format": output_format = spec.value.decode() @@ -154,6 +169,8 @@ def preprocess( # noqa C901 compile_flags.append(spec.value.decode()) if spec.key == "neutron_converter_flavor": neutron_converter_flavor = spec.value.decode() + if spec.key == "use_neutron_for_format_conversion": + use_neutron_for_format_conversion = spec.value.decode() == "True" # Check that the output format is set in the compile spec if not output_format: @@ -180,9 +197,15 @@ def preprocess( # noqa C901 ).transform() # Convert the edge program to TFLite. + conversion_config = ConversionConfig( + {"use_neutron_for_format_conversion": use_neutron_for_format_conversion} + if use_neutron_for_format_conversion is not None + else {} + ) tflite_model, io_formats = EdgeProgramToIRConverter().convert_program( edge_program, neutron_target_spec=NeutronTargetSpec(target, neutron_converter_flavor), + conversion_config=conversion_config, ) neutron_model = NeutronConverterManager(neutron_converter_flavor).convert( @@ -241,7 +264,9 @@ def _format_string_for_array(self, array: np.ndarray) -> str: return f"{array.size}s{self._padding_format_string_for_array(array)}" - def _create_payload_header(self, io_formats, neutron_artifacts) -> np.ndarray: + def _create_payload_header( + self, io_formats: dict[str, list[NodeFormat]], neutron_artifacts + ) -> np.ndarray: """ Create bytes header for returned payload. It contains information about input and output tensor formats. Tensors are ordered based on graph signature @@ -279,9 +304,7 @@ def _create_payload_header(self, io_formats, neutron_artifacts) -> np.ndarray: for input_name in neutron_artifacts.input_names: try: header_data.append( - 1 - if inputs[input_name.decode()] == TensorFormat.CHANNELS_LAST - else 0 + 1 if inputs[input_name.decode()] == NodeFormat.CHANNELS_LAST else 0 ) except KeyError: raise AssertionError( @@ -292,7 +315,7 @@ def _create_payload_header(self, io_formats, neutron_artifacts) -> np.ndarray: try: header_data.append( 1 - if outputs[output_name.decode()] == TensorFormat.CHANNELS_LAST + if outputs[output_name.decode()] == NodeFormat.CHANNELS_LAST else 0 ) except KeyError: @@ -331,7 +354,9 @@ def _pack_with_alignment( neutron_artifacts.kernels.tobytes(), ) - def get_binary_payload(self, io_formats, neutron_model) -> bytes: + def get_binary_payload( + self, io_formats: dict[str, list[NodeFormat]], neutron_model + ) -> bytes: """ Get binary payload for provided input/output tensor formats and neutron_model. Returned data have following structure: @@ -351,7 +376,7 @@ def get_binary_payload(self, io_formats, neutron_model) -> bytes: Tensor format definition: '0x1' == CHANNELS_LAST, '0x0' == FORMATLESS (no format). :param io_formats: Dictionary with keys 'inputs' and 'outputs' that contains dictionaries - mapping tensor name to TensorFormat. + mapping tensor name to NodeFormat. :param neutron_model: Neutron model with single NeutronGraph node. :return: 16 bytes aligned binary payload. """ diff --git a/backends/nxp/quantizer/neutron_quantizer.py b/backends/nxp/quantizer/neutron_quantizer.py index 2681e221869..6f2523a5ca3 100644 --- a/backends/nxp/quantizer/neutron_quantizer.py +++ b/backends/nxp/quantizer/neutron_quantizer.py @@ -39,6 +39,7 @@ SubTensorPattern, TanhInPlacePattern, TanhPattern, + TransposeIntPattern, ViewPattern, ) from executorch.backends.nxp.quantizer.utils import ( @@ -212,6 +213,7 @@ def __init__(self): NeutronAtenQuantizer(SubTensorPattern(), static_qconfig), NeutronAtenQuantizer(TanhPattern(), static_qconfig), NeutronAtenQuantizer(TanhInPlacePattern(), static_qconfig), + NeutronAtenQuantizer(TransposeIntPattern(), static_qconfig), NeutronAtenQuantizer(ViewPattern(), static_qconfig), ] ) diff --git a/backends/nxp/quantizer/patterns.py b/backends/nxp/quantizer/patterns.py index 9588ce24c9e..116e981b37d 100644 --- a/backends/nxp/quantizer/patterns.py +++ b/backends/nxp/quantizer/patterns.py @@ -513,6 +513,15 @@ def partition_types(self): return [torch.ops.aten.permute.default] +class TransposeIntPattern(SharedSpecPattern): + """ + Quantizer for Transpose Int operator. + """ + + def partition_types(self) -> list[OpOverload]: + return [torch.ops.aten.transpose.int] + + class ReluPattern(SharedSpecPattern): """ Quantizer for Relu operator. Shared quantization spec is selected, as ReLU usually follows computation layer. diff --git a/backends/nxp/tests/executorch_pipeline.py b/backends/nxp/tests/executorch_pipeline.py index 09bceb2b0d3..9a70283da37 100644 --- a/backends/nxp/tests/executorch_pipeline.py +++ b/backends/nxp/tests/executorch_pipeline.py @@ -89,6 +89,7 @@ def to_quantized_edge_program( remove_quant_io_ops=False, custom_delegation_options=CustomDelegationOptions(), # noqa B008 get_quantizer_fn=lambda: NeutronQuantizer(), + use_neutron_for_format_conversion=True, ) -> EdgeProgramManager: calibration_inputs = get_calibration_inputs_fn(to_model_input_spec(input_spec)) @@ -120,6 +121,7 @@ def to_quantized_edge_program( target, operators_not_to_delegate=operators_not_to_delegate, neutron_converter_flavor=neutron_converter_flavor, + use_neutron_for_format_conversion=use_neutron_for_format_conversion, ) partitioner = NeutronPartitioner(compile_spec, custom_delegation_options) edge_program_manager = edge_program_manager.to_backend(partitioner) @@ -130,8 +132,13 @@ def to_quantized_edge_program( def to_quantized_executorch_program( model: torch.nn.Module, input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]], + use_neutron_for_format_conversion: bool = True, ) -> ExecutorchProgramManager: - edge_program_manager = to_quantized_edge_program(model, input_spec) + edge_program_manager = to_quantized_edge_program( + model, + input_spec, + use_neutron_for_format_conversion=use_neutron_for_format_conversion, + ) return edge_program_manager.to_executorch( config=ExecutorchBackendConfig(extract_delegate_segments=False) diff --git a/backends/nxp/tests/executors.py b/backends/nxp/tests/executors.py index 632e3da055f..fa99046ff33 100644 --- a/backends/nxp/tests/executors.py +++ b/backends/nxp/tests/executors.py @@ -20,11 +20,12 @@ ) from executorch.backends.nxp.backend.ir.converter.node_converter import NodeConverter from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec + +from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference from torch.export import ExportedProgram from torch.fx import Node from torch.fx.graph import Graph - # If executed on i.MX platform, there is no tensorflow module. And typically the intention is to use the tflite python # interpreter available in tflite_runtime try: @@ -308,6 +309,7 @@ def convert_run_compare( ) -> (TFLiteExecutor, EdgeProgramExecutor): if tfl_model is None: + NodeFormatInference(edge_program).identify_node_formats() tfl_model, _ = EdgeProgramToIRConverter().convert_program( edge_program, conversion_config ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py index 315c76a7614..96b9abfe117 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py @@ -14,9 +14,10 @@ from executorch.backends.nxp.tests.executors import ( convert_run_compare, graph_contains_any_of_ops, - ToNCHWPreprocess, - ToNHWCPreprocess, + ToChannelFirstPreprocess, + ToChannelLastPreprocess, ) + from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram @@ -67,7 +68,9 @@ def test_conv_abs(mocker, input_shape: tuple[int] = (1, 3, 112, 112)): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - quantized_program = to_quantized_edge_program(model, input_shape).exported_program() + quantized_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] @@ -80,8 +83,8 @@ def test_conv_abs(mocker, input_shape: tuple[int] = (1, 3, 112, 112)): convert_run_compare( exported_program, tfl_model=tflite_flatbuffers_model, - tflite_input_preprocess=ToNHWCPreprocess(), - tflite_output_preprocess=ToNCHWPreprocess(), + tflite_input_preprocess=ToChannelLastPreprocess(), + tflite_output_preprocess=ToChannelFirstPreprocess(), input_data=input_data, atol=1.0, ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_adaptive_avg_pool2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_adaptive_avg_pool2d_converter.py index 9c8235f7eda..a80d2014487 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_adaptive_avg_pool2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_adaptive_avg_pool2d_converter.py @@ -47,7 +47,9 @@ def test_adaptive_avg_pool_2d_delegated_quant_conversion( converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - edge_program = to_quantized_edge_program(model, input_shape).exported_program() + edge_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() nodes = [str(node) for node in edge_program.graph.nodes] # Input size is a multiple of output size, can be converted to AveragePool, node is delegated @@ -91,7 +93,9 @@ def test_adaptive_avg_pool_2d_non_delegated_quant_conversion( converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - edge_program = to_quantized_edge_program(model, input_shape).exported_program() + edge_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() nodes = list(edge_program.graph.nodes) # Input size is not a multiple of output size, cannot be converted to AveragePool, node is not delegated @@ -122,7 +126,9 @@ def test_adaptive_avg_pool_2d_mean_dim_quant_conversion(mocker): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py index 2c3107eae77..02e799723d4 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py @@ -103,7 +103,9 @@ def test_add_tensor_w_conv_quant_conversion(mocker, input_shape): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py index bcdbd955c71..7aed0236043 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py @@ -6,10 +6,11 @@ import numpy as np import pytest import torch - from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) + +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.backend.ir.converter.builder.model_builder import ( ModelBuilder, ) @@ -91,6 +92,9 @@ def test_avg_pool_2d_conversion(input_shape, padding, count_include_pad): input_data, tflite_input_preprocess=ToNHWCPreprocess(), tflite_output_preprocess=ToNCHWPreprocess(), + conversion_config=ConversionConfig( + {"use_neutron_for_format_conversion": False} + ), ) @@ -145,7 +149,9 @@ def test_avg_pool_2d_quant_conversion(mocker, input_shape, padding, count_includ converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -172,7 +178,9 @@ def test_avg_pool_2d_quant_conversion__padded(mocker): ops_spy = mocker.spy(ModelBuilder, "finish") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ) # Capture the converter operators. ops = ops_spy.spy_return.sub_graphs[0].operators.vector diff --git a/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py index 3df703f5bba..d9b58eda839 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py @@ -17,6 +17,8 @@ from executorch.backends.nxp.tests.executors import ( convert_run_compare, graph_contains_any_of_ops, + ToNCHWPreprocess, + ToNHWCPreprocess, ) from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram @@ -126,6 +128,8 @@ def test_cat__channels_first__same_shapes(dim, num_inputs, mocker): exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data, + tflite_input_preprocess=ToNHWCPreprocess(), + tflite_output_preprocess=ToNCHWPreprocess(), atol=1, ) @@ -241,6 +245,8 @@ def test_cat__channels_first__different_shapes(dim, num_inputs, mocker): exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data, + tflite_input_preprocess=ToNHWCPreprocess(), + tflite_output_preprocess=ToNCHWPreprocess(), atol=1, ) @@ -290,3 +296,78 @@ def test_cat__force_delegate(): graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default] ) assert any("lowered_module" in node.name for node in quantized_program.graph.nodes) + + +def test_cat__format_specific_support__formatless(mocker): + # The last dim will end up being the channels, as the format is `formatless`. + # Only the last dim satisfies the Neutron requirements for the channels. + input_shape = (3, 3, 3, 8) + num_inputs = 2 + dim = 2 + + input_shapes = [input_shape] * num_inputs + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + quantized_program = to_quantized_edge_program( + CatModule(dim), input_shapes + ).exported_program() + + # Make sure the `Cat` was delegated. + assert not graph_contains_any_of_ops( + graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default] + ) + assert any("lowered_module" in node.name for node in quantized_program.graph.nodes) + + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + exported_program: ExportedProgram = converter_spy.call_args.args[1] + input_data = { + i: (np.random.random(shape) * 50).astype(np.int8) + for i, shape in enumerate(input_shapes) + } + convert_run_compare( + exported_program, + tfl_model=tflite_flatbuffers_model, + input_data=input_data, + atol=1, + ) + + +def test_cat__format_specific_support__channels_first(mocker): + # The second dim will end up being the channels, as the format is `formatless`. + # Only the second dim satisfies the Neutron requirements for the channels. + input_shape = (3, 8, 3, 3) + num_inputs = 2 + dim = 2 + + input_shapes = [input_shape] * num_inputs + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + channels = ( + sum(shape[1] for shape in input_shapes) if dim in [1, -3] else input_shape[1] + ) + quantized_program = to_quantized_edge_program( + CatConvModule(dim, channels), input_shapes + ).exported_program() + + # Make sure the `Cat` was delegated. + assert not graph_contains_any_of_ops( + graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default] + ) + assert any("lowered_module" in node.name for node in quantized_program.graph.nodes) + + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + exported_program: ExportedProgram = converter_spy.call_args.args[1] + input_data = { + i: (np.random.random(shape) * 50).astype(np.int8) + for i, shape in enumerate(input_shapes) + } + convert_run_compare( + exported_program, + tfl_model=tflite_flatbuffers_model, + input_data=input_data, + tflite_input_preprocess=ToNHWCPreprocess(), + tflite_output_preprocess=ToNCHWPreprocess(), + atol=1, + ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py index c02d184c5ae..989d97c622b 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py @@ -111,7 +111,7 @@ def test_conv_dropout_quant(self, inplace_dropout: bool, input_shape: tuple[int] EdgeProgramToIRConverter.convert_program, call_original=True ) as converter_spy: quantized_program = to_quantized_edge_program( - model, input_shape + model, input_shape, use_neutron_for_format_conversion=False ).exported_program() tflite_flatbuffers_model, _ = converter_spy.calls[-1].return_value diff --git a/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py index 47cd54c4efb..bd1f894001c 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -7,12 +7,14 @@ import pytest import torch +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.tests.executorch_pipeline import ( to_edge_program, to_quantized_edge_program, ) from executorch.backends.nxp.tests.executors import ( convert_run_compare, + graph_contains_any_of_ops, ToNCHWPreprocess, ToNHWCPreprocess, ) @@ -20,6 +22,7 @@ ConstantPadNDConvModule, ConstantPadNDModule, ) +from executorch.exir.dialects._ops import ops as exir_ops @pytest.fixture(autouse=True) @@ -99,6 +102,9 @@ def test_constant_pad_nd_conversion__channels_first(input_shape, paddings): input_data, tflite_input_preprocess=ToNHWCPreprocess(), tflite_output_preprocess=ToNCHWPreprocess(), + conversion_config=ConversionConfig( + {"use_neutron_for_format_conversion": False} + ), ) @@ -121,3 +127,51 @@ def test_constant_pad_nd__unsupported_paddings(input_shape, paddings): nodes = list(exec_program.graph.nodes) # There is at least one non-delegated Pad node assert any(node.name == "aten_constant_pad_nd_default" for node in nodes) + + +def test_constant_pad_nd__delegation__formatless__supported_padding(): + input_shape = (2, 4, 6, 8) # Formatless -> the last dim (8) will be padded. + paddings = [0, 0, 1, 2, 3, 4] # The last dim is padded using the first 2 paddings. + model = ConstantPadNDModule(paddings) + exec_program = to_quantized_edge_program(model, input_shape).exported_program() + + # Make sure the `pad` was delegated. + assert not graph_contains_any_of_ops( + exec_program.graph, [exir_ops.edge.aten.constant_pad_nd.default] + ) + + +def test_constant_pad_nd__delegation__formatless__unsupported_padding(): + input_shape = (2, 4, 6, 8) # Formatless -> the last dim (8) will be padded. + paddings = [0, 1] # The last dim is padded using the first 2 paddings. + model = ConstantPadNDModule(paddings) + exec_program = to_quantized_edge_program(model, input_shape).exported_program() + + # Make sure the `pad` was NOT delegated. + assert graph_contains_any_of_ops( + exec_program.graph, [exir_ops.edge.aten.constant_pad_nd.default] + ) + + +def test_constant_pad_nd__delegation__channels_first__supported_padding(): + input_shape = (2, 4, 6, 8) # Channels first -> the second dim (4) will be padded. + paddings = [1, 2, 3, 4, 0, 0] # The second dim is padded using the paddings[4:6]. + model = ConstantPadNDConvModule(paddings) + exec_program = to_quantized_edge_program(model, input_shape).exported_program() + + # Make sure the `pad` was delegated. + assert not graph_contains_any_of_ops( + exec_program.graph, [exir_ops.edge.aten.constant_pad_nd.default] + ) + + +def test_constant_pad_nd__delegation__channels_first__unsupported_padding(): + input_shape = (2, 3, 6, 8) # Channels first -> the second dim (3) will be padded. + paddings = [0, 0, 0, 0, 1, 0] # The second dim is padded using the paddings[4:6]. + model = ConstantPadNDConvModule(paddings) + exec_program = to_quantized_edge_program(model, input_shape).exported_program() + + # Make sure the `pad` was NOT delegated. + assert graph_contains_any_of_ops( + exec_program.graph, [exir_ops.edge.aten.constant_pad_nd.default] + ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py index d7a59cad6d6..a6d7f84000a 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py @@ -10,6 +10,7 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.backend.ir.converter.builder.model_builder import ( ModelBuilder, ) @@ -375,7 +376,9 @@ def test_conv2d_quant_conversion(mocker, model: torch.nn.Module, input_shape): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -451,6 +454,7 @@ def test_conv2d_conversion__depthwise__quantized( kernel_size=kernel_shape, ), tuple(input_shape), + use_neutron_for_format_conversion=False, ).exported_program() ops = spy.spy_return.sub_graphs[0].operators.vector @@ -485,6 +489,9 @@ def test_conv2d_conversion__depthwise__padded(padding, mocker): tflite_input_preprocess=ToChannelLastPreprocess(), tflite_output_preprocess=ToChannelFirstPreprocess(), atol=4e-7, + conversion_config=ConversionConfig( + {"use_neutron_for_format_conversion": False} + ), ) conversion_result = spy.spy_return ops = conversion_result.sub_graphs[0].operators.vector @@ -505,6 +512,7 @@ def test_conv2d_conversion__depthwise__padded__quantized(padding, mocker): group=group, in_channels=group, out_channels=group, padding=padding ), tuple(input_shape), + use_neutron_for_format_conversion=False, ).exported_program() ops = spy.spy_return.sub_graphs[0].operators.vector diff --git a/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py index c4bc559817b..dad8ce6a0e3 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py @@ -42,7 +42,9 @@ def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - quantized_program = to_quantized_edge_program(model, input_shape).exported_program() + quantized_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] @@ -79,7 +81,9 @@ def test_custom_hardtanh_quant( converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - quantized_program = to_quantized_edge_program(model, input_shape).exported_program() + quantized_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] diff --git a/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py index 50bbf100980..8b938ef7fff 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py @@ -6,10 +6,11 @@ import numpy as np import pytest import torch - from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) + +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.neutron_pass_manager import NeutronPassManager from executorch.backends.nxp.tests.executorch_pipeline import ( to_edge_program, @@ -76,6 +77,9 @@ def test_max_pool_2d_conversion(input_shape, padding): input_data, tflite_input_preprocess=ToNHWCPreprocess(), tflite_output_preprocess=ToNCHWPreprocess(), + conversion_config=ConversionConfig( + {"use_neutron_for_format_conversion": False} + ), ) @@ -103,7 +107,11 @@ def test_max_pool_2d_quant_conversion(mocker, input_shape, padding): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(MaxPool2dConvModule(padding=padding), input_shape) + _ = to_quantized_edge_program( + MaxPool2dConvModule(padding=padding), + input_shape, + use_neutron_for_format_conversion=False, + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py index a634416f8a7..6dcca38d10f 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py @@ -33,7 +33,9 @@ def test_mean_dim_conv_quant_conversion(mocker, input_shape, dim, keeepdim=True) converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(model, input_shape) + _ = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return @@ -119,7 +121,9 @@ def test_mean_dim_conv_unsupported_quant_conversion(mocker, input_shape, dim, ke converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - edge_program = to_quantized_edge_program(model, input_shape).exported_program() + edge_program = to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() nodes = list(edge_program.graph.nodes) # Last 2 dimensions are not used or keepdim is False, cannot be converted to MeanDim, node is not delegated diff --git a/backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py index d25e2759cc8..57d15aefdc0 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py @@ -3,8 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import unittest + +import kgb import numpy as np -import pytest import torch from executorch.backends.nxp.backend.edge_program_converter import ( @@ -13,52 +15,312 @@ from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program from executorch.backends.nxp.tests.executors import ( convert_run_compare, - ToNCHWPreprocess, - ToNHWCPreprocess, + graph_contains_any_of_ops, ) from executorch.backends.nxp.tests.models import Conv2dModule +from executorch.exir.dialects._ops import ops as exir_ops +from parameterized import parameterized from torch.export import ExportedProgram -@pytest.fixture(autouse=True) -def reseed_model_per_test_run(): - torch.manual_seed(23) - np.random.seed(23) +class Conv2dTransposeModule(torch.nn.Module): + def __init__(self, in_channels: int, dim0: int, dim1: int): + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + self.conv = Conv2dModule( + in_channels=in_channels, out_channels=in_channels, kernel_size=(1, 1) + ) + + def forward(self, x): + x = self.conv(x) + return torch.transpose(x, self.dim0, self.dim1) + + +class Conv2dPermuteModule(torch.nn.Module): + def __init__(self, in_channels: int, perm: tuple[int, ...]): + super().__init__() + self.perm = perm + self.conv = Conv2dModule( + in_channels=in_channels, + out_channels=in_channels, + stride=1, + kernel_size=3, + padding=1, + ) + + def forward(self, x): + x = self.conv(x) + return torch.permute(x, self.perm) + + +class PermuteConv2dModule(torch.nn.Module): + def __init__(self, in_channels: int, perm: tuple[int, ...]): + super().__init__() + self.perm = perm + self.conv = Conv2dModule( + in_channels=in_channels, + out_channels=in_channels, + stride=1, + kernel_size=3, + padding=1, + ) + + def forward(self, x): + x = torch.permute(x, self.perm) + return self.conv(x) -class Conv2dPermuteCopyModule(torch.nn.Module): - def __init__(self, new_dims: tuple[int, ...]): +class PermuteConv2dPermuteModule(torch.nn.Module): + def __init__( + self, in_channels: int, perm1: tuple[int, ...], perm2: tuple[int, ...] + ): super().__init__() - self.new_dims = new_dims - self.conv = Conv2dModule() + self.perm1 = perm1 + self.perm2 = perm2 + self.conv = Conv2dModule( + in_channels=in_channels, + out_channels=in_channels, + stride=1, + kernel_size=3, + padding=1, + ) def forward(self, x): + x = torch.permute(x, self.perm1) x = self.conv(x) - return torch.permute(x, self.new_dims) + x = torch.permute(x, self.perm2) + return x -def test_permute_copy_quant_conversion__with_bias(mocker): - input_shape = (1, 4, 8, 8) - new_dims = (0, 2, 3, 1) +class LinearPermuteModule(torch.nn.Module): + def __init__(self, in_features: int, perm: tuple[int, ...]): + super().__init__() + self.perm = perm + self.fc = torch.nn.Linear(in_features, in_features) + + def forward(self, x): + x = self.fc(x) + return torch.permute(x, self.perm) - converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - # Run conversion - _ = to_quantized_edge_program(Conv2dPermuteCopyModule(new_dims), input_shape) +class TestPermuteCopyConversion(kgb.SpyAgency, unittest.TestCase): + @classmethod + def setUpClass(cls): + torch.manual_seed(23) + np.random.seed(42) - # Capture generated model - tflite_flatbuffers_model, io_formats = converter_spy.spy_return + @parameterized.expand( + [ + ["To channel first permutation", (1, 16, 8, 8), (0, 3, 1, 2)], + ["To channel last permutation", (1, 16, 8, 8), (0, 2, 3, 1)], + ] + ) + def test_permute_copy_conversion__from_permute_4D__quantized__channels_first_input( + self, _: str, input_shape, perm + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, call_original=True + ) as converter_spy: + model = Conv2dPermuteModule(input_shape[1], perm) - # Capture converted program - edge_program: ExportedProgram = converter_spy.call_args.args[1] + # Run conversion + edge_program = to_quantized_edge_program( + model, input_shape + ).exported_program() - input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) + # Make sure the `Permute_copy` was delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.permute_copy.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) - convert_run_compare( - edge_program, - input_data, - tfl_model=tflite_flatbuffers_model, - atol=1.0, - tflite_input_preprocess=ToNHWCPreprocess(), - tflite_output_preprocess=ToNCHWPreprocess(), + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + + # Capture converted program + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + atol=1.0, + ) + + @parameterized.expand( + [ + ["To channel first permutation", (1, 8, 8, 8), (0, 3, 1, 2)], + ["To channel last permutation", (1, 8, 8, 8), (0, 2, 3, 1)], + ] ) + def test_permute_copy_conversion__from_permute_4D__quantized__channels_first_output( + self, _: str, input_shape, perm + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, call_original=True + ) as converter_spy: + model = PermuteConv2dModule(input_shape[1], perm) + + # Run conversion + edge_program = to_quantized_edge_program( + model, input_shape + ).exported_program() + + # Make sure the `Permute_copy` was delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.permute_copy.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + + # Capture converted program + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + atol=1.0, + ) + + @parameterized.expand( + [ + ["nchw->nhwc ... nchw->nhwc", (1, 8, 8, 8), (0, 2, 3, 1), (0, 2, 3, 1)], + ["nchw->nhwc ... nhwc->nchw", (1, 8, 8, 8), (0, 2, 3, 1), (0, 3, 1, 2)], + ["nhwc->nchw ... nhwc->nchw", (1, 8, 8, 8), (0, 3, 1, 2), (0, 3, 1, 2)], + ["nhwc->nchw ... nchw->nhwc", (1, 8, 8, 8), (0, 3, 1, 2), (0, 2, 3, 1)], + ] + ) + def test_permute_copy_conversion__from_permute_4D__quantized__channels_first_io( + self, _: str, input_shape, perm1, perm2 + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, call_original=True + ) as converter_spy: + model = PermuteConv2dPermuteModule(input_shape[1], perm1, perm2) + + # Run conversion + edge_program = to_quantized_edge_program( + model, input_shape + ).exported_program() + + # Make sure the `Permute_copy` was delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.permute_copy.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + + # Capture converted program + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + atol=1.0, + ) + + @parameterized.expand( + [ + ["Permutation can be replaced by reshapes", (10, 1, 8), (0, 2, 1)], + ["Permutation can be replaced by reshapes", (10, 1, 1), (2, 1, 0)], + ["Permutation is identical and can be removed", (10, 1, 8), (0, 1, 2)], + ] + ) + def test_permute_copy_conversion__from_permute_3D__quantized( + self, _: str, input_shape, perm + ): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, call_original=True + ) as converter_spy: + # Run conversion + edge_program = to_quantized_edge_program( + LinearPermuteModule(input_shape[2], perm), input_shape + ).exported_program() + + # Make sure the `Permute_copy` was delegated. + assert not graph_contains_any_of_ops( + graph=edge_program.graph, ops=[exir_ops.edge.aten.permute_copy.default] + ) + assert any( + "lowered_module" in node.name for node in edge_program.graph.nodes + ) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + + # Capture converted program + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + + convert_run_compare( + exported_program, + input_data, + tfl_model=tflite_flatbuffers_model, + atol=1.0, + ) + + @parameterized.expand( + [ + ["Transpose dims 1 and 2", (1, 16, 8, 8), (0, 2, 1, 3)], + ["To (2, 0, 1, 3) permutation", (1, 16, 8, 8), (2, 0, 1, 3)], + ["To (3, 1, 2, 0) permutation", (1, 16, 8, 8), (3, 1, 2, 0)], + ["To (3, 1, 0, 2) permutation", (1, 16, 8, 8), (3, 1, 0, 2)], + ] + ) + def test_permute_copy_non_delegated_conversion__from_permute_4D__quantized( + self, _: str, input_shape, perm + ): + model = Conv2dPermuteModule(input_shape[1], perm) + edge_program = to_quantized_edge_program(model, input_shape).exported_program() + + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 10 + assert ( + nodes[6].target == exir_ops.edge.aten.permute_copy.default + ) # PermuteCopy not delegated. + + @parameterized.expand( + [ + ["Transpose dims 1 and 2", (1, 16, 8, 8), 1, 2], + ["Transpose dims 2 and 3", (1, 16, 8, 8), 2, 3], + ] + ) + def test_permute_copy_non_delegated_conversion__from_transpose_4D__quantized( + self, _: str, input_shape, dim0, dim1 + ): + model = Conv2dTransposeModule(input_shape[1], dim0, dim1) + edge_program = to_quantized_edge_program(model, input_shape).exported_program() + + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 10 + assert ( + nodes[6].target == exir_ops.edge.aten.permute_copy.default + ) # PermuteCopy not delegated. diff --git a/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py index 8d903e3e0b5..cf0e0135ffe 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_relu_converter.py @@ -67,7 +67,9 @@ def test_relu_with_conv_quant_conversion(mocker): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") # Run conversion - _ = to_quantized_edge_program(ConvReLUModule(), input_shape) + _ = to_quantized_edge_program( + ConvReLUModule(), input_shape, use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, _ = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py index c5d7d4d6a38..382266e9cb1 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_sigmoid_converter.py @@ -33,7 +33,9 @@ def test_conv_sigmoid(mocker, input_shape: tuple[int] = (1, 3, 112, 112)): converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") - to_quantized_edge_program(model, input_shape).exported_program() + to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.spy_return exported_program: ExportedProgram = converter_spy.call_args.args[1] diff --git a/backends/nxp/tests/ir/converter/node_converter/test_softmax_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_softmax_converter.py index 92af90b923d..b2e00fefc5a 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_softmax_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_softmax_converter.py @@ -11,6 +11,7 @@ EdgeProgramToIRConverter, ) from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig +from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference from executorch.backends.nxp.tests.executorch_pipeline import to_edge_program from executorch.backends.nxp.tests.executors import convert_run_compare from executorch.backends.nxp.tests.models import SoftmaxConvModule, SoftmaxModule @@ -56,6 +57,7 @@ def test_softmax_conversion__unknown_input_format(input_shape, dim: int): model = SoftmaxModule(dim) edge_program = to_edge_program(model, input_shape).exported_program() + NodeFormatInference(edge_program).identify_node_formats() # Currently this test not pass because the convertibility checker doesn't use tensor formats. with pytest.raises( @@ -78,6 +80,7 @@ def test_softmax_conversion_channel_last(input_shape, dim: int): model = SoftmaxConvModule(dim) edge_program = to_edge_program(model, input_shape).exported_program() + NodeFormatInference(edge_program).identify_node_formats() # TODO (Robert Kalmar) Currently this test not pass because the convertibility checker doesn't use tensor formats. with pytest.raises( @@ -104,6 +107,7 @@ def test_softmax_conversion_unsupported_dims(input_shape, dim: int): model = SoftmaxModule(dim) edge_program = to_edge_program(model, input_shape).exported_program() + NodeFormatInference(edge_program).identify_node_formats() with pytest.raises( AssertionError, match="`aten__softmax_default` is not convertible" diff --git a/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py index 98566ff1ad6..336c3cc9afd 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py @@ -118,7 +118,9 @@ def test_sub_tensor_w_conv_quant_conversion(mocker, x_input_shape): y_input_shape = (n, 8, h, w) # Run conversion - _ = to_quantized_edge_program(model, [x_input_shape, y_input_shape]) + _ = to_quantized_edge_program( + model, [x_input_shape, y_input_shape], use_neutron_for_format_conversion=False + ) # Capture generated model tflite_flatbuffers_model, io_formats = converter_spy.spy_return diff --git a/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py index bb4500bc1e2..fd6a95eef50 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py @@ -60,7 +60,7 @@ def test_conv_tanh( ) quantized_program = to_quantized_edge_program( - model, input_shape + model, input_shape, use_neutron_for_format_conversion=False ).exported_program() tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value exported_program: ExportedProgram = converter_spy.calls[-1].args[0] diff --git a/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py index 448a9753000..fac0a1fffee 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py @@ -12,6 +12,8 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) + +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig from executorch.backends.nxp.backend.ir.converter.builder.model_builder import ( ModelBuilder, ) @@ -146,6 +148,9 @@ def test__channels_first_to_4d(mocker): input_data, tflite_input_preprocess=ToNHWCPreprocess(), atol=2.0e-7, + conversion_config=ConversionConfig( + {"use_neutron_for_format_conversion": False} + ), ) tflite_model = converter_spy.spy_return @@ -243,6 +248,7 @@ def test_view_w_conv_linear_quant_conversion(mocker, input_shape, channels_view_ channels=input_shape[1], channels_view_out=channels_view_out ), input_shape, + use_neutron_for_format_conversion=False, ) # Capture generated model diff --git a/backends/nxp/tests/ir/edge_passes/test_remove_io_quant_ops_pass.py b/backends/nxp/tests/ir/edge_passes/test_remove_io_quant_ops_pass.py index 17b040fbc3d..b5e701ab239 100644 --- a/backends/nxp/tests/ir/edge_passes/test_remove_io_quant_ops_pass.py +++ b/backends/nxp/tests/ir/edge_passes/test_remove_io_quant_ops_pass.py @@ -51,7 +51,10 @@ def test_remove_io_quant_ops_pass__cifarnet(): model = CifarNet().get_eager_model() input_shape = (1, 3, 32, 32) edge_program_manager = to_quantized_edge_program( - model, input_shape, remove_quant_io_ops=True + model, + input_shape, + remove_quant_io_ops=True, + use_neutron_for_format_conversion=False, ) exec_prog = edge_program_manager.to_executorch( diff --git a/backends/nxp/tests/test_integration.py b/backends/nxp/tests/test_integration.py index d31b22c9ce9..3bd5f3e1487 100644 --- a/backends/nxp/tests/test_integration.py +++ b/backends/nxp/tests/test_integration.py @@ -39,7 +39,9 @@ def test_conv_fc_softmax__to_executorch_program(): def test_cifarnet(): model = CifarNet().get_eager_model().eval() input_shape = (1, 3, 32, 32) - exec_prog = to_quantized_executorch_program(model, input_shape) + exec_prog = to_quantized_executorch_program( + model, input_shape, use_neutron_for_format_conversion=False + ) delegation_info = get_delegation_info(exec_prog.exported_program().graph_module) assert delegation_info.num_delegated_subgraphs == 1 diff --git a/backends/nxp/tests/test_neutron_backend.py b/backends/nxp/tests/test_neutron_backend.py index c9917651fbd..1db97b1cbfb 100644 --- a/backends/nxp/tests/test_neutron_backend.py +++ b/backends/nxp/tests/test_neutron_backend.py @@ -3,8 +3,30 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import numpy as np +import torch + +from executorch.backends.nxp.backend.edge_program_converter import ( + EdgeProgramToIRConverter, +) +from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOptions import BuiltinOptions +from executorch.backends.nxp.backend.ir.lib.tflite.Model import Model +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec +from executorch.backends.nxp.nxp_backend import PayloadComposer from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program -from executorch.backends.nxp.tests.models import Conv2dModule, LinearSoftmaxModule +from executorch.backends.nxp.tests.executors import ( + convert_run_compare, + EdgeProgramExecutor, + graph_contains_any_of_ops, + TFLiteExecutor, + ToNHWCPreprocess, +) +from executorch.backends.nxp.tests.models import ( + Conv2dModule, + ConvFCSoftmaxModule, + LinearSoftmaxModule, +) +from torch.export import ExportedProgram def test_neutron_backend__single_conv_model(): @@ -21,7 +43,9 @@ def test_neutron_backend__single_conv_model(): def test_neutron_backend__single_conv_model__payload_header_channels_last(): edge_program_manager = to_quantized_edge_program( - Conv2dModule(bias=False), (1, 4, 32, 32) + Conv2dModule(bias=False), + (1, 4, 32, 32), + use_neutron_for_format_conversion=False, ) payload = ( edge_program_manager.exported_program().graph_module.lowered_module_0.processed_bytes @@ -53,3 +77,307 @@ def test_neutron_backend__linear_softmax_model__payload_header_formatless(): assert payload[6] == 0x0 # Map 0-th Neutron output to 0-th model output assert all(byte == 0x0 for byte in payload[7:16]) # Aligned to 16 bytes assert payload[17] != 0x0 # Followed by non-zero content + + +def test_lowered_program_and_tflite_output_match__conv2d__no_bias(mocker): + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + model = Conv2dModule(bias=False) + input_shape = (1, 4, 32, 32) + + # Run conversion + to_quantized_edge_program( + model, input_shape, use_neutron_for_format_conversion=False + ) + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + + tflite_model = Model.GetRootAs(tflite_flatbuffers_model) + sub_graph = tflite_model.Subgraphs(0) + + assert sub_graph.OperatorsLength() == 1 + assert sub_graph.Operators(0).BuiltinOptionsType() == BuiltinOptions.Conv2DOptions + + # Capture converted program + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + input_data = ( + (torch.randn(input_shape, dtype=torch.float32) * 50) + .type(torch.int8) + .detach() + .numpy() + ) + input_data_tflite = np.transpose(input_data, [0, 2, 3, 1]) + + # Execute program and TFLite model + program_executor = EdgeProgramExecutor(exported_program) + tflite_executor = TFLiteExecutor(model_content=tflite_flatbuffers_model) + + output_edge = program_executor.inference(input_data) + output_tflite = tflite_executor.inference(input_data_tflite) + + output_tflite = np.transpose(output_tflite, [0, 3, 1, 2]) + + # Outputs difference is smaller than 1 (rounding error in quantization) + assert np.max(np.abs(output_edge - output_tflite)) <= 1 + + +def test_conv_fc__lowered_program_and_tflite_output_match(mocker): + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + model = ConvFCSoftmaxModule() + input_shape = (1, 4, 5, 5) + + # Run conversion + _ = to_quantized_edge_program(model, input_shape) + + # Capture converted program + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + # Capture generated model + tflite_flatbuffers_model, _ = converter_spy.spy_return + + # No Transpose ops in produced TFLite model + tflite_subgraph = Model.GetRootAs(tflite_flatbuffers_model).Subgraphs(0) + + assert tflite_subgraph.OperatorsLength() == 3 + assert ( + tflite_subgraph.Operators(0).BuiltinOptionsType() + == BuiltinOptions.Conv2DOptions + ) + assert ( + tflite_subgraph.Operators(1).BuiltinOptionsType() + == BuiltinOptions.ReshapeOptions + ) + assert ( + tflite_subgraph.Operators(2).BuiltinOptionsType() + == BuiltinOptions.FullyConnectedOptions + ) + + # Verify outputs of program and TFLite model + input_data = ( + (torch.randn(input_shape, dtype=torch.float32)) + .type(torch.int8) + .detach() + .numpy() + ) + convert_run_compare( + exported_program, + input_data=input_data, + tflite_input_preprocess=ToNHWCPreprocess(), + ) + + +def test_delegating_format_related_transpose_operators__unsupported_shapes(mocker): + # This test focuses on the case when Neutron would not support the inserted Transpose operators, so they are not + # inserted, so the runtime will permute the data. + + # Make sure none of the dimensions are multiples of `num_macs` (8), for proper testing. + model = Conv2dModule(in_channels=3, out_channels=3, padding=1, stride=1) + input_shape = (1, 3, 3, 3) + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + payload_header_spy = mocker.spy(PayloadComposer, "_create_payload_header") + edge_program = to_quantized_edge_program( + model, + input_shape, + use_neutron_for_format_conversion=True, # Make sure the IR converter inserts the extra `Transpose` operators. + ).exported_program() + + # Make sure the edge_program only contains the 1 delegate call. + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 7 + assert "call_delegate" in nodes[3].name + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.convolution.default] + ) + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.permute_copy.default] + ) + + # Capture the converted IR model. + tflite_flatbuffers_model, _ = converter_spy.spy_return + + # Make sure the `Transpose` ops are NOT in the IR model. + tflite_subgraph = Model.GetRootAs(tflite_flatbuffers_model).Subgraphs(0) + assert tflite_subgraph.OperatorsLength() == 2 + assert ( + tflite_subgraph.Operators(0).BuiltinOptionsType() == BuiltinOptions.PadV2Options + ) + assert ( + tflite_subgraph.Operators(1).BuiltinOptionsType() + == BuiltinOptions.Conv2DOptions + ) + + # Get the header of the payload for the delegated partition. + payload_header = payload_header_spy.spy_return + assert payload_header.size == 7 + # the 4th and 5th bytes indicate the format. `1` means `channels_last`, which means the runtime will transpose the data. + assert all(payload_header[3:5] == [1, 1]) # [, ] + + +def test_delegating_format_related_transpose_operators__supported_case(mocker): + # Make sure the output channels (channels for the trailing Transpose), and the last input dimension (channels for + # the leading Transpose) are multiples of `num_macs``. + + num_macs = NeutronTargetSpec("imxrt700", "SDK_25_09").get_num_macs() + model = Conv2dModule( + in_channels=num_macs, out_channels=num_macs, padding=1, stride=1 + ) + input_shape = (1, num_macs, num_macs, num_macs) + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + payload_header_spy = mocker.spy(PayloadComposer, "_create_payload_header") + edge_program = to_quantized_edge_program( + model, + input_shape, + use_neutron_for_format_conversion=True, # Make sure the IR converter inserts the extra `Transpose` operators. + ).exported_program() + + # Make sure the edge_program only contains the 1 delegate call. + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 7 + assert "call_delegate" in nodes[3].name + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.convolution.default] + ) + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.permute_copy.default] + ) + + # Capture the converted IR model. + tflite_flatbuffers_model, _ = converter_spy.spy_return + + # Make sure the `Transpose` ops ARE in the IR model. + tflite_subgraph = Model.GetRootAs(tflite_flatbuffers_model).Subgraphs(0) + assert tflite_subgraph.OperatorsLength() == 4 + assert ( + tflite_subgraph.Operators(0).BuiltinOptionsType() + == BuiltinOptions.TransposeOptions + ) + assert ( + tflite_subgraph.Operators(1).BuiltinOptionsType() == BuiltinOptions.PadV2Options + ) + assert ( + tflite_subgraph.Operators(2).BuiltinOptionsType() + == BuiltinOptions.Conv2DOptions + ) + assert ( + tflite_subgraph.Operators(3).BuiltinOptionsType() + == BuiltinOptions.TransposeOptions + ) + + # Get the header of the payload for the delegated partition. + payload_header = payload_header_spy.spy_return + assert payload_header.size == 7 + # the 4th and 5th bytes indicate the format. `0` means `channels_last`, which means the runtime will NOT transpose the data. + assert all(payload_header[3:5] == [0, 0]) # [, ] + + +def test_delegating_format_related_transpose_operators__supported_output__unsupported_input( + mocker, +): + num_macs = NeutronTargetSpec("imxrt700", "SDK_25_09").get_num_macs() + model = Conv2dModule( + in_channels=num_macs, + out_channels=num_macs, # The output `Transpose` will be supported. + padding=1, + stride=1, + ) + input_shape = (1, num_macs, num_macs, 3) # The input `Transpose` is not supported. + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + payload_header_spy = mocker.spy(PayloadComposer, "_create_payload_header") + edge_program = to_quantized_edge_program( + model, + input_shape, + use_neutron_for_format_conversion=True, # Make sure the IR converter inserts the extra `Transpose` operators. + ).exported_program() + + # Make sure the edge_program only contains the 1 delegate call. + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 7 + assert "call_delegate" in nodes[3].name + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.convolution.default] + ) + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.permute_copy.default] + ) + + # Capture the converted IR model. + tflite_flatbuffers_model, _ = converter_spy.spy_return + + # Make sure there is just the 1 `Transpose` in the model. + tflite_subgraph = Model.GetRootAs(tflite_flatbuffers_model).Subgraphs(0) + assert tflite_subgraph.OperatorsLength() == 3 + assert ( + tflite_subgraph.Operators(0).BuiltinOptionsType() == BuiltinOptions.PadV2Options + ) + assert ( + tflite_subgraph.Operators(1).BuiltinOptionsType() + == BuiltinOptions.Conv2DOptions + ) + assert ( + tflite_subgraph.Operators(2).BuiltinOptionsType() + == BuiltinOptions.TransposeOptions + ) + + # Get the header of the payload for the delegated partition. + payload_header = payload_header_spy.spy_return + assert payload_header.size == 7 + # the 4th and 5th bytes indicate the format. `1` means `channels_last`, which means the runtime will transpose the data. + assert all(payload_header[3:5] == [1, 0]) # [, ] + + +def test_delegating_format_related_transpose_operators__supported_input__unsupported_output( + mocker, +): + num_macs = NeutronTargetSpec("imxrt700", "SDK_25_09").get_num_macs() + model = Conv2dModule( + in_channels=num_macs, + out_channels=3, # The output `Transpose` will NOT be supported. + stride=1, + ) + input_shape = (1, num_macs, 3, num_macs) # The input `Transpose` is supported. + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + payload_header_spy = mocker.spy(PayloadComposer, "_create_payload_header") + edge_program = to_quantized_edge_program( + model, + input_shape, + use_neutron_for_format_conversion=True, # Make sure the IR converter inserts the extra `Transpose` operators. + ).exported_program() + + # Make sure the edge_program only contains the 1 delegate call. + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 7 + assert "call_delegate" in nodes[3].name + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.convolution.default] + ) + assert not graph_contains_any_of_ops( + edge_program.graph, [torch.ops.aten.permute_copy.default] + ) + + # Capture the converted IR model. + tflite_flatbuffers_model, _ = converter_spy.spy_return + + # Make sure there is just the 1 `Transpose` in the model. + tflite_subgraph = Model.GetRootAs(tflite_flatbuffers_model).Subgraphs(0) + assert tflite_subgraph.OperatorsLength() == 2 + assert ( + tflite_subgraph.Operators(0).BuiltinOptionsType() + == BuiltinOptions.TransposeOptions + ) + assert ( + tflite_subgraph.Operators(1).BuiltinOptionsType() + == BuiltinOptions.Conv2DOptions + ) + + # Get the header of the payload for the delegated partition. + payload_header = payload_header_spy.spy_return + assert payload_header.size == 7 + # the 4th and 5th bytes indicate the format. `1` means `channels_last`, which means the runtime will transpose the data. + assert all(payload_header[3:5] == [0, 1]) # [, ] diff --git a/backends/nxp/tests/test_neutron_converter_manager.py b/backends/nxp/tests/test_neutron_converter_manager.py index 2fcfd8cd987..5b105d7ef64 100644 --- a/backends/nxp/tests/test_neutron_converter_manager.py +++ b/backends/nxp/tests/test_neutron_converter_manager.py @@ -13,6 +13,7 @@ from executorch.backends.nxp.backend.neutron_converter_manager import ( NeutronConverterManager, ) +from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference from executorch.backends.nxp.tests.models import Conv2dModule @@ -23,6 +24,7 @@ def test_conv2d_neutron_conversion__default_flavor(): exir_program = torch.export.export(model, example_input) edge_program_manager = exir.to_edge(exir_program) + NodeFormatInference(edge_program_manager.exported_program()).identify_node_formats() edge_program_converter = EdgeProgramToIRConverter() tflite_model, _ = edge_program_converter.convert_program( edge_program_manager.exported_program() @@ -43,6 +45,7 @@ def test__conv2d_neutron_conversion__invalid_flavor(): exir_program = torch.export.export(model, example_input) edge_program_manager = exir.to_edge(exir_program) + NodeFormatInference(edge_program_manager.exported_program()).identify_node_formats() edge_program_converter = EdgeProgramToIRConverter() tflite_model, _ = edge_program_converter.convert_program( edge_program_manager.exported_program() diff --git a/backends/nxp/tests/test_node_format_inference.py b/backends/nxp/tests/test_node_format_inference.py index e2796187ce8..d0a73328037 100644 --- a/backends/nxp/tests/test_node_format_inference.py +++ b/backends/nxp/tests/test_node_format_inference.py @@ -9,6 +9,7 @@ from executorch.backends.nxp.backend.node_format_inference import ( NodeFormat, NodeFormatInference, + NXP_NODE_FORMAT, ) from executorch.backends.nxp.neutron_pass_manager import NeutronPassManager from executorch.backends.nxp.tests.models import ( @@ -27,7 +28,7 @@ def test_convolution(): exir_program = torch.export.export(model, example_input) edge_program = exir.to_edge(exir_program).exported_program() - node_formats = NodeFormatInference(edge_program).identify_node_formats() + NodeFormatInference(edge_program).identify_node_formats() expected_mapping = { "p_conv_weight": NodeFormat.CHANNELS_FIRST, @@ -37,8 +38,8 @@ def test_convolution(): "output": NodeFormat.CHANNELS_FIRST, } - for node, node_format in node_formats.items(): - assert expected_mapping[node.name] == node_format + for node in edge_program.graph.nodes: + assert expected_mapping[node.name] == node.meta[NXP_NODE_FORMAT] def test_softmax(): @@ -48,7 +49,7 @@ def test_softmax(): exir_program = torch.export.export(model, example_input) edge_program = exir.to_edge(exir_program).exported_program() - node_formats = NodeFormatInference(edge_program).identify_node_formats() + NodeFormatInference(edge_program).identify_node_formats() expected_mapping = { "x": NodeFormat.FORMATLESS, @@ -56,8 +57,8 @@ def test_softmax(): "output": NodeFormat.FORMATLESS, } - for node, node_format in node_formats.items(): - assert expected_mapping[node.name] == node_format + for node in edge_program.graph.nodes: + assert expected_mapping[node.name] == node.meta[NXP_NODE_FORMAT] def test_maxpool2d(): @@ -78,7 +79,7 @@ def test_maxpool2d(): # Remove MaxPool-related "getitem" nodes from graph edge_program = NeutronPassManager(edge_program, [RemoveGetItemPass]).transform() - node_formats = NodeFormatInference(edge_program).identify_node_formats() + NodeFormatInference(edge_program).identify_node_formats() expected_mapping = { "x": NodeFormat.CHANNELS_FIRST, @@ -86,5 +87,5 @@ def test_maxpool2d(): "output": NodeFormat.CHANNELS_FIRST, } - for node, node_format in node_formats.items(): - assert expected_mapping[node.name] == node_format + for node in edge_program.graph.nodes: + assert expected_mapping[node.name] == node.meta[NXP_NODE_FORMAT] diff --git a/backends/nxp/tests/test_per_channel_conversion.py b/backends/nxp/tests/test_per_channel_conversion.py index 043ba8fc001..8a206b3e429 100644 --- a/backends/nxp/tests/test_per_channel_conversion.py +++ b/backends/nxp/tests/test_per_channel_conversion.py @@ -124,6 +124,7 @@ def test_per_channel_convolution(self): get_quantizer_fn=lambda: NeutronAtenQuantizer( Conv2dPatternPerChannel(is_per_channel=True), static_qconfig ), + use_neutron_for_format_conversion=False, ) tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value