Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions backends/nxp/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ runtime.python_library(
],
)

runtime.python_library(
name = "_passes",
srcs = glob([
"_passes/*.py",
]),
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir:pass_manager",
],
)

runtime.python_library(
name = "quantizer",
srcs = [
Expand Down Expand Up @@ -65,6 +77,7 @@ runtime.python_library(
deps = [
":neutron_sdk",
":aten_passes",
":_passes",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add neutron_backend build to this https://github.com/pytorch/executorch/blob/main/.ci/scripts/unittest-buck2.sh. Thats the one that was failing last time. This way we get signal on the pr itself

":quantizer",
"fbsource//third-party/pypi/flatbuffers:flatbuffers",
"fbsource//third-party/pypi/ml-dtypes:ml-dtypes",
Expand Down
108 changes: 108 additions & 0 deletions backends/nxp/_passes/remove_getitem_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2025 NXP
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

from executorch.backends.nxp.backend.node_format_inference import (
NodeFormat,
NXP_NODE_FORMAT,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class RemoveGetItemPass(ExportPass):
"""
This remove item is used to remove getitem operator for max_pool2d_with_indices.default operator, and replace it with a single operator,
that extracts the first output. More specifically, we are only getting the first output from aten::maxpool2d operator.
Before Pass:
MaxPool2d ---> GetItem[max_values, max_indexes]
After Pass:
MaxPool2d -> max_values
"""

def call(self, graph_module: torch.fx.GraphModule):
module = graph_module
for node in module.graph.nodes:
if node.op == "call_function":
if (
node.target.__name__ == "aten.max_pool2d_with_indices.default"
or node.target.__name__ == "aten.max.dim"
):
users = list(node.users.keys())

if len(users) != 1:
if len(users) == 2 and node.target.__name__ == "aten.max.dim":
# Two users is allowed for max.dim. For that case,
# rather than removing the getitem node in this
# pass, we handle the getitem nodes in the op's
# visitor when serializing
continue
else:
raise AssertionError(
f"Invalid number of users for {node.target.__name__}: {len(users)}"
)

getitem_node = list(node.users.keys())[0]

if getitem_node.target.__name__ != "getitem":
raise AssertionError(
f"Expected max node's user to be getitem, got {getitem_node.target.__name__}"
)

getitem_index = getitem_node.args[1]

with module.graph.inserting_before(node):
if (
node.target.__name__
== "aten.max_pool2d_with_indices.default"
):
if getitem_index != 0:
raise AssertionError(
f"Expected second argument of getitem node for {node.target.__name__} to be 0, got "
f"{getitem_index}. XNNPACK delegate currently only supports getting just the max "
"values from the op but not getting the corresponding indices."
)
new_max_wd = module.graph.create_node(
"call_function",
exir_ops.edge.aten.max_pool2d.default,
args=node.args,
kwargs=node.kwargs,
)

else:
if getitem_index != 0:
raise AssertionError(
f"Expected second argument of getitem node for {node.target.__name__} to be 0, got "
f"{getitem_index}. XNNPACK delegate currently only supports getting just the max "
"values or getting both the max values and their corresponding indices from the "
"op, but not getting the indices alone."
)
new_max_wd = module.graph.create_node(
"call_function",
exir_ops.edge.aten.amax.default,
args=node.args,
kwargs=node.kwargs,
)

# MODIFIED PART START
# Make sure to preserve the inferred node format.
new_max_wd.meta[NXP_NODE_FORMAT] = node.meta.get(
NXP_NODE_FORMAT, NodeFormat.NONE
)
# MODIFIED PART END

getitem_node.replace_all_uses_with(new_max_wd)

module.graph.erase_node(getitem_node)
module.graph.erase_node(node)

graph_module.recompile()
# Propagate metadata and retrace module
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)
13 changes: 3 additions & 10 deletions backends/nxp/backend/edge_program_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
from torch.nn.parameter import Parameter
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from executorch.backends.nxp.backend.node_format_inference import (
NodeFormat,
NodeFormatInference,
)
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
from executorch.exir.dialects._ops import ops as exir_ops

# noinspection PyProtectedMember
Expand Down Expand Up @@ -76,12 +73,10 @@ def convert_program(
:param custom_delegation_options: Custom user options which affect node delegation.
:return: TFLite flatbuffers as bytes.
"""
node_formats = NodeFormatInference(edge_program).identify_node_formats()
parameters_mapping = self.map_inputs_to_parameters(edge_program)

cc = self.build_conversion_context(
parameters_mapping,
node_formats,
neutron_target_spec,
conversion_config,
custom_delegation_options,
Expand All @@ -108,7 +103,7 @@ def convert_program(
def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContext):
for node in nodes:
if node.op == "placeholder":
node_format = context.node_formats[node]
node_format = node.meta[NXP_NODE_FORMAT]

if node.name in context.parameters_mapping:
# Node is placeholder and has data -> append as static tensor with data
Expand All @@ -121,7 +116,7 @@ def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContex
context.tflite_builder.append_as_fake_tensor(node, node_format)
elif node.op == "call_function":
# Node is call function -> append only output as a tensor
node_format = context.node_formats[node]
node_format = node.meta[NXP_NODE_FORMAT]
context.tflite_builder.append_as_fake_tensor(node, node_format)
elif node.op == "output":
# Nothing to do
Expand Down Expand Up @@ -179,7 +174,6 @@ def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Paramet
@staticmethod
def build_conversion_context(
parameters_mapping: dict,
node_formats: dict[Node, NodeFormat],
neutron_target_spec: NeutronTargetSpec,
conversion_config: ConversionConfig = _default_conversion_config,
custom_delegation_options: CustomDelegationOptions = _default_delegation_options,
Expand All @@ -195,7 +189,6 @@ def build_conversion_context(
tflite_builder,
conversion_config,
parameters_mapping,
node_formats,
custom_delegation_options,
)

Expand Down
5 changes: 0 additions & 5 deletions backends/nxp/backend/ir/conversion_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,20 @@
from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import (
AtenModelBuilderDirector,
)
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
from torch import Node
from torch.nn import Parameter


class ConversionContext:
tflite_builder: AtenModelBuilderDirector
conversion_config: ConversionConfig
parameters_mapping: dict[str, Parameter]
node_formats: dict[Node, NodeFormat]
custom_delegation_options: CustomDelegationOptions

def __init__(
self,
tflite_builder: AtenModelBuilderDirector,
conversion_config: ConversionConfig,
parameters_mapping: dict,
node_formats: dict[Node, NodeFormat],
custom_delegation_options: CustomDelegationOptions,
):
"""
Expand All @@ -39,5 +35,4 @@ def __init__(
self.tflite_builder = tflite_builder
self.conversion_config = conversion_config
self.parameters_mapping = parameters_mapping
self.node_formats = node_formats
self.custom_delegation_options = custom_delegation_options
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from executorch.backends.nxp.backend.ir.converter.conversion import translator
from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
from executorch.backends.nxp.backend.node_format import NodeFormat
from torch.fx import Node
from torch.nn import Parameter

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
CustomDelegationOptions,
)
from executorch.backends.nxp.backend.ir.converter.conversion import translator
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
create_channels_first_to_channels_last_permutation,
)
from executorch.backends.nxp.backend.ir.converter.node_converter import (
_is_dequant_node,
_is_quant_node,
Expand All @@ -18,6 +21,7 @@
Concatenation,
)
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
from torch.fx import Node
from torch.nn import Parameter

Expand Down Expand Up @@ -85,39 +89,48 @@ def _is_supported_on_target(
if dim == 0:
return False

# If all input shapes are equal, the neutron is able to pad the last dimension of inputs and outputs.
input_shapes = [_get_shape(input_) for input_ in node.all_input_nodes]
if input_shapes.count(input_shapes[0]) == len(input_shapes):
if dim == len(input_shapes[0]) - 1:
return True
# Neutron requires the channels to be a multiple of `num_macs`. The channels could either be the second or the
# last dimension, depending on the formats of the node.
if node.meta[NXP_NODE_FORMAT].is_channels_first():
# During conversion to IR, the shape will be permuted to channels last, and the dimension on index
# `1` will end up being the channels (last dim in NHWC).
channels_index = 1
to_nhwc_perm = create_channels_first_to_channels_last_permutation(
len(node.meta["val"].shape), True
)
dim = to_nhwc_perm.index(
dim
) # Make sure the dim points to the NHWC dimension.
else:
# The shape will not be permuted during conversion, so the channels will remain the last dimension.
channels_index = -1

# Neutron requires the channels to be a multiple of numMacs. The channels could either be the second or the
# last dimension, depending on the formats of the node. The format, however, cannot be determined
# during conversion, as it depends on what other nodes are delegated.
input_channels = [
# The second dimension is the channels in PyTorch. If the inputs/output are not channels first, it
# will still be the channels in the IR.
_get_shape(input_)[1]
for input_ in node.all_input_nodes
] + [
# If the inputs/outputs are channels first, the last dimension will be the channels.
_get_shape(input_)[-1]
for input_ in node.all_input_nodes
_get_shape(input_)[channels_index] for input_ in node.all_input_nodes
]
if any(
(input_channel % neutron_target_spec.get_num_macs()) != 0
for input_channel in input_channels
):
output_channels = _get_shape(node)[channels_index]

num_macs = neutron_target_spec.get_num_macs()
input_shapes = [_get_shape(input_) for input_ in node.all_input_nodes]
if any((input_channel % num_macs) != 0 for input_channel in input_channels):
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1492
return False

output_channels = [_get_shape(node)[1], _get_shape(node)[-1]]
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493
if any(
(out_c % neutron_target_spec.get_num_macs()) != 0
for out_c in output_channels
):
return False
# If all input shapes are equal, the neutron is able to pad the last dimension of the inputs.
if not (
input_shapes.count(input_shapes[0]) == len(input_shapes)
and dim == len(input_shapes[0]) - 1
):
return False

if (output_channels % num_macs) != 0:
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493

# If all input shapes are equal, the neutron is able to pad the last dimension of the output.
if not (
input_shapes.count(input_shapes[0]) == len(input_shapes)
and dim == len(input_shapes[0]) - 1
):
return False

if len(node.all_input_nodes) < 2: # Not supported on Neutron
# TODO Try to skip the operator if this case is realistic.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
pad_v2_options,
)
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec

from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
from torch.fx import Node
from torch.nn import Parameter

Expand All @@ -40,9 +42,16 @@ def _is_supported_on_target(
custom_delegation_options: CustomDelegationOptions,
) -> bool:
paddings = node.args[1]
if len(paddings) > 4 and paddings[4:6] != [0, 0]:
# Attempt to Pad channels dimension, which is not supported on Neutron.
return False
if node.meta[NXP_NODE_FORMAT].is_channels_first():
# Dim `1` will end up being the channels. It is padded by paddings[4:6].
if len(paddings) > 4 and paddings[4:6] != [0, 0]:
# Attempt to Pad channels dimension -> currently not supported
return False
else:
# Dim `-1` will end up being the channels. It is padded by paddings[:2].
if len(paddings) > 0 and paddings[:2] != [0, 0]:
# Attempt to Pad channels dimension -> currently not supported
return False

return True

Expand All @@ -65,10 +74,6 @@ def _is_supported_in_IR(
if not NodeConverter._has_shared_q_params_if_quantized(node):
return False

if len(paddings) > 4 and paddings[4:6] != [0, 0]:
# Attempt to Pad channels dimension -> currently not supported
return False

return True

# noinspection PyMethodMayBeStatic
Expand Down
2 changes: 1 addition & 1 deletion backends/nxp/backend/ir/tensor_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#
from enum import Enum

from executorch.backends.nxp.backend.node_format_inference import NodeFormat
from executorch.backends.nxp.backend.node_format import NodeFormat


class TensorFormat(Enum):
Expand Down
Loading
Loading