Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions backends/nxp/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ runtime.python_library(
],
)

runtime.python_library(
name = "_passes",
srcs = glob([
"_passes/*.py",
]),
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir:pass_manager",
],
)

runtime.python_library(
name = "quantizer",
srcs = [
Expand Down Expand Up @@ -65,6 +77,7 @@ runtime.python_library(
deps = [
":neutron_sdk",
":aten_passes",
":_passes",
":quantizer",
"fbsource//third-party/pypi/flatbuffers:flatbuffers",
"fbsource//third-party/pypi/ml-dtypes:ml-dtypes",
Expand Down
103 changes: 103 additions & 0 deletions backends/nxp/_passes/remove_getitem_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2025 NXP
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

from executorch.backends.nxp.backend.node_format_inference import (
NodeFormat,
NXP_NODE_FORMAT,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class RemoveGetItemPass(ExportPass):
"""
This remove item is used to remove getitem operator for max_pool2d_with_indices.default operator, and replace it with a single operator,
that extracts the first output. More specifically, we are only getting the first output from aten::maxpool2d operator.
Before Pass:
MaxPool2d ---> GetItem[max_values, max_indexes]
After Pass:
MaxPool2d -> max_values
"""

def call(self, graph_module: torch.fx.GraphModule):
module = graph_module
for node in module.graph.nodes:
if node.op == "call_function":
if (
node.target.__name__ == "aten.max_pool2d_with_indices.default"
or node.target.__name__ == "aten.max.dim"
):
users = list(node.users.keys())

if len(users) != 1:
if len(users) == 2 and node.target.__name__ == "aten.max.dim":
# Two users is allowed for max.dim. For that case,
# rather than removing the getitem node in this
# pass, we handle the getitem nodes in the op's
# visitor when serializing
continue
else:
raise AssertionError(
f"Invalid number of users for {node.target.__name__}: {len(users)}"
)

getitem_node = list(node.users.keys())[0]

if getitem_node.target.__name__ != "getitem":
raise AssertionError(
f"Expected max node's user to be getitem, got {getitem_node.target.__name__}"
)

getitem_index = getitem_node.args[1]

with module.graph.inserting_before(node):
if (
node.target.__name__
== "aten.max_pool2d_with_indices.default"
):
if getitem_index != 0:
raise AssertionError(
f"Expected second argument of getitem node for {node.target.__name__} to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values from the op but not getting the corresponding indices."
)
new_max_wd = module.graph.create_node(
"call_function",
exir_ops.edge.aten.max_pool2d.default,
args=node.args,
kwargs=node.kwargs,
)

else:
if getitem_index != 0:
raise AssertionError(
f"Expected second argument of getitem node for {node.target.__name__} to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values or getting both the max values and their corresponding indices from the op, but not getting the indices alone."
)
new_max_wd = module.graph.create_node(
"call_function",
exir_ops.edge.aten.amax.default,
args=node.args,
kwargs=node.kwargs,
)

# MODIFIED PART START
# Make sure to preserve the inferred node format.
new_max_wd.meta[NXP_NODE_FORMAT] = node.meta.get(
NXP_NODE_FORMAT, NodeFormat.NONE
)
# MODIFIED PART END

getitem_node.replace_all_uses_with(new_max_wd)

module.graph.erase_node(getitem_node)
module.graph.erase_node(node)

graph_module.recompile()
# Propagate metadata and retrace module
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)
28 changes: 12 additions & 16 deletions backends/nxp/backend/edge_program_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
from torch.nn.parameter import Parameter
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from executorch.backends.nxp.backend.node_format_inference import (
NodeFormat,
NodeFormatInference,
)
from executorch.backends.nxp.backend.node_format import NodeFormat, NXP_NODE_FORMAT
from executorch.exir.dialects._ops import ops as exir_ops

# noinspection PyProtectedMember
Expand Down Expand Up @@ -66,7 +63,7 @@ def convert_program(
conversion_config: ConversionConfig = _default_conversion_config,
neutron_target_spec: NeutronTargetSpec = _default_target_spec,
custom_delegation_options: CustomDelegationOptions = _default_delegation_options,
) -> (bytes, dict):
) -> (bytes, dict[str, NodeFormat]):
"""
Convert ExportedProgram in Edge dialect to IR (TFLite flatbuffers) as bytes.

Expand All @@ -76,12 +73,10 @@ def convert_program(
:param custom_delegation_options: Custom user options which affect node delegation.
:return: TFLite flatbuffers as bytes.
"""
node_formats = NodeFormatInference(edge_program).identify_node_formats()
parameters_mapping = self.map_inputs_to_parameters(edge_program)

cc = self.build_conversion_context(
parameters_mapping,
node_formats,
neutron_target_spec,
conversion_config,
custom_delegation_options,
Expand All @@ -92,13 +87,16 @@ def convert_program(
self._convert_qdq_cluster_q_dq_nodes(edge_program.graph.nodes, cc)
self._process_nodes(edge_program.graph.nodes, cc)

# Assign output
io_formats = cc.tflite_builder.assign_model_io_to_subgraph_and_get_io_formats(
edge_program.graph_signature
)
# Assign the model its inputs and outputs.
cc.tflite_builder.assign_model_io_to_subgraph(edge_program.graph_signature)

# TFLite model generation
# Apply optimizations and finalize the model.
internal_tflite_model = cc.tflite_builder.finish()

# Extract the formats of the model's inputs and outputs.
io_formats = cc.tflite_builder.get_io_formats(edge_program.graph_signature)

# TFLite model generation
flatbuffers_builder = flatbuffers.Builder()
internal_tflite_model.gen_tflite(flatbuffers_builder)

Expand All @@ -108,7 +106,7 @@ def convert_program(
def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContext):
for node in nodes:
if node.op == "placeholder":
node_format = context.node_formats[node]
node_format = node.meta[NXP_NODE_FORMAT]

if node.name in context.parameters_mapping:
# Node is placeholder and has data -> append as static tensor with data
Expand All @@ -121,7 +119,7 @@ def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContex
context.tflite_builder.append_as_fake_tensor(node, node_format)
elif node.op == "call_function":
# Node is call function -> append only output as a tensor
node_format = context.node_formats[node]
node_format = node.meta[NXP_NODE_FORMAT]
context.tflite_builder.append_as_fake_tensor(node, node_format)
elif node.op == "output":
# Nothing to do
Expand Down Expand Up @@ -179,7 +177,6 @@ def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Paramet
@staticmethod
def build_conversion_context(
parameters_mapping: dict,
node_formats: dict[Node, NodeFormat],
neutron_target_spec: NeutronTargetSpec,
conversion_config: ConversionConfig = _default_conversion_config,
custom_delegation_options: CustomDelegationOptions = _default_delegation_options,
Expand All @@ -195,7 +192,6 @@ def build_conversion_context(
tflite_builder,
conversion_config,
parameters_mapping,
node_formats,
custom_delegation_options,
)

Expand Down
2 changes: 1 addition & 1 deletion backends/nxp/backend/ir/conversion_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, args: dict | None = None):

:param args: Optional dictionary with conversion arguments. Unknown arguments are ignored.
"""
self.keep_io_format: bool = False
self.use_neutron_for_format_conversion: bool = False
self.allow_inputs_stripping: bool = True
self.qdq_aware_conversion: bool = True
self.symbolic_dimensions_mapping: dict[str, int] | None = None
Expand Down
5 changes: 0 additions & 5 deletions backends/nxp/backend/ir/conversion_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,20 @@
from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import (
AtenModelBuilderDirector,
)
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
from torch import Node
from torch.nn import Parameter


class ConversionContext:
tflite_builder: AtenModelBuilderDirector
conversion_config: ConversionConfig
parameters_mapping: dict[str, Parameter]
node_formats: dict[Node, NodeFormat]
custom_delegation_options: CustomDelegationOptions

def __init__(
self,
tflite_builder: AtenModelBuilderDirector,
conversion_config: ConversionConfig,
parameters_mapping: dict,
node_formats: dict[Node, NodeFormat],
custom_delegation_options: CustomDelegationOptions,
):
"""
Expand All @@ -39,5 +35,4 @@ def __init__(
self.tflite_builder = tflite_builder
self.conversion_config = conversion_config
self.parameters_mapping = parameters_mapping
self.node_formats = node_formats
self.custom_delegation_options = custom_delegation_options
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from executorch.backends.nxp.backend.ir.converter.conversion import translator
from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
from executorch.backends.nxp.backend.node_format import NodeFormat
from torch.fx import Node
from torch.nn import Parameter

Expand Down Expand Up @@ -88,19 +88,40 @@ def append_operators(self, ops_to_add: list[tflite_model.Operator]):

self.check_and_append_operator(op)

def assign_model_io_to_subgraph_and_get_io_formats(
self, graph_signature
) -> dict[str, dict]:
"""
Assign model's inputs/outputs to SubGraph.
def get_io_formats(self, graph_signature) -> dict[str, dict[str, TensorFormat]]:
"""Get a mapping from tensor names to their formats.

:param graph_signature: Instance of GraphSignature.
:param graph_signature: Instance of GraphSignature.
:returns: Mapping between IO tensors' names and their formats.
"""
io_formats = {
"inputs": {},
"outputs": {},
}
for input_name in graph_signature.user_inputs:
tensor = self.tensor_for_name(input_name)
assert input_name == tensor.name, (
"Program's input name doesn't match with tensor name in TFLite. "
"Input was probably redirected."
)
io_formats["inputs"][tensor.name] = tensor.tensor_format

for output_name in graph_signature.user_outputs:
tensor = self.tensor_for_name(output_name)
assert output_name == tensor.name, (
"Program's output name doesn't match with tensor name in TFLite. "
"Output was probably redirected."
)
io_formats["outputs"][tensor.name] = tensor.tensor_format

return io_formats

def assign_model_io_to_subgraph(self, graph_signature):
"""
Assign model's inputs/outputs to SubGraph.

:param graph_signature: Instance of GraphSignature.
"""

self.get_sub_graph().inputs = tflite_model.SubGraphInputs()
for input_name in graph_signature.user_inputs:
Expand All @@ -110,7 +131,6 @@ def assign_model_io_to_subgraph_and_get_io_formats(
"Input was probably redirected."
)
self.get_sub_graph().inputs.tmp_inputs.append(tensor)
io_formats["inputs"][tensor.name] = tensor.tensor_format

self.get_sub_graph().outputs = tflite_model.SubGraphOutputs()
for output_name in graph_signature.user_outputs:
Expand All @@ -120,7 +140,3 @@ def assign_model_io_to_subgraph_and_get_io_formats(
"Output was probably redirected."
)
self.get_sub_graph().outputs.tmp_outputs.append(tensor)

io_formats["outputs"][tensor.name] = tensor.tensor_format

return io_formats
Loading
Loading