Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) 2025 NXP
# All rights reserved.
# Copyright 2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -20,6 +19,7 @@
mean_options,
)
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
from torch.fx import Node
from torch.nn import Parameter

Expand All @@ -32,15 +32,33 @@ def _is_supported_on_target(
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
dim = node.args[1]
keepdim = node.args[2] if len(node.args) >= 3 else False
rank = len(node.args[0].meta["val"].shape)
dim = [d - rank if d > 0 else d for d in dim]
dim = [MeanDimConverter._to_pos_dim(d, rank) for d in node.args[1]]

if rank != 4 or not keepdim:
# neutron-converter/src/OperatorC/GlobalAvgPoolPlugin.cpp#74-77
return False

# Only last 2 dimensions (H, W) and keepdim=True with rank=4 are supported on Neutron.
if rank != 4 or dim not in [[-1, -2], [-2, -1]] or not keepdim:
# The `mean.dim` gets converted to AveragePool by the NeutronConverter, so the channels must be a
# multiple of `num_macs`.
# neutron-converter/src/OperatorC/GlobalAvgPoolPlugin.cpp#59-85
num_macs = neutron_target_spec.get_num_macs()
channels_dim = 1 if node.meta[NXP_NODE_FORMAT].is_channels_first() else -1
if (node.meta["val"].shape[channels_dim] % num_macs) != 0:
return False

# Neutron only supports reduction over the spatial dimensions H, W.
if node.meta[NXP_NODE_FORMAT].is_channels_first():
# The input is NCHW. H and W are at indices 2 and 3.
if dim not in [[2, 3], [3, 2]]:
return False
else:
# The input is formatless. It can be considered as NHWC, as this is the way Neutron will look at
# the dimensions. So H and W are the middle dimensions.
if dim not in [[1, 2], [2, 1]]:
return False

return True

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Copyright 2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import pytest
import torch
Expand All @@ -8,10 +13,12 @@
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
from executorch.backends.nxp.tests.executors import (
convert_run_compare,
graph_contains_any_of_ops,
ToChannelFirstPreprocess,
ToChannelLastPreprocess,
)
from executorch.backends.nxp.tests.models import MeanDimConvModule, MeanDimLinearModule
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export import ExportedProgram


Expand All @@ -21,19 +28,36 @@ def reseed_model_per_test_run():
np.random.seed(23)


class MeanDimModule(torch.nn.Module):
def __init__(self, dim, keepdim):
super().__init__()
self.dim = dim
self.keepdim = keepdim

def forward(self, x):
return torch.mean(x, dim=self.dim, keepdim=self.keepdim)


@pytest.mark.parametrize(
"input_shape, dim",
[
pytest.param((1, 4, 8, 8), (-1, -2), id="Dim -1, -2."),
pytest.param((1, 4, 8, 8), (-2, -1), id="Dim -2, -1."),
pytest.param((1, 4, 8, 8), (2, 3), id="Dim 2, 3."),
pytest.param((1, 4, 8, 8), (3, 2), id="Dim 3, 2."),
],
)
def test_mean_dim_conv_quant_conversion(mocker, input_shape, dim, keeepdim=True):
model = MeanDimConvModule(dim, keeepdim)
def test_mean_dim_conv_quant_conversion(mocker, input_shape, dim, keepdim=True):
model = MeanDimConvModule(dim, keepdim)

converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")

# Run conversion
_ = to_quantized_edge_program(model, input_shape)
ep = to_quantized_edge_program(model, input_shape).exported_program()

# Make sure the `mean.dim` was delegated.
assert not graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim])
assert any("lowered_module" in n.name for n in ep.graph.nodes)

# Capture generated model
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
Expand Down Expand Up @@ -61,16 +85,16 @@ def test_mean_dim_conv_quant_conversion(mocker, input_shape, dim, keeepdim=True)
],
)
@pytest.mark.parametrize(
"keeepdim",
"keepdim",
[
pytest.param(False, id="Don't keep dim."),
pytest.param(True, id="Keep dim."),
],
)
def test_mean_dim_linear_unsupported_quant_conversion(
mocker, input_shape, dim, keeepdim
mocker, input_shape, dim, keepdim
):
model = MeanDimLinearModule(dim, keeepdim)
model = MeanDimLinearModule(dim, keepdim)

converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")

Expand Down Expand Up @@ -107,14 +131,14 @@ def test_mean_dim_linear_unsupported_quant_conversion(
],
)
@pytest.mark.parametrize(
"keeepdim",
"keepdim",
[
pytest.param(False, id="Don't keep dim."),
pytest.param(True, id="Keep dim."),
],
)
def test_mean_dim_conv_unsupported_quant_conversion(mocker, input_shape, dim, keeepdim):
model = MeanDimConvModule(dim, keeepdim)
def test_mean_dim_conv_unsupported_quant_conversion(mocker, input_shape, dim, keepdim):
model = MeanDimConvModule(dim, keepdim)

converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")

Expand All @@ -140,3 +164,93 @@ def test_mean_dim_conv_unsupported_quant_conversion(mocker, input_shape, dim, ke
tflite_output_preprocess=ToChannelFirstPreprocess(),
tfl_model=tflite_flatbuffers_model,
)


@pytest.mark.parametrize(
"input_shape, dim",
[
pytest.param((1, 2, 3, 8), (1, 2), id="Dim 1, 2."),
pytest.param((1, 2, 3, 8), (2, 1), id="Dim 2, 1."),
pytest.param((1, 2, 3, 8), (-3, -2), id="Dim -3, -2."),
pytest.param((1, 2, 3, 8), (-2, -3), id="Dim -2, -3."),
],
)
def test_mean_dim__formatless__supported(mocker, input_shape, dim, keepdim=True):
model = MeanDimModule(dim, keepdim)

converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")

ep = to_quantized_edge_program(model, input_shape).exported_program()

# Make sure the `mean.dim` was delegated.
assert not graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim])
assert any("lowered_module" in n.name for n in ep.graph.nodes)

# Capture generated model
tflite_flatbuffers_model, io_formats = converter_spy.spy_return

# Capture converted program
exported_program: ExportedProgram = converter_spy.call_args.args[1]

input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8)

convert_run_compare(
exported_program,
input_data=input_data,
tfl_model=tflite_flatbuffers_model,
atol=1,
)


@pytest.mark.parametrize(
"input_shape, dim",
[
pytest.param((1, 2, 3, 8), (2, 3), id="Dim 2, 3."),
],
)
def test_mean_dim__formatless__unsupported(input_shape, dim, keepdim=True):
model = MeanDimModule(dim, keepdim)

ep = to_quantized_edge_program(model, input_shape).exported_program()

# Make sure the `mean.dim` was NOT delegated.
assert graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim])
assert not any("lowered_module" in n.name for n in ep.graph.nodes)


@pytest.mark.parametrize(
"input_shape, dim",
[
pytest.param(
(1, 8, 8, 4), (1, 2), id="Dim 1, 2 (supported), channels = 4 (unsupported)."
),
],
)
def test_mean_dim__formatless__unsupported_channels(input_shape, dim, keepdim=True):
model = MeanDimModule(dim, keepdim)

ep = to_quantized_edge_program(model, input_shape).exported_program()

# Make sure the `mean.dim` was NOT delegated.
assert graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim])
assert not any("lowered_module" in n.name for n in ep.graph.nodes)


@pytest.mark.parametrize(
"input_shape, dim",
[
pytest.param(
(1, 4, 8, 8), (2, 3), id="Dim 2, 3 (supported), channels = 5 (unsupported)."
),
],
)
def test_mean_dim__channels_first__unsupported_channels(input_shape, dim, keepdim=True):
model = MeanDimConvModule(
dim, keepdim, out_channels=5
) # Only multiples of 8 (num_macs) are supported.

# Run conversion
ep = to_quantized_edge_program(model, input_shape).exported_program()

# Make sure the `mean.dim` was NOT delegated.
assert graph_contains_any_of_ops(ep.graph, [exir_ops.edge.aten.mean.dim])
4 changes: 2 additions & 2 deletions backends/nxp/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,9 +494,9 @@ def forward(self, x):


class MeanDimConvModule(torch.nn.Module):
def __init__(self, dim, keepdim):
def __init__(self, dim, keepdim, out_channels=8):
super().__init__()
self.conv = Conv2dModule(stride=1, padding=1)
self.conv = Conv2dModule(stride=1, padding=1, out_channels=out_channels)
self.dim = dim
self.keepdim = keepdim

Expand Down
Loading