Skip to content

Commit e1a0db8

Browse files
Add support for conversion and quantization of Mean Dim operator
1 parent 1036071 commit e1a0db8

File tree

6 files changed

+97
-89
lines changed

6 files changed

+97
-89
lines changed

backends/nxp/backend/edge_program_converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
exir_ops.edge.aten._softmax.default: SoftmaxConverter,
3232
exir_ops.edge.aten.view_copy.default: ViewCopyConverter,
3333
exir_ops.edge.aten.add.Tensor: AddTensorConverter,
34+
exir_ops.edge.aten.mean.dim: MeanDimConverter,
3435
}
3536

3637

backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
AddTensorConverter
2323
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.relu_converter import \
2424
ReLUConverter
25+
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.mean_dim_converter import \
26+
MeanDimConverter
2527
__all__ = [
2628
"AddMMConverter", "ConvolutionConverter", "MMConverter", "PermuteCopyConverter", "SoftmaxConverter",
2729
"ViewCopyConverter", "QDQDequantizeConverter", "QDQQuantizeConverter", "ConstantPadNDConverter", "ReLUConverter",
28-
"MaxPool2dConverter", "AvgPool2dConverter", "AddTensorConverter"
30+
"MaxPool2dConverter", "AvgPool2dConverter", "AddTensorConverter", "MeanDimConverter"
2931
]
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) 2025 NXP
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from torch.fx import Node
9+
from torch.nn import Parameter
10+
11+
from executorch.backends.nxp.backend.ir.converter.conversion.translator import \
12+
create_channels_last_to_channels_first_permutation
13+
from executorch.backends.nxp.backend.ir.converter.node_converter import NodeConverter, Target
14+
from executorch.backends.nxp.backend.ir.converter.node_converters.shared.reduce_utils import \
15+
convert_axes_from_attribute
16+
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import mean_options
17+
18+
19+
class MeanDimConverter(NodeConverter):
20+
supported_targets = [Target.RT700]
21+
22+
@staticmethod
23+
def _is_supported_in_IR(node: Node, parameters_mapping: dict[str, Parameter]) -> bool:
24+
dim = node.args[1]
25+
keepdim = node.args[2] if len(node.args) >= 3 else False
26+
rank = len(node.args[0].meta["val"].shape)
27+
to_neg_dim = lambda d: d - rank if d > 0 else d
28+
dim = [to_neg_dim(d) for d in dim]
29+
30+
# Only last 2 dimensions (H, W) and keepdim=True with rank=4 are supported on Neutron.
31+
if rank != 4 or dim not in [[-1, -2], [-2, -1]] or not keepdim:
32+
return False
33+
34+
if hasattr(node.kwargs, "dtype") and node.kwargs["dtype"] not in [torch.float32, torch.uint32, torch.uint8]:
35+
return False
36+
37+
if not NodeConverter._has_shared_q_params_if_quantized(node):
38+
return False
39+
40+
return True
41+
42+
@staticmethod
43+
def _normalize_and_to_channel_last_dim(dim: list[int], rank: int) -> list[int]:
44+
# convert negative index to positive
45+
to_pos_dim = lambda d: d + rank if d < 0 else d
46+
dim = [to_pos_dim(d) for d in dim]
47+
48+
perm = create_channels_last_to_channels_first_permutation(rank, True)
49+
dim = [perm[d] for d in dim]
50+
51+
return dim
52+
53+
# Mean Dim Node format: (Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None)
54+
def convert(self, node: Node):
55+
""" Convert 'mean.dim' operator to TFLite 'Mean'.
56+
"""
57+
self.assert_convertible(node)
58+
59+
dim = node.args[1]
60+
keepdim = node.args[2] if len(node.args) >= 3 else False
61+
62+
t_op = self._create_tflite_op_with_io_tensors(node)
63+
t_op.builtin_options = mean_options.Mean(keepdim)
64+
x = t_op.tmp_inputs[0]
65+
66+
if x.tensor_format.is_channels_last():
67+
dim = self._normalize_and_to_channel_last_dim(dim, x.rank)
68+
69+
convert_axes_from_attribute(t_op, self.builder, dim)
70+
self.builder.append_operators([t_op])

backends/nxp/backend/ir/converter/node_converters/shared/reduce_utils.py

Lines changed: 2 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,17 @@
11
#
2-
# Copyright 2024 NXP
2+
# Copyright 2024-2025 NXP
33
#
44
# License: LA_OPT_NXP_Software_License
55
# See the LICENSE_LA_OPT_NXP_Software_License for more details.
66
#
77

88
import numpy as np
99

10-
from executorch.backends.nxp.backend.ir.lib.tflite.TensorType import TensorType
11-
from executorch.backends.nxp.backend.ir import logger
1210
from executorch.backends.nxp.backend.ir.converter.builder.model_builder import ModelBuilder
1311
from executorch.backends.nxp.backend.ir.converter.conversion import translator
14-
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList, try_get_input
15-
from executorch.backends.nxp.backend.ir.converter.tensor_utils import tensor_has_data
12+
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList
1613
from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat
1714
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
18-
from executorch.backends.nxp.backend.ir.tflite_generator.meta.types import name_for_type
1915

2016

2117
def convert_axes_from_attribute(t_op: tflite_model.Operator, builder: ModelBuilder, axes: list[int] | None):
@@ -38,88 +34,6 @@ def convert_axes_from_attribute(t_op: tflite_model.Operator, builder: ModelBuild
3834
t_op.tmp_inputs.append(axes_tensor)
3935

4036

41-
def convert_axes_from_input_tensor(t_op: tflite_model.Operator, builder: ModelBuilder, inspector: ONNXModelInspector,
42-
ops: OpsList, noop_with_empty_axes: int, op_type: str):
43-
""" Verify the `axes` tensor (on input index 1) of the `t_op`, which is expected to represent an ONNX reduction
44-
operator.
45-
"""
46-
x = t_op.tmp_inputs[0]
47-
rank = x.rank
48-
49-
if axes_tensor := try_get_input(t_op, 1):
50-
51-
# ONNX uses int64, while TFLite requires int32 for the `axes` tensor.
52-
if axes_tensor.type != TensorType.INT64:
53-
logger.e(logger.Code.INVALID_ONNX_OPERATOR,
54-
f'ONNX `{op_type}` has `axes` of type `{name_for_type(axes_tensor.type)}`, instead of INT64.')
55-
56-
# Try to get the inferred data for the `axes` input.
57-
if (axes_data := inspector.try_get_inferred_tensor_data(axes_tensor.name)) is not None:
58-
# The `axes` were inferred during shape inference.
59-
logger.d(f'Using inferred data for the `axes` input tensor of ONNX `{op_type}`.')
60-
61-
# Create a new tensor, in case the original `axes` tensor is used by multiple ops.
62-
axes_tensor = builder.create_tensor_for_data(axes_data.astype(np.int32), 'axes')
63-
64-
# Make sure the `axes` are int32.
65-
if tensor_has_data(axes_tensor):
66-
# Cast the `axes` to int32 statically.
67-
axes_tensor.tmp_buffer.data = axes_tensor.tmp_buffer.data.astype(np.int32)
68-
axes_tensor.type = TensorType.INT32
69-
70-
else:
71-
# The `axes` are dynamic and there is no inferred data for them. The shape inference is not possible in
72-
# this case, so it must have been skipped. If the `axes` are empty at runtime, ONNX will reduce over
73-
# all dimensions, whereas TFLite will not reduce at all. So the behavior is different, and it depends
74-
# on runtime data. Conversion could be implemented by adding multiple extra operators.
75-
# I don't thing that completely prohibiting the conversion here is ideal, since the issue arises only in
76-
# an edge case, which is hopefully not very common. Just print a warning message for now.
77-
logger.w(f'Conversion of ONNX `{op_type}` with a dynamic `axes` input will not be correct, if the `axes`'
78-
'are empty at runtime!')
79-
80-
# Insert a `Cast` op, to make the `axes` int32.
81-
cast_op = builder.create_cast_before(t_op, 1, TensorType.INT32)
82-
ops.add_pre(cast_op)
83-
84-
# For future references. Following code only cares about the final axes tensor.
85-
axes_tensor = cast_op.tmp_outputs[0]
86-
87-
# Assign the new `axes_tensor` to the ReduceX operator.
88-
t_op.tmp_inputs[1] = axes_tensor
89-
90-
else:
91-
# No axes specified.
92-
93-
if noop_with_empty_axes == 1:
94-
# ONNXRT: According to the documentation, the operator should do nothing in this situation. But that's
95-
# not what happens in ONNX Runtime. ORT seems to simply ignore the `noop_with_empty_axes` attribute.
96-
# https://github.com/microsoft/onnxruntime/issues/19147
97-
# For now, exit with error. If later ORT adds support for this attribute, simply uncomment the
98-
# following code.
99-
100-
# if self.builder.operator_can_be_skipped(t_op, self.inspector):
101-
# # Skip the operator.
102-
# self.builder.redirect_tensor(t_op.tmp_outputs[0], t_op.tmp_inputs[0])
103-
# return []
104-
#
105-
# else:
106-
# # Return an operator which does nothing.
107-
# self.builder.turn_operator_to_identity(t_op)
108-
# return [t_op]
109-
110-
logger.e(logger.Code.INVALID_ONNX_OPERATOR,
111-
f'ONNX `{op_type}` has `noop_with_empty_axes` == 1 and the `axes` are not specified, which'
112-
' indicates that the operator should do nothing. This is however not supported by ONNX'
113-
' Runtime, and therefore the conversion is also not supported.')
114-
115-
else:
116-
# Default is to reduce all axes.
117-
axes_tensor = builder.create_tensor_for_data(np.arange(rank).astype(np.int32), 'axes')
118-
119-
t_op.tmp_inputs[1:] = [] # If the optional input was passed with name "", remove it.
120-
t_op.tmp_inputs.append(axes_tensor)
121-
122-
12337
def ensure_reduce_transposition(builder, ops: OpsList):
12438
"""
12539
Ensure transposition of ReduceX operator is defined correctly based on tensor format.

backends/nxp/neutron_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def tag_qdq_clusters(self, nodes: List[torch.fx.Node]):
190190
exir_ops.edge.aten._softmax.default: SoftmaxConverter,
191191
exir_ops.edge.aten.view_copy.default: ViewCopyConverter,
192192
exir_ops.edge.aten.add.Tensor: AddTensorConverter,
193+
exir_ops.edge.aten.mean.dim: MeanDimConverter,
193194
}
194195

195196

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,15 @@ def partition_types(self):
169169
return [torch.ops.aten.view.default]
170170

171171

172+
class FlattenPattern(SharedSpecPattern):
173+
"""
174+
Quantizer for Flatten operator.
175+
"""
176+
177+
def partition_types(self):
178+
return [torch.ops.aten.flatten.using_ints]
179+
180+
172181
class PermutePattern(SharedSpecPattern):
173182
"""
174183
Quantizer for Permute operator.
@@ -178,6 +187,15 @@ def partition_types(self):
178187
return [torch.ops.aten.permute.default]
179188

180189

190+
class MeanDimPattern(SharedSpecPattern):
191+
"""
192+
Quantizer for Mean Dim operator.
193+
"""
194+
195+
def partition_types(self):
196+
return [torch.ops.aten.mean.dim]
197+
198+
181199
class SoftMaxPattern(QuantizationPattern):
182200
"""
183201
Quantizer for Softmax operator.
@@ -275,6 +293,8 @@ def __init__(self):
275293
CadenceAtenQuantizer(ReluInPlacePattern(), static_qconfig),
276294
CadenceAtenQuantizer(AvgPoolPattern(), static_qconfig),
277295
CadenceAtenQuantizer(ViewPattern(), static_qconfig),
296+
CadenceAtenQuantizer(MeanDimPattern(), static_qconfig),
297+
CadenceAtenQuantizer(FlattenPattern(), static_qconfig),
278298
]
279299
)
280300
self.op_to_quantizer = {pt: q for q in self.quantizers for pt in q.pattern.partition_types()}

0 commit comments

Comments
 (0)