Skip to content

Commit dc8e7ea

Browse files
skywallrobert-kalmar
authored andcommitted
NXP Backend: Add pass to remove IO de/quantize nodes
1 parent 5499f4e commit dc8e7ea

File tree

5 files changed

+230
-0
lines changed

5 files changed

+230
-0
lines changed

backends/nxp/backend/ir/edge_passes/__init__.py

Whitespace-only changes.
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
8+
from executorch.exir import EdgeProgramManager
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass
11+
from executorch.exir.passes.quantize_io_pass import QuantizeInputs, QuantizeOutputs
12+
from torch.fx.passes.infra.pass_base import PassResult
13+
14+
15+
class RemoveIOQuantOpsPass(ExportPass):
16+
17+
def __init__(self, edge_program_manager: EdgeProgramManager):
18+
super().__init__()
19+
self._edge_program_manager = edge_program_manager
20+
21+
def _get_quantizable_input_indices(self):
22+
exported_program = self._edge_program_manager.exported_program()
23+
24+
graph = exported_program.graph_module.graph
25+
user_inputs = exported_program.graph_signature.user_inputs
26+
27+
inputs_to_quantization = []
28+
29+
for input_index, user_input in enumerate(user_inputs):
30+
placeholders = [
31+
n for n in graph.nodes if n.op == "placeholder" and n.name == user_input
32+
]
33+
assert placeholders
34+
target_placeholder = placeholders[0]
35+
36+
if len(target_placeholder.users) != 1:
37+
raise ValueError(f"Input {input_index} has more than one users")
38+
39+
quantize = next(iter(target_placeholder.users))
40+
if (
41+
quantize.target
42+
!= exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
43+
):
44+
continue
45+
46+
inputs_to_quantization.append(input_index)
47+
48+
return inputs_to_quantization
49+
50+
def _get_quantizable_output_indices(self):
51+
exported_program = self._edge_program_manager.exported_program()
52+
53+
graph = exported_program.graph_module.graph
54+
outputs = [n for n in graph.nodes if n.op == "output"]
55+
if len(outputs) != 1:
56+
raise NotImplementedError("Only 1 output node is supported.")
57+
58+
outputs_to_quantization = []
59+
60+
user_outputs = list(outputs[0].args[0])
61+
for output_index, user_output in enumerate(user_outputs):
62+
if (
63+
user_output.target
64+
!= exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
65+
):
66+
continue
67+
68+
outputs_to_quantization.append(output_index)
69+
70+
return outputs_to_quantization
71+
72+
def call(self, graph_module: torch.fx.GraphModule):
73+
input_indices = self._get_quantizable_input_indices()
74+
output_indices = self._get_quantizable_output_indices()
75+
76+
QuantizeInputs(self._edge_program_manager, input_indices).call(graph_module)
77+
QuantizeOutputs(self._edge_program_manager, output_indices).call(graph_module)
78+
79+
return PassResult(graph_module, True)

backends/nxp/tests/executorch_pipeline.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
import torch
77

88
from executorch import exir
9+
from executorch.backends.nxp.backend.ir.edge_passes.remove_io_quant_ops_pass import (
10+
RemoveIOQuantOpsPass,
11+
)
912
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
1013
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
1114
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
@@ -37,6 +40,7 @@ def to_quantized_edge_program(
3740
operators_not_to_delegate: list[str] = None,
3841
target="imxrt700",
3942
neutron_converter_flavor="SDK_25_03",
43+
remove_quant_io_ops=False,
4044
) -> EdgeProgramManager:
4145
if isinstance(input_shapes, list):
4246
assert all(isinstance(input_shape, tuple) for input_shape in input_shapes), (
@@ -77,6 +81,11 @@ def to_quantized_edge_program(
7781
compile_config=EdgeCompileConfig(_check_ir_validity=False),
7882
)
7983

84+
if remove_quant_io_ops:
85+
edge_program_manager = edge_program_manager.transform(
86+
[RemoveIOQuantOpsPass(edge_program_manager=edge_program_manager)]
87+
)
88+
8089
return edge_program_manager
8190

8291

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
import itertools
6+
7+
import executorch.kernels.quantized # noqa F401
8+
import torch
9+
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
10+
from executorch.backends.nxp.tests.models import Conv2dReLUModule
11+
from executorch.examples.nxp.experimental.cifar_net.cifar_net import CifarNet
12+
from executorch.exir import ExecutorchBackendConfig
13+
from executorch.exir.passes.quantize_io_pass import get_config_method_name
14+
15+
16+
def test_remove_io_quant_ops_pass__conv_relu():
17+
model = Conv2dReLUModule()
18+
model.eval()
19+
20+
input_shape = (1, 4, 32, 32)
21+
edge_program_manager = to_quantized_edge_program(
22+
model, input_shape, remove_quant_io_ops=True
23+
)
24+
25+
exec_prog = edge_program_manager.to_executorch(
26+
config=ExecutorchBackendConfig(extract_delegate_segments=False)
27+
)
28+
29+
nodes = list(exec_prog.exported_program().graph.nodes)
30+
assert (
31+
nodes[0].meta["val"].dtype == torch.int8
32+
), "Input tensor doesn't have type INT8."
33+
assert nodes[2].name == "executorch_call_delegate"
34+
assert (
35+
nodes[4].meta["val"][0].dtype == torch.int8
36+
), "Output tensor doesn't have type INT8."
37+
38+
assert (
39+
get_config_method_name(None, "input", 0, "scale") in exec_prog._config_methods
40+
)
41+
assert get_config_method_name(None, "input", 0, "zp") in exec_prog._config_methods
42+
assert (
43+
get_config_method_name(None, "output", 0, "scale") in exec_prog._config_methods
44+
)
45+
assert get_config_method_name(None, "output", 0, "zp") in exec_prog._config_methods
46+
47+
48+
def test_remove_io_quant_ops_pass__cifarnet():
49+
model = CifarNet().get_eager_model()
50+
input_shape = (1, 3, 32, 32)
51+
edge_program_manager = to_quantized_edge_program(
52+
model, input_shape, remove_quant_io_ops=True
53+
)
54+
55+
exec_prog = edge_program_manager.to_executorch(
56+
config=ExecutorchBackendConfig(extract_delegate_segments=False)
57+
)
58+
59+
nodes = list(exec_prog.exported_program().graph.nodes)
60+
assert len(nodes) == 17
61+
assert (
62+
nodes[0].meta["val"].dtype == torch.int8
63+
), "Input tensor doesn't have type INT8."
64+
assert (
65+
nodes[16].meta["val"][0].dtype == torch.int8
66+
), "Output tensor doesn't have type INT8."
67+
68+
assert (
69+
get_config_method_name(None, "input", 0, "scale") in exec_prog._config_methods
70+
)
71+
assert get_config_method_name(None, "input", 0, "zp") in exec_prog._config_methods
72+
assert (
73+
get_config_method_name(None, "output", 0, "scale") in exec_prog._config_methods
74+
)
75+
assert get_config_method_name(None, "output", 0, "zp") in exec_prog._config_methods
76+
77+
78+
class MultiInputOutputModule(torch.nn.Module):
79+
def __init__(self):
80+
super().__init__()
81+
82+
self.conv = torch.nn.Conv2d(4, 64, 2, bias=False)
83+
self.relu = torch.nn.ReLU()
84+
85+
def forward(self, x, y):
86+
z = self.relu(x)
87+
x = self.conv(z)
88+
return x + y, z
89+
90+
91+
def test_multiple_inputs__multiple_outputs():
92+
model = MultiInputOutputModule()
93+
model.eval()
94+
95+
input_shape = [(1, 4, 32, 32), (1, 1, 1, 31)]
96+
edge_program_manager = to_quantized_edge_program(
97+
model, input_shape, remove_quant_io_ops=True
98+
)
99+
100+
exec_prog = edge_program_manager.to_executorch(
101+
config=ExecutorchBackendConfig(extract_delegate_segments=False)
102+
)
103+
104+
nodes = list(exec_prog.exported_program().graph.nodes)
105+
print(nodes)
106+
assert (
107+
nodes[0].meta["val"].dtype == torch.int8
108+
), "Input tensor doesn't have type INT8."
109+
assert nodes[3].name == "executorch_call_delegate"
110+
assert (
111+
nodes[-1].meta["val"][0].dtype == torch.int8
112+
), "Output tensor doesn't have type INT8."
113+
114+
quant_method_variants = itertools.product(
115+
["input", "output"], [0, 1], ["scale", "zp"]
116+
)
117+
118+
expected_methods = [
119+
get_config_method_name(None, arg_type, index, key)
120+
for arg_type, index, key in quant_method_variants
121+
]
122+
assert all(method in exec_prog._config_methods for method in expected_methods)

examples/nxp/aot_neutron_compile.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
import torch
1818

19+
from executorch.backends.nxp.backend.ir.edge_passes.remove_io_quant_ops_pass import (
20+
RemoveIOQuantOpsPass,
21+
)
1922
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
2023
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
2124
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
@@ -191,6 +194,15 @@ def _get_batch_size(data):
191194
default=False,
192195
help="Test the selected model and print the accuracy between 0 and 1.",
193196
)
197+
parser.add_argument(
198+
"-r",
199+
"--remove-quant-io-ops",
200+
action="store_true",
201+
required=False,
202+
default=False,
203+
help="Remove I/O De/Quantize nodes. Model will start to accept quantized "
204+
"inputs and produce quantized outputs.",
205+
)
194206
parser.add_argument(
195207
"--operators_not_to_delegate",
196208
required=False,
@@ -266,6 +278,14 @@ def _get_batch_size(data):
266278
)
267279
logging.debug(f"Exported graph:\n{edge_program.exported_program().graph}")
268280

281+
if args.remove_quant_io_ops:
282+
edge_program = edge_program.transform(
283+
[RemoveIOQuantOpsPass(edge_program_manager=edge_program)]
284+
)
285+
logging.debug(
286+
f"Exported graph (RemoveIOQuantOpsPass):\n{edge_program.exported_program().graph}"
287+
)
288+
269289
# 6. Export to ExecuTorch program
270290
try:
271291
exec_prog = edge_program.to_executorch(

0 commit comments

Comments
 (0)