Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/nxp/backend/edge_program_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
exir_ops.edge.aten.sub.Tensor: SubTensorConverter, # noqa F405
exir_ops.edge.aten.tanh.default: TanhConverter, # noqa F405
exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405
exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.softmax_converter import (
SoftmaxConverter,
)
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.sub_tensor_converter import (
SubTensorConverter,
)
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.tanh_converter import (
TanhConverter,
)
Expand All @@ -80,6 +83,7 @@
"MaxPool2dConverter",
"AvgPool2dConverter",
"AddTensorConverter",
"SubTensorConverter",
"CloneConverter",
"AbsConverter",
"AdaptiveAvgPool2dConverter",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.backends.nxp.backend.ir.converter.conversion.common import (
node_uses_shape_broadcasting,
)
from executorch.backends.nxp.backend.ir.converter.node_converter import (
CustomDelegationOptions,
NodeConverter,
)
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
sub_options,
)
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from torch.fx import Node
from torch.nn import Parameter


class SubTensorConverter(NodeConverter):
@staticmethod
def _is_supported_on_target(
node: Node,
neutron_target_spec: NeutronTargetSpec,
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
if node_uses_shape_broadcasting(node):
# Shape broadcasting may require the addition of `Transpose` ops during conversion.
return False

return True

@staticmethod
def _is_supported_in_IR(
node: Node,
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
if len(node.args) != 2:
return False

# The `alpha` attribute can be represented by adding an extra `Mul` operator.
# However, this is not implemented as `alpha` is rarely used.
if hasattr(node.kwargs, "alpha"):
return False

return True

# sub.Tensor Node format: (Tensor self, Tensor other, *, Scalar alpha=1)
def convert(self, node: Node):
"""Convert 'sub_tensor' operator to NeutronIR 'Sub'."""
self.assert_convertible(node)

t_op = self._create_tflite_op_with_io_tensors(node)

t_op.builtin_options = sub_options.Sub()
self.builder.append_operators([t_op])
1 change: 1 addition & 0 deletions backends/nxp/neutron_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def tag_qdq_clusters(self, nodes: list[torch.fx.Node]):
exir_ops.edge.aten.mm.default: MMConverter, # noqa F405
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
exir_ops.edge.aten.sub.Tensor: SubTensorConverter, # noqa F405
exir_ops.edge.aten.tanh.default: TanhConverter, # noqa F405
exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405
exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405
Expand Down
2 changes: 2 additions & 0 deletions backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
SharedSpecPattern,
SigmoidPattern,
SoftMaxPattern,
SubTensorPattern,
TanhInPlacePattern,
TanhPattern,
ViewPattern,
Expand Down Expand Up @@ -208,6 +209,7 @@ def __init__(self):
NeutronAtenQuantizer(ReshapePattern(), static_qconfig),
NeutronAtenQuantizer(SigmoidPattern(), static_qconfig),
NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig),
NeutronAtenQuantizer(SubTensorPattern(), static_qconfig),
NeutronAtenQuantizer(TanhPattern(), static_qconfig),
NeutronAtenQuantizer(TanhInPlacePattern(), static_qconfig),
NeutronAtenQuantizer(ViewPattern(), static_qconfig),
Expand Down
26 changes: 26 additions & 0 deletions backends/nxp/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,32 @@ def get_anchors(
)


class SubTensorPattern(QuantizationPattern):
"""
Quantization pattern for Sub Tensor quantization. Accepts 1 or 2 input nodes.

Basic quantization for all inputs and output.
"""

def partition_types(self) -> list[torch.nn.Module]:
return [torch.ops.aten.sub.Tensor]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors | None:
node = fused_partition[0].nodes[-1]
inputs = [(node, NodeArgsIdx(0))]
if len(fused_partition[0].input_nodes) == 2:
inputs = [(node, NodeArgsIdx(0)), (node, NodeArgsIdx(1))]

return PartitionAnchors(
inputs=inputs,
weights=[],
biases=[],
output=[(node,)],
)


class AvgPoolPattern(SharedSpecPattern):
"""
Quantizer for AvgPool2D operator.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Copyright 2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import pytest
import torch
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Copyright 2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import pytest
import torch

from executorch.backends.nxp.backend.edge_program_converter import (
EdgeProgramToIRConverter,
)
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
from executorch.backends.nxp.tests.executors import (
convert_run_compare,
ToChannelFirstPreprocess,
ToChannelLastPreprocess,
)
from executorch.backends.nxp.tests.models import (
SubTensorConvModule,
SubTensorModule,
SubTensorOneInputModule,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export import ExportedProgram


@pytest.fixture(autouse=True)
def reseed_model_per_test_run():
torch.manual_seed(23)
np.random.seed(23)


@pytest.mark.parametrize(
"input_shape",
[
pytest.param((4,), id="1D."),
pytest.param((6, 6), id="2D."),
pytest.param((1, 4, 8), id="3D."),
pytest.param((1, 4, 8, 8), id="4D."),
],
)
def test_sub_tensor_quant_conversion(mocker, input_shape):
model = SubTensorModule()

converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")

# Run conversion
_ = to_quantized_edge_program(model, [input_shape, input_shape])

# Capture generated model
tflite_flatbuffers_model, io_formats = converter_spy.spy_return

# Capture converted program
exported_program: ExportedProgram = converter_spy.call_args.args[1]

input_data_1 = (np.random.random(input_shape).astype(np.float32) * 50).astype(
np.int8
)
input_data_2 = (np.random.random(input_shape).astype(np.float32) * 50).astype(
np.int8
)
input_data = {0: input_data_1, 1: input_data_2}

nodes = list(exported_program.graph.nodes)
assert nodes[4].target == exir_ops.edge.aten.sub.Tensor

convert_run_compare(
exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data
)


@pytest.mark.parametrize(
"input_shape",
[
pytest.param((4,), id="1D."),
pytest.param((6, 6), id="2D."),
pytest.param((1, 4, 8), id="3D."),
pytest.param((1, 4, 8, 8), id="4D."),
],
)
def test_sub_tensor_one_input_quant_conversion(mocker, input_shape):
model = SubTensorOneInputModule()

converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")

# Run conversion
_ = to_quantized_edge_program(model, input_shape)

# Capture generated model
tflite_flatbuffers_model, io_formats = converter_spy.spy_return

# Capture converted program
exported_program: ExportedProgram = converter_spy.call_args.args[1]

input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8)

nodes = list(exported_program.graph.nodes)
assert nodes[2].target == exir_ops.edge.aten.sub.Tensor

convert_run_compare(
exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data
)


@pytest.mark.parametrize(
"x_input_shape",
[
pytest.param((1, 4, 8, 8), id="4D."),
pytest.param((1, 4, 5, 5), id="4D, product of dims is not a multiple of 8."),
],
)
def test_sub_tensor_w_conv_quant_conversion(mocker, x_input_shape):
model = SubTensorConvModule()

converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")

n, c, h, w = x_input_shape
y_input_shape = (n, 8, h, w)

# Run conversion
_ = to_quantized_edge_program(model, [x_input_shape, y_input_shape])

# Capture generated model
tflite_flatbuffers_model, io_formats = converter_spy.spy_return

# Capture converted program
exported_program: ExportedProgram = converter_spy.call_args.args[1]

input_data_1 = (np.random.random(x_input_shape).astype(np.float32) * 50).astype(
np.int8
)
input_data_2 = (np.random.random(y_input_shape).astype(np.float32) * 50).astype(
np.int8
)
input_data = {0: input_data_1, 1: input_data_2}

nodes = list(exported_program.graph.nodes)
assert nodes[15].target == exir_ops.edge.aten.sub.Tensor

convert_run_compare(
exported_program,
input_data=input_data,
tflite_input_preprocess=ToChannelLastPreprocess(),
tfl_model=tflite_flatbuffers_model,
tflite_output_preprocess=ToChannelFirstPreprocess(),
)


@pytest.mark.parametrize(
"x_input_shape, y_input_shape",
[
pytest.param((1, 4, 7), (4, 7), id="3D -> 2D."),
pytest.param((1, 4, 8), (1, 4, 4, 8), id="3D -> 4D."),
pytest.param((1, 1, 4, 4, 8), (1, 4, 4, 8), id="5D -> 4D."),
pytest.param((4,), (4, 4), id="1D -> 2D."),
pytest.param((4,), (4, 4, 4), id="1D -> 3D."),
pytest.param((6, 6), (1, 8, 6, 6), id="2D -> 4D."),
pytest.param((6, 6), (6,), id="2D -> 1D."),
],
)
def test_sub_tensor_broadcasting_unsupported_quant_conversion(
x_input_shape, y_input_shape
):
model = SubTensorModule()

# Run conversion
edge_program = to_quantized_edge_program(
model, [x_input_shape, y_input_shape]
).exported_program()
nodes = list(edge_program.graph.nodes)

# Broadcast is not supported, node is not converted
assert (
nodes[6].target == exir_ops.edge.aten.sub.Tensor
) # Sub Tensor is not delegated.
28 changes: 28 additions & 0 deletions backends/nxp/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,34 @@ def forward(x):
return x + x


class SubTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()

@staticmethod
def forward(x, y):
return x - y


class SubTensorConvModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = Conv2dModule(padding=1, stride=1)

def forward(self, x, y):
x = self.conv(x)
return x - y


class SubTensorOneInputModule(torch.nn.Module):
def __init__(self):
super().__init__()

@staticmethod
def forward(x):
return x - x


class MeanDimLinearModule(torch.nn.Module):
def __init__(self, dim, keepdim):
super().__init__()
Expand Down
Loading