Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import numpy as np

from executorch.backends.nxp.backend.edge_helper import input_rank
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
apply_permutation_to,
create_channels_first_to_channels_last_permutation,
Expand All @@ -24,6 +23,7 @@
)
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
pad_options,
pad_v2_options,
)
from torch.fx import Node
Expand All @@ -50,6 +50,10 @@ def _is_supported_in_IR(
if not NodeConverter._has_shared_q_params_if_quantized(node):
return False

if len(paddings) > 4 and paddings[4:6] != [0, 0]:
# Attempt to Pad channels dimension -> currently not supported
return False

return True

# noinspection PyMethodMayBeStatic
Expand Down Expand Up @@ -101,6 +105,15 @@ def convert(self, node: Node):
np.asarray(paddings, "int32"), "paddings"
)

if constant == 0.0:
# We're padding with zeros, we can use traditional Pad op
t_op.tmp_inputs = [x, paddings_tensor]
t_op.tmp_outputs = [y]
t_op.builtin_options = pad_options.Pad()

self.builder.append_operators([t_op])
return

if x.quantization is None:
constant_tensor = self.builder.create_tensor_for_data(
np.array([constant], tf_lite_type_to_numpy(x.type)), "constant"
Expand All @@ -124,6 +137,4 @@ def convert(self, node: Node):
t_op.tmp_outputs = [y]
t_op.builtin_options = pad_v2_options.PadV2()

ops_to_add = OpsList(middle_op=t_op)

self.builder.append_operators(ops_to_add.flatten())
self.builder.append_operators([t_op])
Empty file.
79 changes: 79 additions & 0 deletions backends/nxp/backend/ir/edge_passes/remove_io_quant_ops_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

from executorch.exir import EdgeProgramManager
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
from executorch.exir.passes.quantize_io_pass import QuantizeInputs, QuantizeOutputs
from torch.fx.passes.infra.pass_base import PassResult


class RemoveIOQuantOpsPass(ExportPass):

def __init__(self, edge_program_manager: EdgeProgramManager):
super().__init__()
self._edge_program_manager = edge_program_manager

def _get_quantizable_input_indices(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to improve quantize_io_pass and move this utils there if you think they can be useful elsewhere.

exported_program = self._edge_program_manager.exported_program()

graph = exported_program.graph_module.graph
user_inputs = exported_program.graph_signature.user_inputs

inputs_to_quantization = []

for input_index, user_input in enumerate(user_inputs):
placeholders = [
n for n in graph.nodes if n.op == "placeholder" and n.name == user_input
]
assert placeholders
target_placeholder = placeholders[0]

if len(target_placeholder.users) != 1:
raise ValueError(f"Input {input_index} has more than one users")

quantize = next(iter(target_placeholder.users))
if (
quantize.target
!= exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
):
continue

inputs_to_quantization.append(input_index)

return inputs_to_quantization

def _get_quantizable_output_indices(self):
exported_program = self._edge_program_manager.exported_program()

graph = exported_program.graph_module.graph
outputs = [n for n in graph.nodes if n.op == "output"]
if len(outputs) != 1:
raise NotImplementedError("Only 1 output node is supported.")

outputs_to_quantization = []

user_outputs = list(outputs[0].args[0])
for output_index, user_output in enumerate(user_outputs):
if (
user_output.target
!= exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
):
continue

outputs_to_quantization.append(output_index)

return outputs_to_quantization

def call(self, graph_module: torch.fx.GraphModule):
input_indices = self._get_quantizable_input_indices()
output_indices = self._get_quantizable_output_indices()

QuantizeInputs(self._edge_program_manager, input_indices).call(graph_module)
QuantizeOutputs(self._edge_program_manager, output_indices).call(graph_module)

return PassResult(graph_module, True)
23 changes: 23 additions & 0 deletions backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
no_outside_users,
)
from torch import fx
from torch.ao.quantization.quantizer.utils import _annotate_output_qspec
from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver
from torchao.quantization.pt2e.quantizer import (
ComposableQuantizer,
Expand Down Expand Up @@ -237,6 +238,8 @@ def transform_for_annotation(
return pass_runner(model).graph_module

def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
self._annotate_inputs(model)

nodes = list(model.graph.nodes)
for node in nodes:
if (
Expand All @@ -252,5 +255,25 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:

return model

def _is_input_annotated(self, node: fx.Node) -> bool:
return (
"quantization_annotation" in node.meta
and node.meta["quantization_annotation"]._annotated
)

def _mark_input_node_as_annotated(self, node: fx.Node) -> None:
if "quantization_annotation" not in node.meta:
node.meta["quantization_annotation"] = QuantizationAnnotation()
node.meta["quantization_annotation"]._annotated = True

def _annotate_inputs(self, model: fx.GraphModule):
for node in model.graph.nodes:
if self._is_input_annotated(node):
continue

if node.op == "placeholder" and len(node.users) > 0:
_annotate_output_qspec(node, act_qspec)
self._mark_input_node_as_annotated(node)

def validate(self, model: torch.fx.GraphModule) -> None:
return super().validate(model)
2 changes: 1 addition & 1 deletion backends/nxp/run_unittests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ EXECUTORCH_DIR=$(dirname $(dirname $SCRIPT_DIR))
cd $EXECUTORCH_DIR

# '-c /dev/null' is used to ignore root level pytest.ini.
PYTHONPATH=`cd ..; pwd` pytest -c /dev/null backends/nxp/tests/
pytest -c /dev/null backends/nxp/tests/
9 changes: 9 additions & 0 deletions backends/nxp/tests/executorch_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import torch

from executorch import exir
from executorch.backends.nxp.backend.ir.edge_passes.remove_io_quant_ops_pass import (
RemoveIOQuantOpsPass,
)
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
Expand Down Expand Up @@ -37,6 +40,7 @@ def to_quantized_edge_program(
operators_not_to_delegate: list[str] = None,
target="imxrt700",
neutron_converter_flavor="SDK_25_03",
remove_quant_io_ops=False,
) -> EdgeProgramManager:
if isinstance(input_shapes, list):
assert all(isinstance(input_shape, tuple) for input_shape in input_shapes), (
Expand Down Expand Up @@ -77,6 +81,11 @@ def to_quantized_edge_program(
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)

if remove_quant_io_ops:
edge_program_manager = edge_program_manager.transform(
[RemoveIOQuantOpsPass(edge_program_manager=edge_program_manager)]
)

return edge_program_manager


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,10 @@ def test_constant_pad_nd_conversion__default_constant():
pytest.param((2, 4), tuple(range(4)), id="2D, padding N, H"),
pytest.param((2, 4, 6), tuple(range(2)), id="3D, padding H"),
pytest.param((2, 4, 6), tuple(range(4)), id="3D, padding C, H"),
pytest.param((2, 4, 6), list(range(6)), id="3D, padding N, C, H"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why remove these tests?

Copy link
Collaborator

@skywall skywall Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are no longer relevant, because ConstantPad nodes with following params will not be delegated. It is related to this restricstion: https://github.com/pytorch/executorch/pull/12586/files#diff-e01d426046aa644b4e18ffa510b42e50e1b18b8f5407bcfb0d210f701d95b16aR53

We are still able to convert them into intermediate model representation, but Neutron conversion will fail.

pytest.param((2, 4, 6, 8), tuple(range(2)), id="4D, padding W"),
pytest.param((2, 4, 6, 8), tuple(range(4)), id="4D, padding H, W"),
pytest.param((2, 4, 6, 8), list(range(6)), id="4D, padding C, H, W"),
pytest.param((2, 4, 6, 8), list(range(8)), id="4D, padding N, C, H, W"),
pytest.param((1, 2, 3, 4, 5), list(range(2)), id="5D, padding D"),
pytest.param((1, 2, 3, 4, 5), tuple(range(2)), id="5D, padding D"),
pytest.param((1, 2, 3, 4, 5), tuple(range(4)), id="5D, padding W, D"),
pytest.param((1, 2, 3, 4, 5), list(range(6)), id="5D, padding H, W, D"),
pytest.param((1, 2, 3, 4, 5), tuple(range(8)), id="5D, padding C, H, W, D"),
pytest.param((1, 2, 3, 4, 5), list(range(10)), id="5D, padding N, C, H, W, D"),
],
)
def test_constant_pad_nd_conversion__format_less(input_shape, paddings):
Expand All @@ -93,8 +87,9 @@ def test_constant_pad_nd_conversion__format_less(input_shape, paddings):
],
)
def test_constant_pad_nd_conversion__channels_first(input_shape, paddings):
model = ConstantPadNDConvModule(paddings)
edge_program = to_edge_program(
ConstantPadNDConvModule(paddings), input_shape
model, input_shape
).exported_program() # Extra `Conv` after the padding.

input_data = np.random.random(input_shape).astype(np.float32)
Expand Down
122 changes: 122 additions & 0 deletions backends/nxp/tests/ir/edge_passes/test_remove_io_quant_ops_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright 2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import itertools

import executorch.kernels.quantized # noqa F401
import torch
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
from executorch.backends.nxp.tests.models import Conv2dReLUModule
from executorch.examples.nxp.experimental.cifar_net.cifar_net import CifarNet
from executorch.exir import ExecutorchBackendConfig
from executorch.exir.passes.quantize_io_pass import get_config_method_name


def test_remove_io_quant_ops_pass__conv_relu():
model = Conv2dReLUModule()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you are calculating indices do you want to test a model which has >1 inputs and outputs?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test for multi input/output model added.

model.eval()

input_shape = (1, 4, 32, 32)
edge_program_manager = to_quantized_edge_program(
model, input_shape, remove_quant_io_ops=True
)

exec_prog = edge_program_manager.to_executorch(
config=ExecutorchBackendConfig(extract_delegate_segments=False)
)

nodes = list(exec_prog.exported_program().graph.nodes)
assert (
nodes[0].meta["val"].dtype == torch.int8
), "Input tensor doesn't have type INT8."
assert nodes[2].name == "executorch_call_delegate"
assert (
nodes[4].meta["val"][0].dtype == torch.int8
), "Output tensor doesn't have type INT8."

assert (
get_config_method_name(None, "input", 0, "scale") in exec_prog._config_methods
)
assert get_config_method_name(None, "input", 0, "zp") in exec_prog._config_methods
assert (
get_config_method_name(None, "output", 0, "scale") in exec_prog._config_methods
)
assert get_config_method_name(None, "output", 0, "zp") in exec_prog._config_methods


def test_remove_io_quant_ops_pass__cifarnet():
model = CifarNet().get_eager_model()
input_shape = (1, 3, 32, 32)
edge_program_manager = to_quantized_edge_program(
model, input_shape, remove_quant_io_ops=True
)

exec_prog = edge_program_manager.to_executorch(
config=ExecutorchBackendConfig(extract_delegate_segments=False)
)

nodes = list(exec_prog.exported_program().graph.nodes)
assert len(nodes) == 17
assert (
nodes[0].meta["val"].dtype == torch.int8
), "Input tensor doesn't have type INT8."
assert (
nodes[16].meta["val"][0].dtype == torch.int8
), "Output tensor doesn't have type INT8."

assert (
get_config_method_name(None, "input", 0, "scale") in exec_prog._config_methods
)
assert get_config_method_name(None, "input", 0, "zp") in exec_prog._config_methods
assert (
get_config_method_name(None, "output", 0, "scale") in exec_prog._config_methods
)
assert get_config_method_name(None, "output", 0, "zp") in exec_prog._config_methods


class MultiInputOutputModule(torch.nn.Module):
def __init__(self):
super().__init__()

self.conv = torch.nn.Conv2d(4, 64, 2, bias=False)
self.relu = torch.nn.ReLU()

def forward(self, x, y):
z = self.relu(x)
x = self.conv(z)
return x + y, z


def test_multiple_inputs__multiple_outputs():
model = MultiInputOutputModule()
model.eval()

input_shape = [(1, 4, 32, 32), (1, 1, 1, 31)]
edge_program_manager = to_quantized_edge_program(
model, input_shape, remove_quant_io_ops=True
)

exec_prog = edge_program_manager.to_executorch(
config=ExecutorchBackendConfig(extract_delegate_segments=False)
)

nodes = list(exec_prog.exported_program().graph.nodes)
print(nodes)
assert (
nodes[0].meta["val"].dtype == torch.int8
), "Input tensor doesn't have type INT8."
assert nodes[3].name == "executorch_call_delegate"
assert (
nodes[-1].meta["val"][0].dtype == torch.int8
), "Output tensor doesn't have type INT8."

quant_method_variants = itertools.product(
["input", "output"], [0, 1], ["scale", "zp"]
)

expected_methods = [
get_config_method_name(None, arg_type, index, key)
for arg_type, index, key in quant_method_variants
]
assert all(method in exec_prog._config_methods for method in expected_methods)
Loading
Loading