Skip to content

Commit 392b3b2

Browse files
skywallrobert-kalmar
authored andcommitted
NXP Backend: Add pass to remove IO de/quantize nodes
1 parent afc12eb commit 392b3b2

File tree

5 files changed

+189
-0
lines changed

5 files changed

+189
-0
lines changed

backends/nxp/backend/ir/edge_passes/__init__.py

Whitespace-only changes.
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
8+
from executorch.exir import EdgeProgramManager
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass
11+
from executorch.exir.passes.quantize_io_pass import QuantizeInputs, QuantizeOutputs
12+
from torch.fx.passes.infra.pass_base import PassResult
13+
14+
15+
class RemoveIOQuantOpsPass(ExportPass):
16+
17+
def __init__(self, edge_program_manager: EdgeProgramManager):
18+
super().__init__()
19+
self._edge_program_manager = edge_program_manager
20+
21+
def _get_quantizable_input_indices(self):
22+
exported_program = self._edge_program_manager.exported_program()
23+
24+
graph = exported_program.graph_module.graph
25+
user_inputs = exported_program.graph_signature.user_inputs
26+
27+
inputs_to_quantization = []
28+
29+
for input_index, user_input in enumerate(user_inputs):
30+
placeholders = [
31+
n for n in graph.nodes if n.op == "placeholder" and n.name == user_input
32+
]
33+
assert placeholders
34+
target_placeholder = placeholders[0]
35+
36+
if len(target_placeholder.users) != 1:
37+
raise ValueError(f"Input {input_index} has more than one users")
38+
39+
quantize = next(iter(target_placeholder.users))
40+
if (
41+
quantize.target
42+
!= exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
43+
):
44+
continue
45+
46+
inputs_to_quantization.append(input_index)
47+
48+
return inputs_to_quantization
49+
50+
def _get_quantizable_output_indices(self):
51+
exported_program = self._edge_program_manager.exported_program()
52+
53+
graph = exported_program.graph_module.graph
54+
outputs = [n for n in graph.nodes if n.op == "output"]
55+
if len(outputs) != 1:
56+
raise NotImplementedError("Only 1 output node is supported.")
57+
58+
outputs_to_quantization = []
59+
60+
user_outputs = list(outputs[0].args[0])
61+
for output_index, user_output in enumerate(user_outputs):
62+
if (
63+
user_output.target
64+
!= exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
65+
):
66+
continue
67+
68+
outputs_to_quantization.append(output_index)
69+
70+
return outputs_to_quantization
71+
72+
def call(self, graph_module: torch.fx.GraphModule):
73+
input_indices = self._get_quantizable_input_indices()
74+
output_indices = self._get_quantizable_output_indices()
75+
76+
QuantizeInputs(self._edge_program_manager, input_indices).call(graph_module)
77+
QuantizeOutputs(self._edge_program_manager, output_indices).call(graph_module)
78+
79+
return PassResult(graph_module, True)

backends/nxp/tests/executorch_pipeline.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
import torch
77

88
from executorch import exir
9+
from executorch.backends.nxp.backend.ir.edge_passes.remove_io_quant_ops_pass import (
10+
RemoveIOQuantOpsPass,
11+
)
912
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
1013
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
1114
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
@@ -37,6 +40,7 @@ def to_quantized_edge_program(
3740
operators_not_to_delegate: list[str] = None,
3841
target="imxrt700",
3942
neutron_converter_flavor="SDK_25_03",
43+
remove_quant_io_ops=False,
4044
) -> EdgeProgramManager:
4145
if isinstance(input_shapes, list):
4246
assert all(isinstance(input_shape, tuple) for input_shape in input_shapes), (
@@ -77,6 +81,11 @@ def to_quantized_edge_program(
7781
compile_config=EdgeCompileConfig(_check_ir_validity=False),
7882
)
7983

84+
if remove_quant_io_ops:
85+
edge_program_manager = edge_program_manager.transform(
86+
[RemoveIOQuantOpsPass(edge_program_manager=edge_program_manager)]
87+
)
88+
8089
return edge_program_manager
8190

8291

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import executorch.extension.pybindings.portable_lib
7+
import executorch.kernels.quantized # noqa F401
8+
import torch
9+
10+
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
11+
from executorch.backends.nxp.tests.exported_program_vizualize import (
12+
exported_program_to_dot,
13+
)
14+
from executorch.backends.nxp.tests.models import Conv2dReLUModule
15+
from executorch.examples.nxp.experimental.cifar_net.cifar_net import CifarNet
16+
from executorch.exir import ExecutorchBackendConfig
17+
from executorch.exir.passes.quantize_io_pass import get_config_method_name
18+
19+
20+
def test_remove_io_quant_ops_pass__conv_relu():
21+
model = Conv2dReLUModule()
22+
model.eval()
23+
24+
input_shape = (1, 4, 32, 32)
25+
edge_program_manager = to_quantized_edge_program(
26+
model, input_shape, remove_quant_io_ops=True
27+
)
28+
29+
exec_prog = edge_program_manager.to_executorch(
30+
config=ExecutorchBackendConfig(extract_delegate_segments=False)
31+
)
32+
33+
exported_program_to_dot(exec_prog.exported_program(), "conv_relu.dot")
34+
35+
nodes = list(exec_prog.exported_program().graph.nodes)
36+
assert (
37+
nodes[0].meta["val"].dtype == torch.int8
38+
), "Input tensor doesn't have type INT8."
39+
assert nodes[2].name == "executorch_call_delegate"
40+
assert (
41+
nodes[4].meta["val"][0].dtype == torch.int8
42+
), "Output tensor doesn't have type INT8."
43+
44+
assert (
45+
get_config_method_name(None, "input", 0, "scale") in exec_prog._config_methods
46+
)
47+
assert get_config_method_name(None, "input", 0, "zp") in exec_prog._config_methods
48+
assert (
49+
get_config_method_name(None, "output", 0, "scale") in exec_prog._config_methods
50+
)
51+
assert get_config_method_name(None, "output", 0, "zp") in exec_prog._config_methods
52+
53+
54+
def test_remove_io_quant_ops_pass__cifarnet():
55+
model = CifarNet().get_eager_model()
56+
input_shape = (1, 3, 32, 32)
57+
edge_program_manager = to_quantized_edge_program(
58+
model, input_shape, remove_quant_io_ops=True
59+
)
60+
61+
exec_prog = edge_program_manager.to_executorch(
62+
config=ExecutorchBackendConfig(extract_delegate_segments=False)
63+
)
64+
65+
nodes = list(exec_prog.exported_program().graph.nodes)
66+
assert len(nodes) == 17
67+
assert (
68+
nodes[0].meta["val"].dtype == torch.int8
69+
), "Input tensor doesn't have type INT8."
70+
assert (
71+
nodes[16].meta["val"][0].dtype == torch.int8
72+
), "Output tensor doesn't have type INT8."
73+
74+
assert (
75+
get_config_method_name(None, "input", 0, "scale") in exec_prog._config_methods
76+
)
77+
assert get_config_method_name(None, "input", 0, "zp") in exec_prog._config_methods
78+
assert (
79+
get_config_method_name(None, "output", 0, "scale") in exec_prog._config_methods
80+
)
81+
assert get_config_method_name(None, "output", 0, "zp") in exec_prog._config_methods

examples/nxp/aot_neutron_compile.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
import torch
1818

19+
from executorch.backends.nxp.backend.ir.edge_passes.remove_io_quant_ops_pass import (
20+
RemoveIOQuantOpsPass,
21+
)
1922
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
2023
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
2124
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
@@ -194,6 +197,15 @@ def _get_batch_size(data):
194197
default=False,
195198
help="Test the selected model and print the accuracy between 0 and 1.",
196199
)
200+
parser.add_argument(
201+
"-r",
202+
"--remove-quant-io-ops",
203+
action="store_true",
204+
required=False,
205+
default=False,
206+
help="Remove I/O De/Quantize nodes. Model will start to accept quantized "
207+
"inputs and produce quantized outputs.",
208+
)
197209
parser.add_argument(
198210
"--operators_not_to_delegate",
199211
required=False,
@@ -269,6 +281,14 @@ def _get_batch_size(data):
269281
)
270282
logging.debug(f"Exported graph:\n{edge_program.exported_program().graph}")
271283

284+
if args.remove_quant_io_ops:
285+
edge_program = edge_program.transform(
286+
[RemoveIOQuantOpsPass(edge_program_manager=edge_program)]
287+
)
288+
logging.debug(
289+
f"Exported graph (RemoveIOQuantOpsPass):\n{edge_program.exported_program().graph}"
290+
)
291+
272292
# 6. Export to ExecuTorch program
273293
try:
274294
exec_prog = edge_program.to_executorch(

0 commit comments

Comments
 (0)