Skip to content

Commit 05799c9

Browse files
NXP backend: added aten.sub operator support (pytorch#14514)
### Summary adds support for aten.sub operator ### Test plan tests can be manually run using `pytest -c /dev/null backends/nxp/tests/` --------- Co-authored-by: Martin Pavella <[email protected]>
1 parent 70ea661 commit 05799c9

File tree

9 files changed

+300
-0
lines changed

9 files changed

+300
-0
lines changed

backends/nxp/backend/edge_program_converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405
4444
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
4545
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
46+
exir_ops.edge.aten.sub.Tensor: SubTensorConverter, # noqa F405
4647
exir_ops.edge.aten.tanh.default: TanhConverter, # noqa F405
4748
exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405
4849
exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405

backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@
5656
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.softmax_converter import (
5757
SoftmaxConverter,
5858
)
59+
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.sub_tensor_converter import (
60+
SubTensorConverter,
61+
)
5962
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.tanh_converter import (
6063
TanhConverter,
6164
)
@@ -80,6 +83,7 @@
8083
"MaxPool2dConverter",
8184
"AvgPool2dConverter",
8285
"AddTensorConverter",
86+
"SubTensorConverter",
8387
"CloneConverter",
8488
"AbsConverter",
8589
"AdaptiveAvgPool2dConverter",
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from executorch.backends.nxp.backend.ir.converter.conversion.common import (
7+
node_uses_shape_broadcasting,
8+
)
9+
from executorch.backends.nxp.backend.ir.converter.node_converter import (
10+
CustomDelegationOptions,
11+
NodeConverter,
12+
)
13+
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
14+
sub_options,
15+
)
16+
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
17+
from torch.fx import Node
18+
from torch.nn import Parameter
19+
20+
21+
class SubTensorConverter(NodeConverter):
22+
@staticmethod
23+
def _is_supported_on_target(
24+
node: Node,
25+
neutron_target_spec: NeutronTargetSpec,
26+
parameters_mapping: dict[str, Parameter],
27+
custom_delegation_options: CustomDelegationOptions,
28+
) -> bool:
29+
if node_uses_shape_broadcasting(node):
30+
# Shape broadcasting may require the addition of `Transpose` ops during conversion.
31+
return False
32+
33+
return True
34+
35+
@staticmethod
36+
def _is_supported_in_IR(
37+
node: Node,
38+
parameters_mapping: dict[str, Parameter],
39+
custom_delegation_options: CustomDelegationOptions,
40+
) -> bool:
41+
if len(node.args) != 2:
42+
return False
43+
44+
# The `alpha` attribute can be represented by adding an extra `Mul` operator.
45+
# However, this is not implemented as `alpha` is rarely used.
46+
if hasattr(node.kwargs, "alpha"):
47+
return False
48+
49+
return True
50+
51+
# sub.Tensor Node format: (Tensor self, Tensor other, *, Scalar alpha=1)
52+
def convert(self, node: Node):
53+
"""Convert 'sub_tensor' operator to NeutronIR 'Sub'."""
54+
self.assert_convertible(node)
55+
56+
t_op = self._create_tflite_op_with_io_tensors(node)
57+
58+
t_op.builtin_options = sub_options.Sub()
59+
self.builder.append_operators([t_op])

backends/nxp/neutron_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def tag_qdq_clusters(self, nodes: list[torch.fx.Node]):
210210
exir_ops.edge.aten.mm.default: MMConverter, # noqa F405
211211
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
212212
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
213+
exir_ops.edge.aten.sub.Tensor: SubTensorConverter, # noqa F405
213214
exir_ops.edge.aten.tanh.default: TanhConverter, # noqa F405
214215
exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405
215216
exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
SharedSpecPattern,
3737
SigmoidPattern,
3838
SoftMaxPattern,
39+
SubTensorPattern,
3940
TanhInPlacePattern,
4041
TanhPattern,
4142
ViewPattern,
@@ -208,6 +209,7 @@ def __init__(self):
208209
NeutronAtenQuantizer(ReshapePattern(), static_qconfig),
209210
NeutronAtenQuantizer(SigmoidPattern(), static_qconfig),
210211
NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig),
212+
NeutronAtenQuantizer(SubTensorPattern(), static_qconfig),
211213
NeutronAtenQuantizer(TanhPattern(), static_qconfig),
212214
NeutronAtenQuantizer(TanhInPlacePattern(), static_qconfig),
213215
NeutronAtenQuantizer(ViewPattern(), static_qconfig),

backends/nxp/quantizer/patterns.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,32 @@ def get_anchors(
224224
)
225225

226226

227+
class SubTensorPattern(QuantizationPattern):
228+
"""
229+
Quantization pattern for Sub Tensor quantization. Accepts 1 or 2 input nodes.
230+
231+
Basic quantization for all inputs and output.
232+
"""
233+
234+
def partition_types(self) -> list[torch.nn.Module]:
235+
return [torch.ops.aten.sub.Tensor]
236+
237+
def get_anchors(
238+
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
239+
) -> PartitionAnchors | None:
240+
node = fused_partition[0].nodes[-1]
241+
inputs = [(node, NodeArgsIdx(0))]
242+
if len(fused_partition[0].input_nodes) == 2:
243+
inputs = [(node, NodeArgsIdx(0)), (node, NodeArgsIdx(1))]
244+
245+
return PartitionAnchors(
246+
inputs=inputs,
247+
weights=[],
248+
biases=[],
249+
output=[(node,)],
250+
)
251+
252+
227253
class AvgPoolPattern(SharedSpecPattern):
228254
"""
229255
Quantizer for AvgPool2D operator.

backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
15
import numpy as np
26
import pytest
37
import torch
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
import numpy as np
6+
import pytest
7+
import torch
8+
9+
from executorch.backends.nxp.backend.edge_program_converter import (
10+
EdgeProgramToIRConverter,
11+
)
12+
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
13+
from executorch.backends.nxp.tests.executors import (
14+
convert_run_compare,
15+
ToChannelFirstPreprocess,
16+
ToChannelLastPreprocess,
17+
)
18+
from executorch.backends.nxp.tests.models import (
19+
SubTensorConvModule,
20+
SubTensorModule,
21+
SubTensorOneInputModule,
22+
)
23+
from executorch.exir.dialects._ops import ops as exir_ops
24+
from torch.export import ExportedProgram
25+
26+
27+
@pytest.fixture(autouse=True)
28+
def reseed_model_per_test_run():
29+
torch.manual_seed(23)
30+
np.random.seed(23)
31+
32+
33+
@pytest.mark.parametrize(
34+
"input_shape",
35+
[
36+
pytest.param((4,), id="1D."),
37+
pytest.param((6, 6), id="2D."),
38+
pytest.param((1, 4, 8), id="3D."),
39+
pytest.param((1, 4, 8, 8), id="4D."),
40+
],
41+
)
42+
def test_sub_tensor_quant_conversion(mocker, input_shape):
43+
model = SubTensorModule()
44+
45+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
46+
47+
# Run conversion
48+
_ = to_quantized_edge_program(model, [input_shape, input_shape])
49+
50+
# Capture generated model
51+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
52+
53+
# Capture converted program
54+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
55+
56+
input_data_1 = (np.random.random(input_shape).astype(np.float32) * 50).astype(
57+
np.int8
58+
)
59+
input_data_2 = (np.random.random(input_shape).astype(np.float32) * 50).astype(
60+
np.int8
61+
)
62+
input_data = {0: input_data_1, 1: input_data_2}
63+
64+
nodes = list(exported_program.graph.nodes)
65+
assert nodes[4].target == exir_ops.edge.aten.sub.Tensor
66+
67+
convert_run_compare(
68+
exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data
69+
)
70+
71+
72+
@pytest.mark.parametrize(
73+
"input_shape",
74+
[
75+
pytest.param((4,), id="1D."),
76+
pytest.param((6, 6), id="2D."),
77+
pytest.param((1, 4, 8), id="3D."),
78+
pytest.param((1, 4, 8, 8), id="4D."),
79+
],
80+
)
81+
def test_sub_tensor_one_input_quant_conversion(mocker, input_shape):
82+
model = SubTensorOneInputModule()
83+
84+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
85+
86+
# Run conversion
87+
_ = to_quantized_edge_program(model, input_shape)
88+
89+
# Capture generated model
90+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
91+
92+
# Capture converted program
93+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
94+
95+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8)
96+
97+
nodes = list(exported_program.graph.nodes)
98+
assert nodes[2].target == exir_ops.edge.aten.sub.Tensor
99+
100+
convert_run_compare(
101+
exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data
102+
)
103+
104+
105+
@pytest.mark.parametrize(
106+
"x_input_shape",
107+
[
108+
pytest.param((1, 4, 8, 8), id="4D."),
109+
pytest.param((1, 4, 5, 5), id="4D, product of dims is not a multiple of 8."),
110+
],
111+
)
112+
def test_sub_tensor_w_conv_quant_conversion(mocker, x_input_shape):
113+
model = SubTensorConvModule()
114+
115+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
116+
117+
n, c, h, w = x_input_shape
118+
y_input_shape = (n, 8, h, w)
119+
120+
# Run conversion
121+
_ = to_quantized_edge_program(model, [x_input_shape, y_input_shape])
122+
123+
# Capture generated model
124+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
125+
126+
# Capture converted program
127+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
128+
129+
input_data_1 = (np.random.random(x_input_shape).astype(np.float32) * 50).astype(
130+
np.int8
131+
)
132+
input_data_2 = (np.random.random(y_input_shape).astype(np.float32) * 50).astype(
133+
np.int8
134+
)
135+
input_data = {0: input_data_1, 1: input_data_2}
136+
137+
nodes = list(exported_program.graph.nodes)
138+
assert nodes[15].target == exir_ops.edge.aten.sub.Tensor
139+
140+
convert_run_compare(
141+
exported_program,
142+
input_data=input_data,
143+
tflite_input_preprocess=ToChannelLastPreprocess(),
144+
tfl_model=tflite_flatbuffers_model,
145+
tflite_output_preprocess=ToChannelFirstPreprocess(),
146+
)
147+
148+
149+
@pytest.mark.parametrize(
150+
"x_input_shape, y_input_shape",
151+
[
152+
pytest.param((1, 4, 7), (4, 7), id="3D -> 2D."),
153+
pytest.param((1, 4, 8), (1, 4, 4, 8), id="3D -> 4D."),
154+
pytest.param((1, 1, 4, 4, 8), (1, 4, 4, 8), id="5D -> 4D."),
155+
pytest.param((4,), (4, 4), id="1D -> 2D."),
156+
pytest.param((4,), (4, 4, 4), id="1D -> 3D."),
157+
pytest.param((6, 6), (1, 8, 6, 6), id="2D -> 4D."),
158+
pytest.param((6, 6), (6,), id="2D -> 1D."),
159+
],
160+
)
161+
def test_sub_tensor_broadcasting_unsupported_quant_conversion(
162+
x_input_shape, y_input_shape
163+
):
164+
model = SubTensorModule()
165+
166+
# Run conversion
167+
edge_program = to_quantized_edge_program(
168+
model, [x_input_shape, y_input_shape]
169+
).exported_program()
170+
nodes = list(edge_program.graph.nodes)
171+
172+
# Broadcast is not supported, node is not converted
173+
assert (
174+
nodes[6].target == exir_ops.edge.aten.sub.Tensor
175+
) # Sub Tensor is not delegated.

backends/nxp/tests/models.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,34 @@ def forward(x):
451451
return x + x
452452

453453

454+
class SubTensorModule(torch.nn.Module):
455+
def __init__(self):
456+
super().__init__()
457+
458+
@staticmethod
459+
def forward(x, y):
460+
return x - y
461+
462+
463+
class SubTensorConvModule(torch.nn.Module):
464+
def __init__(self):
465+
super().__init__()
466+
self.conv = Conv2dModule(padding=1, stride=1)
467+
468+
def forward(self, x, y):
469+
x = self.conv(x)
470+
return x - y
471+
472+
473+
class SubTensorOneInputModule(torch.nn.Module):
474+
def __init__(self):
475+
super().__init__()
476+
477+
@staticmethod
478+
def forward(x):
479+
return x - x
480+
481+
454482
class MeanDimLinearModule(torch.nn.Module):
455483
def __init__(self, dim, keepdim):
456484
super().__init__()

0 commit comments

Comments
 (0)