Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/nxp/backend/edge_program_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
exir_ops.edge.aten.tanh.default: TanhConverter, # noqa F405
exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405
exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.softmax_converter import (
SoftmaxConverter,
)
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.tanh_converter import (
TanhConverter,
)
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.view_copy_converter import (
ViewCopyConverter,
)
Expand All @@ -76,4 +79,5 @@
"AdaptiveAvgPool2dConverter",
"HardTanhConverter",
"SigmoidConverter",
"TanhConverter",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.backends.nxp.backend.ir.converter.node_converter import NodeConverter
from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import (
BuiltinOperator,
)
from torch.fx import Node
from torch.nn import Parameter


class TanhConverter(NodeConverter):

@staticmethod
def _is_supported_in_IR(
node: Node,
parameters_mapping: dict[str, Parameter],
) -> bool:
return True

def convert(self, node: Node):
self.assert_convertible(node)

t_op = self._create_tflite_op_with_io_tensors(node)
t_op.opcode_index = self.builder.op_code_index_for_op_type(BuiltinOperator.TANH)

self.builder.append_operators([t_op])
1 change: 1 addition & 0 deletions backends/nxp/neutron_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def tag_qdq_clusters(self, nodes: List[torch.fx.Node]):
exir_ops.edge.aten.mm.default: MMConverter, # noqa F405
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
exir_ops.edge.aten.tanh.default: TanhConverter, # noqa F405
exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405
exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405
}
Expand Down
4 changes: 4 additions & 0 deletions backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
SharedSpecPattern,
SigmoidPattern,
SoftMaxPattern,
TanhInPlacePattern,
TanhPattern,
ViewPattern,
)
from executorch.backends.nxp.quantizer.utils import (
Expand Down Expand Up @@ -221,6 +223,8 @@ def __init__(self):
NeutronAtenQuantizer(ReshapePattern(), static_qconfig),
NeutronAtenQuantizer(SigmoidPattern(), static_qconfig),
NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig),
NeutronAtenQuantizer(TanhPattern(), static_qconfig),
NeutronAtenQuantizer(TanhInPlacePattern(), static_qconfig),
NeutronAtenQuantizer(ViewPattern(), static_qconfig),
]
)
Expand Down
100 changes: 72 additions & 28 deletions backends/nxp/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,35 @@ def get_anchors(
)


def get_anchors_for_fixed_quant_specs(
fused_partition: list[fx.GraphModule],
scale: float,
zero_point: int,
quant_min: int = -128,
quant_max: int = 127,
) -> PartitionAnchors:
node = fused_partition[0].nodes[-1]
assert len(fused_partition[0].input_nodes) == 1

qspec = FixedQParamsQuantizationSpec(
dtype=torch.int8,
scale=scale,
zero_point=zero_point,
quant_min=quant_min,
quant_max=quant_max,
qscheme=torch.per_tensor_affine,
)

return PartitionAnchors(
inputs=[(node, 0)],
weights=[],
biases=[],
output=[
(node, qspec),
],
)


class AbsPattern(SharedSpecPattern):
"""
Quantizer for Abs operator.
Expand Down Expand Up @@ -438,31 +467,6 @@ def partition_types(self):
return [torch.ops.aten.view.default]


def get_anchors_for_softmax_like_operators(
fused_partition: List[fx.GraphModule],
) -> PartitionAnchors:
node = fused_partition[0].nodes[-1]
assert len(fused_partition[0].input_nodes) == 1

qspec = FixedQParamsQuantizationSpec(
dtype=torch.int8,
scale=1.0 / 256.0,
zero_point=-128,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
)

return PartitionAnchors(
inputs=[(node, 0)],
weights=[],
biases=[],
output=[
(node, qspec),
],
)


class SoftMaxPattern(QuantizationPattern):
"""
Quantizer for Softmax operator.
Expand All @@ -474,9 +478,47 @@ def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.softmax.int]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors:
return get_anchors_for_softmax_like_operators(fused_partition)
return get_anchors_for_fixed_quant_specs(
fused_partition, scale=1.0 / 256.0, zero_point=-128
)


class TanhPattern(QuantizationPattern):
"""
Quantizer for Tanh operator.

The quantization of Tanh output is fixed to scale 1/128, zero point 0, dtype int8.
"""

def partition_types(self):
return [torch.ops.aten.tanh.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors:
return get_anchors_for_fixed_quant_specs(
fused_partition, scale=1.0 / 128.0, zero_point=0
)


class TanhInPlacePattern(QuantizationPattern):
"""
Quantizer for inplace version of Tanh operator (torch.tanh_).

The quantization of Tanh output is fixed to scale 1/128, zero point 0, dtype int8.
"""

def partition_types(self):
return [torch.ops.aten.tanh_.default]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors:
return get_anchors_for_fixed_quant_specs(
fused_partition, scale=1.0 / 128.0, zero_point=0
)


class SigmoidPattern(QuantizationPattern):
Expand All @@ -492,4 +534,6 @@ def partition_types(self) -> List[OpOverload]:
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
return get_anchors_for_softmax_like_operators(fused_partition)
return get_anchors_for_fixed_quant_specs(
fused_partition, scale=1.0 / 256.0, zero_point=-128
)
2 changes: 2 additions & 0 deletions backends/nxp/run_unittests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ cd $EXECUTORCH_DIR

# '-c /dev/null' is used to ignore root level pytest.ini.
pytest -c /dev/null backends/nxp/tests/

python -m unittest discover -s backends/nxp/tests/ -v
Empty file.
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Copyright 2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import pytest
import torch
Expand All @@ -15,6 +20,7 @@
ToNCHWPreprocess,
ToNHWCPreprocess,
)
from executorch.backends.nxp.tests.models import Conv2dWithActivation
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export import ExportedProgram

Expand All @@ -25,48 +31,14 @@ def reseed_model_per_test_run():
np.random.seed(23)


class Relu6ConvBlock(torch.nn.Module):
def __init__(self, conv_in_channels: int = 3, inplace: bool = False):
super().__init__()
self.block = torch.nn.Sequential(
torch.nn.Conv2d(
in_channels=conv_in_channels, out_channels=64, kernel_size=(4, 4)
),
torch.nn.ReLU6(inplace=inplace),
)

def forward(self, x):
return self.block(x)


class ConvHardTanhBlock(torch.nn.Module):
def __init__(
self,
conv_in_channels: int = 3,
min_act_val: float = -1.0,
max_act_val: float = 1.0,
inplace: bool = False,
):
super().__init__()
self.block = torch.nn.Sequential(
torch.nn.Conv2d(
in_channels=conv_in_channels, out_channels=64, kernel_size=(4, 4)
),
torch.nn.Hardtanh(
min_val=min_act_val, max_val=max_act_val, inplace=inplace
),
)

def forward(self, x):
return self.block(x)


@pytest.mark.parametrize("input_shape", [(1, 3, 128, 128), (1, 3, 256, 256)])
@pytest.mark.parametrize("input_shape", [(1, 3, 128, 128)])
@pytest.mark.parametrize("inplace", [True, False])
def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool):
# The torch.nn.Relu6 inherits from torch.nn.Hardtanh, and hence represented as HardTanh in ATen.
# Testing the hardtanh originated from torch.nn.Relu6 op.
model = Relu6ConvBlock(conv_in_channels=input_shape[1], inplace=inplace)
model = Conv2dWithActivation(
activation=torch.nn.ReLU6(inplace=inplace), in_channels=input_shape[1]
)

converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")

Expand Down Expand Up @@ -100,11 +72,9 @@ def test_custom_hardtanh_quant(
# TODO(13063): This test suffers from non-ideal testing random quantization, because we always use range <0,1>.
# We should update (decrease atol) when the Conv/Linear + Activation fuse at quantization is in place.
min_val, max_val = activation_range
model = ConvHardTanhBlock(
conv_in_channels=input_shape[1],
min_act_val=min_val,
max_act_val=max_val,
inplace=inplace,
model = Conv2dWithActivation(
activation=torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=inplace),
in_channels=input_shape[1],
)

converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2025 NXP
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import kgb
import numpy as np
import torch

from executorch.backends.nxp.nxp_backend import EdgeProgramToIRConverter
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
from executorch.backends.nxp.tests.executors import (
convert_run_compare,
graph_contains_any_of_ops,
ToChannelFirstPreprocess,
ToChannelLastPreprocess,
)
from executorch.backends.nxp.tests.models import Conv2dWithActivation
from executorch.exir.dialects._ops import ops as exir_ops
from parameterized import parameterized
from torch.export import ExportedProgram


class TestTanhConverter(unittest.TestCase):
__test__ = False # Prevent interfering with PyTest tests

@parameterized.expand(
input=[
(
"inplace",
True,
),
(
"not_inplace",
False,
),
]
)
def test_conv_tanh(
self, _: str, inplace: bool, input_shape: tuple[int] = (1, 3, 112, 112)
):
with kgb.spy_on(
EdgeProgramToIRConverter.convert_program, call_original=True
) as converter_spy:
if inplace:
model = Conv2dWithActivation(
activation=torch.tanh_, in_channels=input_shape[1]
)
else:
model = Conv2dWithActivation(
activation=torch.tanh, in_channels=input_shape[1]
)

quantized_program = to_quantized_edge_program(
model, input_shape
).exported_program()
tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value
exported_program: ExportedProgram = converter_spy.calls[-1].args[0]

lowered_module_graph = (
quantized_program.graph_module.lowered_module_0.original_module.graph
)
tanh_ops = [
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.tanh_.default,
]
assert graph_contains_any_of_ops(graph=lowered_module_graph, ops=tanh_ops)

input_data = (np.random.random(input_shape) * 50).astype(np.int8)
convert_run_compare(
exported_program,
tfl_model=tflite_flatbuffers_model,
tflite_input_preprocess=ToChannelLastPreprocess(),
tflite_output_preprocess=ToChannelFirstPreprocess(),
input_data=input_data,
atol=1.0,
)

@classmethod
def setUpClass(cls):
torch.manual_seed(23)
np.random.seed(23)
Empty file.
Loading
Loading