Skip to content

Commit 34f155b

Browse files
committed
Merge branch 'main-nxp' of ssh://bitbucket.sw.nxp.com/aitec/executorch into bugfix/EIEX-244
2 parents abdba4c + 34c7933 commit 34f155b

File tree

18 files changed

+285
-14
lines changed

18 files changed

+285
-14
lines changed

backends/nxp/backend/edge_program_converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
exir_ops.edge.aten.mm.default: MMConverter,
3030
exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter,
3131
exir_ops.edge.aten.relu.default: ReLUConverter,
32+
exir_ops.edge.aten.hardtanh.default: HardTanhConverter,
3233
exir_ops.edge.aten._softmax.default: SoftmaxConverter,
3334
exir_ops.edge.aten.view_copy.default: ViewCopyConverter,
3435
exir_ops.edge.aten.add.Tensor: AddTensorConverter,

backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
CloneConverter
3131
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.abs_converter import \
3232
AbsConverter
33+
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.hardtanh_converter import \
34+
HardTanhConverter
3335
__all__ = [
3436
"AddMMConverter", "ConvolutionConverter", "MMConverter", "PermuteCopyConverter", "SoftmaxConverter",
3537
"ViewCopyConverter", "QDQDequantizeConverter", "QDQQuantizeConverter", "ConstantPadNDConverter", "ReLUConverter",
3638
"MaxPool2dConverter", "AvgPool2dConverter", "AddTensorConverter", "MeanDimConverter", "AdaptiveAvgPool2dConverter",
37-
"CloneConverter", "AbsConverter"
39+
"CloneConverter", "AbsConverter", "HardTanhConverter"
3840
]
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) 2025 NXP
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from torch.fx import Node
8+
from torch.nn import Parameter
9+
10+
from executorch.backends.nxp.backend.ir.converter.node_converter import NodeConverter, Target
11+
from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import BuiltinOperator
12+
13+
14+
class HardTanhConverter(NodeConverter):
15+
supported_targets = [Target.RT700]
16+
17+
# Maps possible input parameters of HardTanh to equivalent ReLU-based operators supported by TFLite.
18+
supported_modes_map = {
19+
( 0., 6.): BuiltinOperator.RELU6,
20+
(-1., 1.): BuiltinOperator.RELU_N1_TO_1,
21+
( 0., 1.): BuiltinOperator.RELU_0_TO_1,
22+
( 0., float('inf')): BuiltinOperator.RELU,
23+
}
24+
25+
@staticmethod
26+
def _is_supported_in_IR(node: Node, parameters_mapping: dict[str, Parameter]) -> bool:
27+
_, min_value, max_value = node.args
28+
return (min_value, max_value) in HardTanhConverter.supported_modes_map.keys()
29+
30+
def convert(self, node: Node):
31+
""" Convert 'aten::hardtanh' to it's supported ReLU equivalent. """
32+
self.assert_convertible(node)
33+
34+
t_op = self._create_tflite_op_with_io_tensors(node)
35+
36+
_, min_value, max_value = node.args
37+
38+
op = self.supported_modes_map[(min_value, max_value)]
39+
t_op.opcode_index = self.builder.op_code_index_for_op_type(op)
40+
41+
self.builder.append_operators([t_op])

backends/nxp/neutron_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def tag_qdq_clusters(self, nodes: List[torch.fx.Node]):
187187
exir_ops.edge.aten.max_pool2d_with_indices.default: MaxPool2dConverter,
188188
exir_ops.edge.aten.mm.default: MMConverter,
189189
exir_ops.edge.aten.relu.default: ReLUConverter,
190+
exir_ops.edge.aten.hardtanh.default: HardTanhConverter,
190191
exir_ops.edge.aten._softmax.default: SoftmaxConverter,
191192
exir_ops.edge.aten.view_copy.default: ViewCopyConverter,
192193
exir_ops.edge.aten.add.Tensor: AddTensorConverter,

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,56 @@ def partition_types(self):
169169
return [torch.ops.aten.relu_.default]
170170

171171

172+
class HardTanhPattern(QuantizationPattern):
173+
"""
174+
Quantizer for HardTanh operator. Shared quantization spec is selected, as activation functions usually follows
175+
computation layer.
176+
"""
177+
178+
def partition_types(self):
179+
return [torch.ops.aten.hardtanh.default]
180+
181+
def get_anchors(
182+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
183+
) -> PartitionAnchors | None:
184+
node = fused_partition[0].nodes[-1]
185+
186+
return PartitionAnchors(
187+
inputs=[(node, 0)],
188+
weights=[],
189+
biases=[],
190+
output=[(node,)],
191+
)
192+
193+
def replacement_op(self):
194+
assert False
195+
196+
197+
class HardTanhInPlacePattern(QuantizationPattern):
198+
"""
199+
Quantizer for HardTanh operator with param inplace=True. Shared quantization spec is selected, as activation
200+
functions usually follows computation layer.
201+
"""
202+
203+
def partition_types(self):
204+
return [torch.ops.aten.hardtanh_.default]
205+
206+
def get_anchors(
207+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
208+
) -> PartitionAnchors | None:
209+
node = fused_partition[0].nodes[-1]
210+
211+
return PartitionAnchors(
212+
inputs=[(node, 0)],
213+
weights=[],
214+
biases=[],
215+
output=[(node,)],
216+
)
217+
218+
def replacement_op(self):
219+
assert False
220+
221+
172222
class ReshapePattern(SharedSpecPattern):
173223
"""
174224
Quantizer for Reshape operator.
@@ -317,6 +367,8 @@ def __init__(self):
317367
CadenceAtenQuantizer(PermutePattern(), static_qconfig),
318368
CadenceAtenQuantizer(PadPattern(), static_qconfig),
319369
CadenceAtenQuantizer(ReluPattern(), static_qconfig),
370+
CadenceAtenQuantizer(HardTanhPattern(), static_qconfig),
371+
CadenceAtenQuantizer(HardTanhInPlacePattern(), static_qconfig),
320372
CadenceAtenQuantizer(ReluInPlacePattern(), static_qconfig),
321373
CadenceAtenQuantizer(AvgPoolPattern(), static_qconfig),
322374
CadenceAtenQuantizer(ViewPattern(), static_qconfig),

backends/nxp/tests/executorch_pipeline.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]):
2828
return m
2929

3030

31+
def get_random_float_data(input_shapes: tuple[int] | list[tuple[int]]):
32+
#TODO(Lukas): Replace with something more robust.
33+
return (torch.randn(input_shapes),) if type(input_shapes) is tuple \
34+
else tuple(torch.randn(input_shape) for input_shape in input_shapes)
35+
36+
3137
def to_quantized_edge_program(model: torch.nn.Module, input_shapes: tuple[int] | list[tuple[int]],
3238
operators_not_to_delegate: list[str] = None, target="imxrt700",
3339
neutron_converter_flavor="wrapper", remove_quant_io_ops=False)\
@@ -36,9 +42,7 @@ def to_quantized_edge_program(model: torch.nn.Module, input_shapes: tuple[int] |
3642
assert all([isinstance(input_shape, tuple) for input_shape in input_shapes]), ("For multiple inputs, provide"
3743
" list[tuple[int]].")
3844

39-
random_tensors = (torch.randn(input_shapes),) if type(input_shapes) is tuple \
40-
else tuple(torch.randn(input_shape) for input_shape in input_shapes)
41-
calibration_inputs = [random_tensors, random_tensors]
45+
calibration_inputs = [get_random_float_data(input_shapes) for _ in range(4)]
4246
example_input = (torch.ones(input_shapes),) if type(input_shapes) is tuple \
4347
else tuple(torch.ones(input_shape) for input_shape in input_shapes)
4448

backends/nxp/tests/executors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,8 @@ def convert_run_compare(edge_program: ExportedProgram, input_data, rtol=1.e-5, a
267267
return tflite_executor, edge_program_executor
268268

269269

270-
def graph_contains_op(graph: Graph, op: object) -> bool:
271-
return any(map(lambda node: node.target == op, graph.nodes))
270+
def graph_contains_any_of_ops(graph: Graph, ops: list) -> bool:
271+
return any(map(lambda node: node.target in ops, graph.nodes))
272272

273273

274274
class OverrideSupportedTargets:

backends/nxp/tests/exported_program_vizualize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def name_color(string): # pseudo-randomization function
102102
label = ""
103103
if "val" in node.meta:
104104
tensor = node.meta["val"]
105-
if isinstance(tensor, tuple):
105+
if isinstance(tensor, tuple) or isinstance(tensor, list):
106106
tensor = tensor[0] # Fake tensor
107107
label = f" ({list(tensor.shape)} | {tensor.dtype})"
108108

backends/nxp/tests/ir/converter/node_converter/test_abs_converter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from executorch.backends.nxp.backend.edge_program_converter import EdgeProgramToIRConverter
77
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
8-
from executorch.backends.nxp.tests.executors import convert_run_compare, graph_contains_op, ToChannelLastPreprocess, \
8+
from executorch.backends.nxp.tests.executors import convert_run_compare, graph_contains_any_of_ops, ToChannelLastPreprocess, \
99
ToChannelFirstPreprocess
1010
from executorch.exir.dialects._ops import ops as exir_ops
1111

@@ -51,7 +51,7 @@ def test_conv_abs(mocker, input_shape: tuple[int] = (1, 3, 112, 112)):
5151
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
5252
exported_program: ExportedProgram = converter_spy.call_args.args[1]
5353

54-
assert not graph_contains_op(graph=quantized_program.graph, op=exir_ops.edge.aten.abs.default)
54+
assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=[exir_ops.edge.aten.abs.default])
5555

5656
input_data = (np.random.random(input_shape) * 50).astype(np.int8)
5757
convert_run_compare(exported_program,
@@ -72,7 +72,7 @@ def test_abs_only(mocker, input_shape: tuple[int] = (1, 10)):
7272
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
7373
exported_program: ExportedProgram = converter_spy.call_args.args[1]
7474

75-
assert not graph_contains_op(graph=quantized_program.graph, op=exir_ops.edge.aten.abs.default)
75+
assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=[exir_ops.edge.aten.abs.default])
7676

7777
input_data = (np.random.random(input_shape) * 50).astype(np.int8)
7878
convert_run_compare(exported_program,

backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from executorch.backends.nxp.backend.edge_program_converter import EdgeProgramToIRConverter
1414
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
15-
from executorch.backends.nxp.tests.executors import convert_run_compare, graph_contains_op, ToChannelLastPreprocess, \
15+
from executorch.backends.nxp.tests.executors import convert_run_compare, graph_contains_any_of_ops, ToChannelLastPreprocess, \
1616
ToChannelFirstPreprocess
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818

@@ -75,7 +75,7 @@ def test_conv_dropout_quant(mocker, inplace_dropout: bool, input_shape: tuple[in
7575
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
7676
exported_program: ExportedProgram = converter_spy.call_args.args[1]
7777

78-
assert not graph_contains_op(graph=quantized_program.graph, op=exir_ops.edge.aten.clone.default)
78+
assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=[exir_ops.edge.aten.clone.default])
7979

8080
input_data = (np.random.random(input_shape) * 50).astype(np.int8)
8181
convert_run_compare(exported_program,
@@ -97,7 +97,7 @@ def test_clone_pool_view_copy_quant(mocker, inplace_dropout: bool, input_shape:
9797
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
9898
exported_program: ExportedProgram = converter_spy.call_args.args[1]
9999

100-
assert not graph_contains_op(graph=quantized_program.graph, op=exir_ops.edge.aten.clone.default)
100+
assert not graph_contains_any_of_ops(graph=quantized_program.graph, ops=[exir_ops.edge.aten.clone.default])
101101

102102
input_data = (np.random.random(input_shape) * 50).astype(np.int8)
103103
convert_run_compare(exported_program,

0 commit comments

Comments
 (0)