Skip to content

Commit bbaaf9a

Browse files
committed
NXP backend: added aten.sub operator support
1 parent 7e228ee commit bbaaf9a

File tree

8 files changed

+269
-0
lines changed

8 files changed

+269
-0
lines changed

backends/nxp/backend/edge_program_converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
exir_ops.edge.aten._adaptive_avg_pool2d.default: AdaptiveAvgPool2dConverter, # noqa F405
3131
exir_ops.edge.aten.addmm.default: AddMMConverter, # noqa F405
3232
exir_ops.edge.aten.add.Tensor: AddTensorConverter, # noqa F405
33+
exir_ops.edge.aten.sub.Tensor: SubTensorConverter, # noqa F405
3334
exir_ops.edge.aten.avg_pool2d.default: AvgPool2dConverter, # noqa F405
3435
exir_ops.edge.aten.cat.default: CatConverter, # noqa F405
3536
exir_ops.edge.aten.clone.default: CloneConverter, # noqa F405

backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@
5555
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.softmax_converter import (
5656
SoftmaxConverter,
5757
)
58+
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.sub_tensor_converter import (
59+
SubTensorConverter,
60+
)
5861
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.tanh_converter import (
5962
TanhConverter,
6063
)
@@ -78,6 +81,7 @@
7881
"MaxPool2dConverter",
7982
"AvgPool2dConverter",
8083
"AddTensorConverter",
84+
"SubTensorConverter",
8185
"CloneConverter",
8286
"AbsConverter",
8387
"AdaptiveAvgPool2dConverter",
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from executorch.backends.nxp.backend.ir.converter.conversion.common import (
7+
node_uses_shape_broadcasting,
8+
)
9+
from executorch.backends.nxp.backend.ir.converter.node_converter import (
10+
CustomDelegationOptions,
11+
NodeConverter,
12+
Target,
13+
)
14+
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
15+
sub_options,
16+
)
17+
from torch.fx import Node
18+
from torch.nn import Parameter
19+
20+
21+
class SubTensorConverter(NodeConverter):
22+
@staticmethod
23+
def _is_supported_on_target(
24+
node: Node,
25+
target: Target,
26+
parameters_mapping: dict[str, Parameter],
27+
custom_delegation_options: CustomDelegationOptions,
28+
) -> bool:
29+
match target:
30+
case Target.RT700:
31+
if node_uses_shape_broadcasting(node):
32+
# Shape broadcasting may require the addition of `Transpose` ops during conversion.
33+
return False
34+
35+
return True
36+
37+
case _:
38+
return False
39+
40+
@staticmethod
41+
def _is_supported_in_IR(
42+
node: Node,
43+
parameters_mapping: dict[str, Parameter],
44+
custom_delegation_options: CustomDelegationOptions,
45+
) -> bool:
46+
if len(node.args) != 2:
47+
return False
48+
49+
if hasattr(node.kwargs, "alpha"):
50+
return False
51+
52+
return True
53+
54+
# sub.Tensor Node format: (Tensor self, Tensor other, *, Scalar alpha=1)
55+
def convert(self, node: Node):
56+
"""Convert 'sub_tensor' operator to TFLite 'sub'."""
57+
self.assert_convertible(node)
58+
59+
t_op = self._create_tflite_op_with_io_tensors(node)
60+
61+
t_op.builtin_options = sub_options.Sub()
62+
self.builder.append_operators([t_op])

backends/nxp/neutron_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def tag_qdq_clusters(self, nodes: List[torch.fx.Node]):
194194
exir_ops.edge.aten._adaptive_avg_pool2d.default: AdaptiveAvgPool2dConverter, # noqa F405
195195
exir_ops.edge.aten.addmm.default: AddMMConverter, # noqa F405
196196
exir_ops.edge.aten.add.Tensor: AddTensorConverter, # noqa F405
197+
exir_ops.edge.aten.sub.Tensor: SubTensorConverter, # noqa F405
197198
exir_ops.edge.aten.avg_pool2d.default: AvgPool2dConverter, # noqa F405
198199
exir_ops.edge.aten.cat.default: CatConverter, # noqa F405
199200
exir_ops.edge.aten.clone.default: CloneConverter, # noqa F405

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
AdaptiveAvgPoolPattern,
1717
AddmmPattern,
1818
AddTensorPattern,
19+
SubTensorPattern,
1920
AvgPoolPattern,
2021
CatPattern,
2122
Conv1dPattern,
@@ -207,6 +208,7 @@ def __init__(self):
207208
NeutronAtenQuantizer(AbsPattern(), static_qconfig),
208209
NeutronAtenQuantizer(AdaptiveAvgPoolPattern(), static_qconfig),
209210
NeutronAtenQuantizer(AddTensorPattern(), static_qconfig),
211+
NeutronAtenQuantizer(SubTensorPattern(), static_qconfig),
210212
NeutronAtenQuantizer(AddmmPattern(), static_fc_qconfig),
211213
NeutronAtenQuantizer(AvgPoolPattern(), static_qconfig),
212214
NeutronAtenQuantizer(CatPattern(), static_qconfig),

backends/nxp/quantizer/patterns.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,32 @@ def get_anchors(
209209
)
210210

211211

212+
class SubTensorPattern(QuantizationPattern):
213+
"""
214+
Quantization pattern for Sub Tensor quantization. Accepts 1 or 2 input nodes.
215+
216+
Basic quantization for all inputs and output.
217+
"""
218+
219+
def partition_types(self) -> List[Type[torch.nn.Module]]:
220+
return [torch.ops.aten.sub.Tensor]
221+
222+
def get_anchors(
223+
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
224+
) -> PartitionAnchors | None:
225+
node = fused_partition[0].nodes[-1]
226+
inputs = [(node, 0)]
227+
if len(fused_partition[0].input_nodes) == 2:
228+
inputs = [(node, 0), (node, 1)]
229+
230+
return PartitionAnchors(
231+
inputs=inputs,
232+
weights=[],
233+
biases=[],
234+
output=[(node,)],
235+
)
236+
237+
212238
class AvgPoolPattern(SharedSpecPattern):
213239
"""
214240
Quantizer for AvgPool2D operator.
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import numpy as np
2+
import pytest
3+
import torch
4+
5+
from executorch.backends.nxp.backend.edge_program_converter import (
6+
EdgeProgramToIRConverter,
7+
)
8+
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
9+
from executorch.backends.nxp.tests.executors import (
10+
convert_run_compare,
11+
ToChannelFirstPreprocess,
12+
ToChannelLastPreprocess,
13+
)
14+
from executorch.backends.nxp.tests.models import (
15+
SubTensorConvModule,
16+
SubTensorModule,
17+
SubTensorOneInputModule,
18+
)
19+
from torch.export import ExportedProgram
20+
21+
22+
@pytest.fixture(autouse=True)
23+
def reseed_model_per_test_run():
24+
torch.manual_seed(23)
25+
np.random.seed(23)
26+
27+
28+
@pytest.mark.parametrize(
29+
"input_shape",
30+
[
31+
pytest.param((4,), id="1D."),
32+
pytest.param((6, 6), id="2D."),
33+
pytest.param((1, 4, 8), id="3D."),
34+
pytest.param((1, 4, 8, 8), id="4D."),
35+
],
36+
)
37+
def test_sub_tensor_quant_conversion(mocker, input_shape):
38+
model = SubTensorModule()
39+
40+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
41+
42+
# Run conversion
43+
_ = to_quantized_edge_program(model, [input_shape, input_shape])
44+
45+
# Capture generated model
46+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
47+
48+
# Capture converted program
49+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
50+
51+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8)
52+
input_data = {0: input_data, 1: input_data}
53+
54+
convert_run_compare(
55+
exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data
56+
)
57+
58+
59+
@pytest.mark.parametrize(
60+
"input_shape",
61+
[
62+
pytest.param((4,), id="1D."),
63+
pytest.param((6, 6), id="2D."),
64+
pytest.param((1, 4, 8), id="3D."),
65+
pytest.param((1, 4, 8, 8), id="4D."),
66+
],
67+
)
68+
def test_sub_tensor_one_input_quant_conversion(mocker, input_shape):
69+
model = SubTensorOneInputModule()
70+
71+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
72+
73+
# Run conversion
74+
_ = to_quantized_edge_program(model, input_shape)
75+
76+
# Capture generated model
77+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
78+
79+
# Capture converted program
80+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
81+
82+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8)
83+
84+
convert_run_compare(
85+
exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data
86+
)
87+
88+
89+
@pytest.mark.parametrize(
90+
"input_shape",
91+
[
92+
pytest.param((1, 4, 8, 8), id="4D."),
93+
pytest.param((1, 4, 5, 5), id="4D, product of dims is not a multiple of 8."),
94+
],
95+
)
96+
def test_sub_tensor_w_conv_quant_conversion(mocker, input_shape):
97+
model = SubTensorConvModule()
98+
99+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
100+
101+
# Run conversion
102+
_ = to_quantized_edge_program(model, input_shape)
103+
104+
# Capture generated model
105+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
106+
107+
# Capture converted program
108+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
109+
110+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8)
111+
112+
convert_run_compare(
113+
exported_program,
114+
input_data,
115+
tflite_input_preprocess=ToChannelLastPreprocess(),
116+
tfl_model=tflite_flatbuffers_model,
117+
tflite_output_preprocess=ToChannelFirstPreprocess(),
118+
)
119+
120+
121+
@pytest.mark.parametrize(
122+
"x_input_shape, y_input_shape",
123+
[
124+
pytest.param((1, 4, 7), (4, 7), id="3D -> 2D."),
125+
pytest.param((1, 4, 8), (1, 4, 4, 8), id="3D -> 4D."),
126+
pytest.param((1, 1, 4, 4, 8), (1, 4, 4, 8), id="5D -> 4D."),
127+
pytest.param((4,), (4, 4), id="1D -> 2D."),
128+
pytest.param((4,), (4, 4, 4), id="1D -> 3D."),
129+
pytest.param((6, 6), (1, 8, 6, 6), id="2D -> 4D."),
130+
pytest.param((6, 6), (6,), id="2D -> 1D."),
131+
],
132+
)
133+
def test_sub_tensor_broadcasting_unsupported_quant_conversion(
134+
x_input_shape, y_input_shape
135+
):
136+
model = SubTensorModule()
137+
138+
# Run conversion
139+
edge_program = to_quantized_edge_program(
140+
model, [x_input_shape, y_input_shape]
141+
).exported_program()
142+
nodes = list(edge_program.graph.nodes)
143+
144+
# Broadcast is not supported, node is not converted
145+
assert nodes[6].target.__name__ == "aten.sub.Tensor" # Sub Tensor is not delegated.

backends/nxp/tests/models.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,34 @@ def forward(x):
424424
return x + x
425425

426426

427+
class SubTensorModule(torch.nn.Module):
428+
def __init__(self):
429+
super().__init__()
430+
431+
@staticmethod
432+
def forward(x, y):
433+
return x - y
434+
435+
436+
class SubTensorConvModule(torch.nn.Module):
437+
def __init__(self):
438+
super().__init__()
439+
self.conv = Conv2dModule(padding=1, stride=1)
440+
441+
def forward(self, x):
442+
x = self.conv(x)
443+
return x - x
444+
445+
446+
class SubTensorOneInputModule(torch.nn.Module):
447+
def __init__(self):
448+
super().__init__()
449+
450+
@staticmethod
451+
def forward(x):
452+
return x - x
453+
454+
427455
class MeanDimLinearModule(torch.nn.Module):
428456
def __init__(self, dim, keepdim):
429457
super().__init__()

0 commit comments

Comments
 (0)