diff --git a/backends/nxp/backend/edge_helper.py b/backends/nxp/backend/edge_helper.py index 9b584d5166b..3b74d86f599 100644 --- a/backends/nxp/backend/edge_helper.py +++ b/backends/nxp/backend/edge_helper.py @@ -5,6 +5,7 @@ import torch from torch.fx import Node +from torch.nn import Parameter def input_tensor(node: Node, input_index: int) -> torch.Tensor: @@ -38,3 +39,35 @@ def input_tensor_safe(node: Node, input_index: int) -> torch.Tensor | None: return None return input_tensor(node, input_index) + + +def node_is_static_tensor(node: Node, parameters_mapping: dict[str, Parameter]) -> bool: + """Return `True` if the given `node` has static data in the `parameters_mapping` dict. + :param node: Tensor node to check for data. + :param parameters_mapping: Dict mapping tensor names to their static data. Should be inferred from the + `state_dict` attribute of an edge program. + """ + return node.name in parameters_mapping.keys() + + +def node_is_effectively_static_tensor( + node: Node, parameters_mapping: dict[str, Parameter] +) -> bool: + """Return `True` if the given `node` has static data, or follows after a `Dequantize` node with a static input. + In the IR, the `node` will be turned into a static quantized tensor. + :param node: Tensor node to check for data. + :param parameters_mapping: Dict mapping tensor names to their static data. Should be inferred from the + `state_dict` attribute of an edge program. + """ + if node_is_static_tensor(node, parameters_mapping): + return True + + def _is_dequantize(node_: Node) -> bool: + return node_.target.__name__ in { + "quantized_decomposed.dequantize_per_tensor.default", + "quantized_decomposed.dequantize_per_channel.default", + } + + return _is_dequantize(node) and node_is_static_tensor( + node.args[0], parameters_mapping + ) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py index efecebfc783..6aac32649d3 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py @@ -6,26 +6,36 @@ import numpy as np import torch -from executorch.backends.nxp.backend.edge_helper import input_tensor, input_tensor_safe +from executorch.backends.nxp.backend.edge_helper import ( + input_tensor, + input_tensor_safe, + node_is_effectively_static_tensor, +) from executorch.backends.nxp.backend.ir.converter.conversion import ( aten_translator, common, ) -from executorch.backends.nxp.backend.ir.converter.conversion.common import ( - OpsList, - try_get_input, -) +from executorch.backends.nxp.backend.ir.converter.conversion.common import try_get_input from executorch.backends.nxp.backend.ir.converter.node_converter import ( NodeConverter, Target, ) +from executorch.backends.nxp.backend.ir.converter.node_converters.shared import ( + conv_utils, +) +from executorch.backends.nxp.backend.ir.converter.node_converters.shared.conv_utils import ( + ConvConversionResult, + ConvParameters, +) from executorch.backends.nxp.backend.ir.converter.quantization_utils import ( set_quantization_parameters_to_tensor, ) +from executorch.backends.nxp.backend.ir.converter.tensor_utils import tensor_has_data from executorch.backends.nxp.backend.ir.lib.tflite.TensorType import TensorType from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( conv_2d_options, + depthwise_conv_2d_options, ) from torch.fx import Node from torch.nn import Parameter @@ -48,7 +58,29 @@ def _is_supported_in_IR( if output_padding != [0, 0]: return False - if groups != 1: + if groups == 1: + # Regular (pointwise) convolution. + pass + + elif conv_utils.group_conv_convertible_as_depthwise( + node, groups + ) and node_is_effectively_static_tensor(node.args[1], parameters_mapping): + # Depthwise convolution. + # Only supported if the weights are static, because TFLite `DepthwiseConv2D` uses permuted weights. In case + # the weights are dynamic, a Transpose operator would have to be added, which is not supported on Neutron. + pass + + elif conv_utils.group_conv_convertible_into_multiple_convolutions(node, groups): + # Group Separable convolution. + # Not supported natively by the eIQ Neutron so Group Separable Convolution. + # In practice it can be computed by splitting the Group Separable Convolution into multiple Pointwise + # Convo it will use the Split and Concat operation. The Concat operation in Neutron Converter + # SDK 25.03 requires the # of channels to be multipy of # of MAC units in the eIQ Neutron. + # For this reason Group Separable Convolution is not delegated by default at this moment. + return False + + else: + # All conversion options related to the `group` attribute have been checked and none of them can be used. return False if input_tensor_safe(node, 2) is None: @@ -57,71 +89,152 @@ def _is_supported_in_IR( if weight_tensor.dtype not in [torch.float32, torch.int8, torch.uint8]: return False - return True - - def _convert_2d_conv( - self, stride, padding, dilation, t_op: tflite_model.Operator - ) -> list[tflite_model.Operator]: - ops = OpsList(middle_op=t_op) - t_op.builtin_options = conv_2d_options.Conv2D() - common.assign_2d_strides(t_op.builtin_options, stride) - common.assign_2d_dilations(t_op.builtin_options, dilation) - t_op.builtin_options.padding, explicit_padding = ( - aten_translator.convert_padding(padding) - ) + if node.args[0].meta["val"].shape[0] != 1: + # Only batch size 1 is supported on neutron. + return False - if explicit_padding is not None: - # Need to prepend a 'Pad' operator, which adds 0s. But these will be included in the computation! - ops.add_pre( - self.builder.create_pad_operator_before(t_op, 0, explicit_padding) - ) + return True - input_tensor: tflite_model.Tensor = t_op.tmp_inputs[0] - weight_tensor: tflite_model.Tensor = t_op.tmp_inputs[1] - output_tensor: tflite_model.Tensor = t_op.tmp_outputs[0] + Stride = Padding = Dilation = OutPadding = list[int] + Transposed = bool + Groups = int - if (bias_tensor := try_get_input(t_op, 2)) is None: + @staticmethod + def _get_convolution_arguments( + conv_node: Node, + ) -> (Stride, Padding, Dilation, Transposed, OutPadding, Groups): + # The arguments of the conv are: + # [x, w, b, stride, padding, dilation, transposed, output padding, groups] + # https://github.com/pytorch/pytorch/blob/v2.6.0/aten/src/ATen/native/Convolution.cpp#L286-L291 + _, _, _, stride, padding, dilation, transposed, out_padding, groups = ( + conv_node.args + ) + return stride, padding, dilation, transposed, out_padding, groups + + # noinspection PyPep8Naming + def _convert_unpadded_2D( + self, t_op: tflite_model.Operator, conv_params: ConvParameters + ) -> conv_utils.ConvConversionResult: + """Convert the `aten.convolution` into TFLite. The `padding` and `builtin_options` must be converter by the + caller. + """ + common.assign_2d_strides(t_op.builtin_options, conv_params.stride) + common.assign_2d_dilations(t_op.builtin_options, conv_params.dilation) + + x: tflite_model.Tensor = t_op.tmp_inputs[0] + w: tflite_model.Tensor = t_op.tmp_inputs[1] + y: tflite_model.Tensor = t_op.tmp_outputs[0] + + if (b := try_get_input(t_op, 2)) is None: # Operator has no bias. Convolution aten op can omit it, TFLite can't. - output_channels = weight_tensor.shape.vector[0] + output_channels = w.shape.vector[0] - if weight_tensor.type == TensorType.FLOAT32: + if w.type == TensorType.FLOAT32: bias_type = np.dtype(np.float32) - elif weight_tensor.type in [TensorType.INT8, TensorType.UINT8]: + elif w.type in [TensorType.INT8, TensorType.UINT8]: bias_type = np.dtype(np.int32) else: # Should never happen. raise NotImplementedError( - f"Convolution node with unsupported weight type: {weight_tensor.type}" + f"Convolution node with unsupported weight type: {w.type}" ) - bias_tensor = self.builder.create_zeros_tensor( + b = self.builder.create_zeros_tensor( [output_channels], "zero_bias", bias_type, True ) # Compute scale and zero point for bias tensor - input_scale = np.array(input_tensor.quantization.scale.vector) - weight_scale = np.array(weight_tensor.quantization.scale.vector) + input_scale = np.array(x.quantization.scale.vector) + weight_scale = np.array(w.quantization.scale.vector) bias_scale = input_scale * weight_scale bias_zero_point = np.zeros(weight_scale.shape, dtype=np.int64) set_quantization_parameters_to_tensor( - bias_tensor, bias_scale, bias_zero_point, quantized_dimension=0 + b, bias_scale, bias_zero_point, quantized_dimension=0 ) # Assign the operator its TFLite inputs and outputs - t_op.tmp_inputs = [input_tensor, weight_tensor, bias_tensor] - t_op.tmp_outputs = [output_tensor] + t_op.tmp_inputs = [x, w, b] + t_op.tmp_outputs = [y] + + conversion_result = ConvConversionResult(x, w, b, y) + conversion_result.ops_list.middle_op = t_op + + return conversion_result + + def _convert_2d_conv( + self, t_op: tflite_model.Operator, conv_params: ConvParameters + ) -> list[tflite_model.Operator]: + if conv_utils.group_conv_convertible_as_depthwise( + t_op, conv_params.groups + ): # Convert to `DepthwiseConv2D`. + t_op.builtin_options = depthwise_conv_2d_options.DepthwiseConv2D() + + conversion_result = self._convert_unpadded_2D(t_op, conv_params) + t_op.builtin_options.padding, explicit_padding = ( + aten_translator.convert_padding(conv_params.padding) + ) + if explicit_padding is not None: + # Need to prepend a 'Pad' operator, which adds 0s. + conversion_result.ops_list.add_pre( + self.builder.create_pad_operator_before(t_op, 0, explicit_padding) + ) + + # DepthwiseConv2D expects weights in format [kernel_channels, kernel_height, kernel_width, output_channels] + perm = [3, 1, 2, 0] + weight_tensor = conversion_result.conv_weight_tensor + if tensor_has_data(weight_tensor): + # Transpose cloned tensor statically + t_op.tmp_inputs[1] = self.builder.create_transposed_tensor( + weight_tensor, perm + ) + else: + raise NotImplementedError("Dynamic Depthwise Conv weights.") + + elif conv_utils.group_conv_convertible_into_multiple_convolutions( + t_op, conv_params.groups + ): + # Note: by default the Group Separable Convolution is rejected by the Neutron Partitioner, see the + # ConvolutionConveter._is_supported_in_IR() + t_op.builtin_options = conv_2d_options.Conv2D() + + return conv_utils.create_separated_convolutions_based_on_group( + t_op, + conv_params, + self.builder, + self._convert_unpadded_2D, + conv_utils.conv_op_factory, + ) + + else: + # Convert to regular `Conv2D`. + t_op.builtin_options = conv_2d_options.Conv2D() + conversion_result = self._convert_unpadded_2D(t_op, conv_params) + t_op.builtin_options.padding, explicit_padding = ( + aten_translator.convert_padding(conv_params.padding) + ) + if explicit_padding is not None: + # Need to prepend a 'Pad' operator, which adds 0s. + conversion_result.ops_list.add_pre( + self.builder.create_pad_operator_before(t_op, 0, explicit_padding) + ) - return ops.flatten() + return conversion_result.ops_list.flatten() def convert(self, node: Node): self.assert_convertible(node) - stride = node.args[3] - padding = node.args[4] - dilation = node.args[5] + stride, padding, dilation, _, _, groups = self._get_convolution_arguments(node) t_op = self._create_tflite_op_with_io_tensors(node) - ops_to_add = self._convert_2d_conv(stride, padding, dilation, t_op) + conv_params = ConvParameters(stride, padding, dilation, groups) + + rank = t_op.tmp_inputs[1].shape.len() + if rank == 4: # Conv2D + ops_to_add = self._convert_2d_conv(t_op, conv_params) + else: + raise NotImplementedError( + f"{rank - 2}D convolution is not supported." + ) # Should never get here. self.builder.append_operators(ops_to_add) diff --git a/backends/nxp/backend/ir/converter/node_converters/shared/conv_utils.py b/backends/nxp/backend/ir/converter/node_converters/shared/conv_utils.py new file mode 100755 index 00000000000..73bf76a830d --- /dev/null +++ b/backends/nxp/backend/ir/converter/node_converters/shared/conv_utils.py @@ -0,0 +1,400 @@ +# Copyright 2023-2025 NXP +# +# License: LA_OPT_NXP_Software_License +# See the LICENSE_LA_OPT_NXP_Software_License for more details. +# + +from copy import copy +from dataclasses import dataclass +from typing import Callable, cast + +import numpy as np + +from executorch.backends.nxp.backend.ir.converter.builder.model_builder import ( + ModelBuilder, +) +from executorch.backends.nxp.backend.ir.converter.conversion import aten_translator +from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList +from executorch.backends.nxp.backend.ir.converter.tensor_utils import tensor_has_data +from executorch.backends.nxp.backend.ir.lib.tflite.Padding import Padding +from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model +from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import ( + concatenation_options, + conv_2d_options, + split_options, +) +from torch.fx import Node + + +@dataclass +class ConvParameters: + stride: list[int] + padding: list[int] + dilation: list[int] + groups: int + + +# noinspection PyPep8Naming +def _get_IO_channels(node: Node | tflite_model.Operator) -> (int, int): + if isinstance(node, Node): + input_channels = ( + node.args[0].meta["val"].shape[1] + ) # Channels of the main input. + output_channels = ( + node.args[1].meta["val"].shape[0] + ) # Output channels of the weights. + else: + input_channels = node.tmp_inputs[0].shape[-1] # Channels of the main input. + output_channels = node.tmp_inputs[1].shape[0] # Output channels of the weights. + + return input_channels, output_channels + + +def group_conv_convertible_as_depthwise(node: Node | tflite_model.Operator, group: int): + input_channels, output_channels = _get_IO_channels(node) + + return input_channels == output_channels == group + + +def group_conv_convertible_into_multiple_convolutions( + node: Node | tflite_model.Operator, group: int +) -> bool: + if group == 1: + return False + + _, output_channels = _get_IO_channels(node) + if output_channels % group != 0: + return False # Unable to split group Conv into separated convolutions because out_channels % group != 0. + + # 10 is an empirical value. The `group` directly dictates how many branches will be created. + return 2 <= group <= 10 + + +class ConvConversionResult: + """ + Holds references to the direct I/O tensors of the Conv operator + and list of surrounding operators (Quantize, Transpose, etc.). + """ + + def __init__( + self, + input_tensor: tflite_model.Tensor, + weight_tensor: tflite_model.Tensor, + bias_tensor: tflite_model.Tensor, + output_tensor: tflite_model.Tensor, + ): + self.conv_input_tensor = input_tensor + self.conv_weight_tensor = weight_tensor + self.conv_bias_tensor = bias_tensor + self.conv_output_tensor = output_tensor + self.ops_list = OpsList() + + +ConvBuiltinOptions = conv_2d_options.Conv2D +ConvOpFactory = Callable[ + [ + ConvParameters, + tflite_model.Tensor, + tflite_model.Tensor, + tflite_model.Tensor, + tflite_model.Tensor, + ModelBuilder, + ConvBuiltinOptions, + ], + OpsList, +] +ConvConversionFn = Callable[ + [tflite_model.Operator, ConvParameters], ConvConversionResult +] + + +class _InputTensorsSplitter: + """Splits the tensors of a `Conv2D` operator. Static tensors are split statically, and for dynamic tensors, a + TFLite `Split` operator is added. + """ + + input_tensors: list[tflite_model.Tensor] + weight_tensors: list[tflite_model.Tensor] + bias_tensors: list[tflite_model.Tensor] + split_ops: list[tflite_model.Operator] + + def __init__( + self, + input_tensor: tflite_model.Tensor, + weight_tensor: tflite_model.Tensor, + bias_tensor: tflite_model.Tensor, + groups: int, + builder: ModelBuilder, + ): + self.input_tensors = [] + self.weight_tensors = [] + self.bias_tensors = [] + self.split_ops = [] + + inputs = [ + # input tensor, split by axis, output tensors container + (input_tensor, -1, self.input_tensors), + (weight_tensor, 0, self.weight_tensors), + (bias_tensor, 0, self.bias_tensors), + ] + + for i in inputs: + if tensor_has_data(i[0]): + self._generate_static_tensors(builder, groups, i[0], i[1], i[2]) + else: + self._generate_dynamic_tensors(builder, groups, i[0], i[1], i[2]) + + def _generate_dynamic_tensors( + self, builder, groups, split_tensor, axis, target_list + ): + quantization = None + if split_tensor.quantization is not None: + if split_tensor.quantization.is_per_channel(): + scale = np.split( + np.array(split_tensor.quantization.scale.vector, "float32"), groups + ) + zero_point = np.split( + np.array(split_tensor.quantization.zero_point.vector, "int32"), + groups, + ) + quantization = [ + tflite_model.Quantization( + scale=tflite_model.Scale(s), + zero_point=tflite_model.ZeroPoint(zp), + ) + for s, zp in zip(scale, zero_point) + ] + else: + quantization = [split_tensor.quantization] * groups + + split_op = self._create_split_op(builder, groups, split_tensor, axis) + + new_tensor_shape = split_tensor.shape.vector.copy() + new_tensor_shape[axis] = new_tensor_shape[axis] // groups + + for i in range(groups): + conv_split_tensor = builder.duplicate_tensor( + split_tensor, name_suffix="_group_" + str(i) + ) + conv_split_tensor.shape = tflite_model.Shape(new_tensor_shape) + if quantization is not None: + conv_split_tensor.quantization = copy(quantization[i]) + + split_op.tmp_outputs.append(conv_split_tensor) + target_list.append(conv_split_tensor) + self.split_ops.append(split_op) + + # noinspection PyMethodMayBeStatic + def _generate_static_tensors( + self, builder, groups, split_tensor, axis, target_list + ): + quantization = None + if split_tensor.quantization is not None: + if split_tensor.quantization.is_per_channel(): + scale = np.split( + np.array(split_tensor.quantization.scale.vector, "float32"), groups + ) + zero_point = np.split( + np.array(split_tensor.quantization.zero_point.vector, "int32"), + groups, + ) + quantization = [ + tflite_model.Quantization( + scale=tflite_model.Scale(s), + zero_point=tflite_model.ZeroPoint(zp), + ) + for s, zp in zip(scale, zero_point) + ] + else: + quantization = [split_tensor.quantization] * groups + + input_data = np.split(split_tensor.tmp_buffer.data, groups, axis) + + for i in range(len(input_data)): + tensor_name = split_tensor.name + "_group_" + str(i) + conv_input_tensor = builder.create_tensor_for_data( + input_data[i], tensor_name + ) + if quantization is not None: + conv_input_tensor.quantization = copy(quantization[i]) + + target_list.append(conv_input_tensor) + + # noinspection PyMethodMayBeStatic + def _create_split_op(self, builder, groups, input_tensor, axis): + axis_tensor = builder.create_tensor_for_data( + np.asarray([axis], np.int32), "split_dim_" + ) + input_split_op = tflite_model.Operator( + builtin_options=split_options.Split(groups) + ) + input_split_op.tmp_inputs = [axis_tensor, input_tensor] + + return input_split_op + + def get_input_tensor(self, idx) -> tflite_model.Tensor: + return self.input_tensors[idx] + + def get_weight_tensor(self, idx) -> tflite_model.Tensor: + return self.weight_tensors[idx] + + def get_bias_tensor(self, idx) -> tflite_model.Tensor: + return self.bias_tensors[idx] + + def get_ops(self) -> list[tflite_model.Operator]: + return self.split_ops + + +class _OutputTensorsCombiner: + """Handles creation and aggregation of the TFLite Conv2D output tensors. + Aggregation is done with `Concatenation` op. + """ + + output_tensors: list[tflite_model.Tensor] + concat_op: tflite_model.Operator + + def __init__(self, output_tensor, groups, builder): + self.output_tensors = [] + combine_axis = -1 + + new_conv_output_shape = output_tensor.shape.vector.copy() + new_conv_output_shape[combine_axis] = ( + new_conv_output_shape[combine_axis] // groups + ) + conv_output_shape = tflite_model.Shape(new_conv_output_shape) + + self.concat_op = tflite_model.Operator( + builtin_options=concatenation_options.Concatenation(combine_axis) + ) + self.concat_op.tmp_outputs = [output_tensor] + + for i in range(groups): + tensor_name = output_tensor.name + "_group_" + str(i) + output_tensor = builder.duplicate_tensor(output_tensor, tensor_name) + output_tensor.shape = conv_output_shape + + self.output_tensors.append(output_tensor) + self.concat_op.tmp_inputs.append(output_tensor) + + def get_output_tensor(self, idx): + return self.output_tensors[idx] + + def get_ops(self): + return [self.concat_op] + + +def build_input_tensor_padding( + t_op, conv_params: ConvParameters, builder, input_idx=0 +) -> (Padding, tflite_model.Operator | None): + """Build padding for input tensor of Conv2D op 't_op'.""" + + tfl_padding, explicit_padding = aten_translator.convert_padding(conv_params.padding) + if explicit_padding is not None: + # Must add extra 'Pad' operator + return tfl_padding, builder.create_pad_operator_before( + t_op, input_idx, explicit_padding + ) + + return tfl_padding, None + + +def conv_op_factory( + conv_params: ConvParameters, + input_tensor: tflite_model.Tensor, + weight_tensor: tflite_model.Tensor, + bias_tensor: tflite_model.Tensor, + output_tensor: tflite_model.Tensor, + builder, + builtin_options, +) -> OpsList: + """Build padded 'Conv2D' TFLite operator. Padding is realized by 'builtin_options.padding' definition and by + optional prepended 'Pad' operator. + """ + + conv_op = tflite_model.Operator(builtin_options=copy(builtin_options)) + conv_op.tmp_inputs = [input_tensor, weight_tensor, bias_tensor] + conv_op.tmp_outputs = [output_tensor] + + padding, pad_op = build_input_tensor_padding(conv_op, conv_params, builder) + conv_op.builtin_options.padding = padding + + if pad_op is not None: + return OpsList(pre_ops=[pad_op], middle_op=conv_op) + else: + return OpsList(middle_op=conv_op) + + +# noinspection GrazieInspection +def create_separated_convolutions_based_on_group( + t_op: tflite_model.Operator, + conv_params: ConvParameters, + builder: ModelBuilder, + conv_conversion_fn: ConvConversionFn, + conv_op_factory_fn: ConvOpFactory, +) -> list[tflite_model.Operator]: + """Build a subgraph with multiple TFLite Conv2D operators that replace an `aten.convolution` operator with 'group' + attribute higher than one. The number of new Conv2D operators corresponds to the number of groups. Input + tensors of the Aten operator are split and distributed into related convolution operators. Outputs are then + concatenated back together. + + Example: 'aten.convolution' operator with group=2 converted into TFLite subgraph will have + the following structure (tensor dimensions are just for illustrative purposes): + + │ (1,4,4,48) + ┌───▼──┐ + │Split │ + └┬────┬┘ + (1,4,4,24) │ │ (1,4,4,24) + ┌─────▼┐ ┌▼─────┐ + │Conv2D│ │Conv2D│ + └────┬─┘ └─┬────┘ + (1,4,4,18)│ │(1,4,4,18) + ┌─▼──────▼──┐ + │Concatenate│ + └─────┬─────┘ + │ (1,4,4,36) + ▼ + """ + + conversion_result = conv_conversion_fn(t_op, conv_params) + + splitter = _InputTensorsSplitter( + conversion_result.conv_input_tensor, + conversion_result.conv_weight_tensor, + conversion_result.conv_bias_tensor, + conv_params.groups, + builder, + ) + combiner = _OutputTensorsCombiner( + conversion_result.conv_output_tensor, conv_params.groups, builder + ) + + conv_ops = [] + for i in range(conv_params.groups): + input_tensor = splitter.get_input_tensor(i) + weight_tensor = splitter.get_weight_tensor(i) + bias_tensor = splitter.get_bias_tensor(i) + output_tensor = combiner.get_output_tensor(i) + + conv_builtin_options = cast( + ConvBuiltinOptions, conversion_result.ops_list.middle_op.builtin_options + ) + conv_ops_list = conv_op_factory_fn( + conv_params, + input_tensor, + weight_tensor, + bias_tensor, + output_tensor, + builder, + conv_builtin_options, + ) + + conv_ops.extend(conv_ops_list.flatten()) + + return ( + conversion_result.ops_list.pre_ops # `Pad` operator + + splitter.get_ops() + + conv_ops + + combiner.get_ops() # Split, Conv2D, Concatenate ops + + conversion_result.ops_list.post_ops + ) # Currently not used diff --git a/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py index d6030ebae7f..3fa2a3239dd 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py @@ -124,12 +124,10 @@ def test_constant_pad_nd_conversion__format_less(input_shape, paddings): @pytest.mark.parametrize( "input_shape, paddings", [ - pytest.param([2, 4, 6, 8], list(range(2)), id="4D, padding W"), - pytest.param([2, 4, 6, 8], list(range(4)), id="4D, padding H, W"), - pytest.param([2, 1, 6, 8], [1, 2, 3, 4, 2, 1], id="4D, padding C, H, W"), - pytest.param( - [2, 1, 6, 8], [1, 2, 3, 4, 2, 1, 5, 6], id="4D, padding N, C, H, W" - ), + pytest.param([1, 4, 6, 8], list(range(2)), id="4D, padding W"), + pytest.param([1, 4, 6, 8], list(range(4)), id="4D, padding H, W"), + pytest.param([1, 1, 6, 8], [1, 2, 3, 4, 2, 1], id="4D, padding C, H, W"), + # pytest.param([1, 1, 6, 8], [1, 2, 3, 4, 2, 1, 5, 6], id='4D, padding N, C, H, W'), # Batch size must stay 0. ], ) def test_constant_pad_nd_conversion__channels_first(input_shape, paddings): diff --git a/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py index 1eceacbf060..767bb0dfff0 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py @@ -10,6 +10,12 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) +from executorch.backends.nxp.backend.ir.converter.builder.model_builder import ( + ModelBuilder, +) +from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import ( + BuiltinOperator, +) from executorch.backends.nxp.tests.executorch_pipeline import ( to_edge_program, to_quantized_edge_program, @@ -204,3 +210,310 @@ def test_conv2d_quant_conversion(mocker, model: torch.nn.Module, input_shape): input_data=input_data, atol=1.0, ) + + +@pytest.mark.parametrize("stride", [1, 2]) +@pytest.mark.parametrize("dilation", [1, 2]) +@pytest.mark.parametrize("kernel_shape", [[1, 2], [3, 3], [4, 1]]) +def test_conv2d_conversion__depthwise(stride, dilation, kernel_shape, mocker): + input_shape = [1, 3, 12, 16] + group = input_shape[1] + edge_program = to_edge_program( + Conv2dModule( + group=group, + in_channels=group, + out_channels=group, + stride=stride, + dilation=dilation, + kernel_size=kernel_shape, + ), + input_shape, + ).exported_program() + + input_data = np.random.random(input_shape).astype(np.float32) + + spy = mocker.spy(ModelBuilder, "finish") + + convert_run_compare( + edge_program, + input_data, + tflite_input_preprocess=ToNHWCPreprocess(), + tflite_output_preprocess=ToNCHWPreprocess(), + atol=4e-7, + ) + conversion_result = spy.spy_return + ops = conversion_result.sub_graphs[0].operators.vector + + assert len(ops) == 1 + assert ops[0].builtin_options.operator_type == BuiltinOperator.DEPTHWISE_CONV_2D + + +@pytest.mark.parametrize("stride", [1, 2]) +@pytest.mark.parametrize("dilation", [1, 2]) +@pytest.mark.parametrize("kernel_shape", [[1, 2], [3, 3], [4, 1]]) +def test_conv2d_conversion__depthwise__quantized( + stride, dilation, kernel_shape, mocker +): + input_shape = [1, 4, 12, 12] + group = input_shape[1] + spy = mocker.spy(ModelBuilder, "finish") + + edge_program = to_quantized_edge_program( + Conv2dModule( + group=group, + in_channels=group, + out_channels=group, + stride=stride, + dilation=dilation, + kernel_size=kernel_shape, + ), + tuple(input_shape), + ).exported_program() + + ops = spy.spy_return.sub_graphs[0].operators.vector + assert len(ops) == 1 + assert ops[0].builtin_options.operator_type == BuiltinOperator.DEPTHWISE_CONV_2D + + nodes = list(edge_program.graph.nodes) + assert ( + len(nodes) == 7 + ) # input, Quant, lowered_module, delegate_call, getitem, Deq, output + assert nodes[2].target == "lowered_module_0" + + +@pytest.mark.parametrize("padding", [1, 2]) +def test_conv2d_conversion__depthwise__padded(padding, mocker): + input_shape = [1, 3, 13, 15] + group = input_shape[1] + edge_program = to_edge_program( + Conv2dModule( + group=group, in_channels=group, out_channels=group, padding=padding + ), + input_shape, + ).exported_program() + + input_data = np.random.random(input_shape).astype(np.float32) + + spy = mocker.spy(ModelBuilder, "finish") + + convert_run_compare( + edge_program, + input_data, + tflite_input_preprocess=ToNHWCPreprocess(), + tflite_output_preprocess=ToNCHWPreprocess(), + atol=4e-7, + ) + conversion_result = spy.spy_return + ops = conversion_result.sub_graphs[0].operators.vector + + assert len(ops) == 2 + assert ops[0].builtin_options.operator_type == BuiltinOperator.PAD + assert ops[1].builtin_options.operator_type == BuiltinOperator.DEPTHWISE_CONV_2D + + +@pytest.mark.parametrize("padding", [1, 2]) +def test_conv2d_conversion__depthwise__padded__quantized(padding, mocker): + input_shape = [1, 4, 12, 12] + group = input_shape[1] + spy = mocker.spy(ModelBuilder, "finish") + + edge_program = to_quantized_edge_program( + Conv2dModule( + group=group, in_channels=group, out_channels=group, padding=padding + ), + tuple(input_shape), + ).exported_program() + + ops = spy.spy_return.sub_graphs[0].operators.vector + assert len(ops) == 2 + assert ops[0].builtin_options.operator_type == BuiltinOperator.PAD + assert ops[1].builtin_options.operator_type == BuiltinOperator.DEPTHWISE_CONV_2D + + nodes = list(edge_program.graph.nodes) + assert ( + len(nodes) == 7 + ) # input, Quant, lowered_module, delegate_call, getitem, Deq, output + assert nodes[2].target == "lowered_module_0" + + +@pytest.mark.parametrize("stride", [1, 2]) +@pytest.mark.parametrize("dilation", [1, 2]) +@pytest.mark.parametrize( + "input_shape, group, out_channels", + [([1, 4, 12, 12], 2, 2), ([2, 3, 8, 15], 3, 6), ([11, 16, 9, 8], 4, 16)], +) +def test_conv2d_conversion__separated( + input_shape, group, out_channels, stride, dilation +): + edge_program = to_edge_program( + Conv2dModule( + group=group, + in_channels=input_shape[1], + out_channels=out_channels, + stride=stride, + dilation=dilation, + ), + input_shape, + ).exported_program() + + input_data = np.random.random(input_shape).astype(np.float32) + + # Note: The generic group convolution is not yet supported by Neutron Converter. Once supported, the + # commented out code allows usual testing flow for this test-case. + + # spy = mocker.spy(ModelBuilder, 'finish') + + # The convert_run_compare skips the partitioner call, hence conversion failure indicated by exception + # is expected behavior now. + with pytest.raises(AssertionError) as e: + convert_run_compare( + edge_program, + input_data, + tflite_input_preprocess=ToNHWCPreprocess(), + tflite_output_preprocess=ToNCHWPreprocess(), + atol=3.0e-7, + ) + assert ( + "`aten_convolution_default` is not convertible to the intermediate representation" + in str(e) + ) + + # ops = spy.spy_return.sub_graphs[0].operators.vector + # assert len(ops) == 1 + group + 1 # Split -> Conv (group times) -> Concat + # assert ops[0].builtin_options.operator_type == BuiltinOperator.SPLIT + # for op in ops[1:-1]: + # assert op.builtin_options.operator_type == BuiltinOperator.CONV_2D + # assert ops[-1].builtin_options.operator_type == BuiltinOperator.CONCATENATION + + +@pytest.mark.parametrize("stride", [1, 2]) +@pytest.mark.parametrize("dilation", [1, 2]) +@pytest.mark.parametrize( + "input_shape, group, out_channels", + [([1, 4, 12, 12], 2, 2), ([2, 3, 17, 9], 3, 6), ([11, 16, 9, 8], 4, 16)], +) +def test_conv2d_conversion__separated__quantized( + input_shape, group, out_channels, stride, dilation +): + + # Note: The generic group convolution is not yet supported by Neutron Converter. Once supported, the + # commented out code allows usuall testing flow for this test-case. + + # spy = mocker.spy(ModelBuilder, 'finish') + + # The convert_run_compare skips the partitioner call, hence conversion failure indicated by exception + # is expected behavior now. + edge_program = to_quantized_edge_program( + Conv2dModule( + group=group, + in_channels=input_shape[1], + out_channels=out_channels, + stride=stride, + dilation=dilation, + ), + tuple(input_shape), + ).exported_program() + + # ops = spy.spy_return.sub_graphs[0].operators.vector + # assert len(ops) == 1 + group + 1 # Split -> Conv (group times) -> Concat + # assert ops[0].builtin_options.operator_type == BuiltinOperator.SPLIT + # for op in ops[1:-1]: + # assert op.builtin_options.operator_type == BuiltinOperator.CONV_2D + # assert ops[-1].builtin_options.operator_type == BuiltinOperator.CONCATENATION + + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 11 + assert ( + nodes[7].target.__name__ == "aten.convolution.default" + ) # Convolution not delegated. + + +@pytest.mark.parametrize("padding", [1, 2]) +@pytest.mark.parametrize( + "input_shape, group, out_channels", + [([1, 4, 12, 12], 2, 2), ([2, 3, 4, 5], 3, 6), ([11, 16, 9, 8], 4, 16)], +) +def test_conv2d_conversion__separated__padded( + input_shape, group, out_channels, padding +): + edge_program = to_edge_program( + Conv2dModule( + group=group, + in_channels=input_shape[1], + out_channels=out_channels, + padding=padding, + ), + input_shape, + ).exported_program() + + input_data = np.random.random(input_shape).astype(np.float32) + + # Note: The generic group convolution is not yet supported by Neutron Converter. Once supported, the + # commented out code allows usuall testing flow for this test-case. + + # spy = mocker.spy(ModelBuilder, 'finish') + + # The convert_run_compare skips the partitioner call, hence conversion failure indicated by exception + # is expected behavior now. + with pytest.raises(AssertionError) as e: + convert_run_compare( + edge_program, + input_data, + tflite_input_preprocess=ToNHWCPreprocess(), + tflite_output_preprocess=ToNCHWPreprocess(), + atol=3.0e-7, + ) + assert ( + "`aten_convolution_default` is not convertible to the intermediate representation" + in str(e) + ) + + # conversion_result = spy.spy_return + # ops = conversion_result.sub_graphs[0].operators.vector + # assert len(ops) == 1 + 2 * group + 1 # Split -> Pad + Conv (group times) -> Concat + # assert ops[0].builtin_options.operator_type == BuiltinOperator.SPLIT + # for op in ops[1:-2:2]: + # assert op.builtin_options.operator_type == BuiltinOperator.PAD + # for op in ops[2:-1:2]: + # assert op.builtin_options.operator_type == BuiltinOperator.CONV_2D + # assert ops[-1].builtin_options.operator_type == BuiltinOperator.CONCATENATION + + +@pytest.mark.parametrize("padding", [1, 2]) +@pytest.mark.parametrize( + "input_shape, group, out_channels", + [([1, 4, 12, 12], 2, 2), ([2, 3, 4, 5], 3, 6), ([11, 16, 9, 8], 4, 16)], +) +def test_conv2d_conversion__separated__padded__quantized( + input_shape, group, out_channels, padding +): + + # Note: The generic group convolution is not yet supported by Neutron Converter. Once supported, the + # commented out code allows usuall testing flow for this test-case. + + # spy = mocker.spy(ModelBuilder, 'finish') + + edge_program = to_quantized_edge_program( + Conv2dModule( + group=group, + in_channels=input_shape[1], + out_channels=out_channels, + padding=padding, + ), + tuple(input_shape), + ).exported_program() + + # ops = spy.spy_return.sub_graphs[0].operators.vector + # assert len(ops) == 1 + 2 * group + 1 # Split -> Pad + Conv (group times) -> Concat + # assert ops[0].builtin_options.operator_type == BuiltinOperator.SPLIT + # for op in ops[1:-2:2]: + # assert op.builtin_options.operator_type == BuiltinOperator.PAD + # for op in ops[2:-1:2]: + # assert op.builtin_options.operator_type == BuiltinOperator.CONV_2D + # assert ops[-1].builtin_options.operator_type == BuiltinOperator.CONCATENATION + + nodes = list(edge_program.graph.nodes) + assert len(nodes) == 11 + assert ( + nodes[7].target.__name__ == "aten.convolution.default" + ) # Convolution not delegated. diff --git a/backends/nxp/tests/ir/converter/node_converter/test_softmax_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_softmax_converter.py index c3eecc04adc..92af90b923d 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_softmax_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_softmax_converter.py @@ -70,8 +70,8 @@ def test_softmax_conversion__unknown_input_format(input_shape, dim: int): @pytest.mark.parametrize( "input_shape,dim", [ - pytest.param((10, 4, 32, 32), 1, id="4D,dim=1"), - pytest.param((10, 4, 16, 16), -3, id="4D,dim=-3"), + pytest.param((1, 4, 32, 32), 1, id="4D,dim=1"), + pytest.param((1, 4, 16, 16), -3, id="4D,dim=-3"), ], ) def test_softmax_conversion_channel_last(input_shape, dim: int): diff --git a/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py index 9863c8acc41..12709dab6b9 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_view_copy_converter.py @@ -90,8 +90,8 @@ def forward(self, x): def test__channels_first_to_2d(mocker): - input_shape = [2, 4, 7, 9] - new_shape = [12, 32] # Mix up the dimensions for a thorough test. + input_shape = [1, 4, 7, 9] + new_shape = [6, 32] # Mix up the dimensions for a thorough test. torch_model = ConvReshapeModule(channels=input_shape[1], new_shape=new_shape) edge_program = to_edge_program(torch_model, input_shape).exported_program() @@ -113,7 +113,7 @@ def test__channels_first_to_2d(mocker): def test__channels_first_to_4d(mocker): - input_shape = [2, 4, 6, 8] + input_shape = [1, 8, 6, 8] new_shape = [7, 4, 2, 5] torch_model = ConvReshapeModule(channels=input_shape[1], new_shape=new_shape) @@ -124,7 +124,10 @@ def test__channels_first_to_4d(mocker): converter_spy = mocker.spy(ModelBuilder, "finish") convert_run_compare( - edge_program, input_data, tflite_input_preprocess=ToNHWCPreprocess() + edge_program, + input_data, + tflite_input_preprocess=ToNHWCPreprocess(), + atol=2.0e-7, ) tflite_model = converter_spy.spy_return @@ -137,7 +140,7 @@ def test__channels_first_to_4d(mocker): def test__formatless_to_channels_first(mocker): input_shape = [12, 32] - new_shape = [2, 4, 6, 8] # Mix up the dimensions for a thorough test. + new_shape = [1, 4, 12, 8] # Mix up the dimensions for a thorough test. torch_model = FormatlessToChannelsFirstModule( channels=new_shape[1], new_shape=new_shape @@ -149,7 +152,10 @@ def test__formatless_to_channels_first(mocker): converter_spy = mocker.spy(ModelBuilder, "finish") convert_run_compare( - edge_program, input_data, tflite_output_preprocess=ToNCHWPreprocess() + edge_program, + input_data, + tflite_output_preprocess=ToNCHWPreprocess(), + atol=2.0e-7, ) tflite_model = converter_spy.spy_return @@ -162,7 +168,7 @@ def test__formatless_to_channels_first(mocker): def test__formatless_to_formatless(mocker): input_shape = [12, 32] - new_shape = [2, 4, 6, 8] + new_shape = [1, 4, 6, 16] torch_model = FormatlessToFormatlessModule(new_shape=new_shape) edge_program = to_edge_program(torch_model, input_shape).exported_program() diff --git a/backends/nxp/tests/models.py b/backends/nxp/tests/models.py index 741e64a28a1..90e550d2fc5 100644 --- a/backends/nxp/tests/models.py +++ b/backends/nxp/tests/models.py @@ -1,4 +1,5 @@ -# Copyright 2024 NXP +# Copyright (c) 2024-2025 NXP +# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -18,6 +19,7 @@ def __init__( out_channels: int = 8, padding: Union[str, int, Collection[int]] = 0, stride: Union[int, tuple[int, int]] = 2, + group: int = 1, ): super().__init__() @@ -29,6 +31,7 @@ def __init__( padding=padding, dilation=dilation, bias=bias, + groups=group, ) def forward(self, x):