Skip to content

Commit 8dcca72

Browse files
committed
fix: applied feedback from PR
1 parent 8db6b8a commit 8dcca72

File tree

5 files changed

+29
-6
lines changed

5 files changed

+29
-6
lines changed

backends/nxp/backend/edge_program_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
exir_ops.edge.aten._adaptive_avg_pool2d.default: AdaptiveAvgPool2dConverter, # noqa F405
3131
exir_ops.edge.aten.addmm.default: AddMMConverter, # noqa F405
3232
exir_ops.edge.aten.add.Tensor: AddTensorConverter, # noqa F405
33-
exir_ops.edge.aten.sub.Tensor: SubTensorConverter, # noqa F405
3433
exir_ops.edge.aten.avg_pool2d.default: AvgPool2dConverter, # noqa F405
3534
exir_ops.edge.aten.cat.default: CatConverter, # noqa F405
3635
exir_ops.edge.aten.clone.default: CloneConverter, # noqa F405
@@ -43,6 +42,7 @@
4342
exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405
4443
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
4544
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
45+
exir_ops.edge.aten.sub.Tensor: SubTensorConverter, # noqa F405
4646
exir_ops.edge.aten.tanh.default: TanhConverter, # noqa F405
4747
exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405
4848
exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405

backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,16 @@ def _is_supported_in_IR(
4646
if len(node.args) != 2:
4747
return False
4848

49+
# The `alpha` attribute can be represented by adding an extra `Mul` operator.
50+
# However, this is not implemented as `alpha` is rarely used.
4951
if hasattr(node.kwargs, "alpha"):
5052
return False
5153

5254
return True
5355

5456
# sub.Tensor Node format: (Tensor self, Tensor other, *, Scalar alpha=1)
5557
def convert(self, node: Node):
56-
"""Convert 'sub_tensor' operator to TFLite 'sub'."""
58+
"""Convert 'sub_tensor' operator to NeutronIR 'Sub'."""
5759
self.assert_convertible(node)
5860

5961
t_op = self._create_tflite_op_with_io_tensors(node)

backends/nxp/neutron_partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,6 @@ def tag_qdq_clusters(self, nodes: List[torch.fx.Node]):
194194
exir_ops.edge.aten._adaptive_avg_pool2d.default: AdaptiveAvgPool2dConverter, # noqa F405
195195
exir_ops.edge.aten.addmm.default: AddMMConverter, # noqa F405
196196
exir_ops.edge.aten.add.Tensor: AddTensorConverter, # noqa F405
197-
exir_ops.edge.aten.sub.Tensor: SubTensorConverter, # noqa F405
198197
exir_ops.edge.aten.avg_pool2d.default: AvgPool2dConverter, # noqa F405
199198
exir_ops.edge.aten.cat.default: CatConverter, # noqa F405
200199
exir_ops.edge.aten.clone.default: CloneConverter, # noqa F405
@@ -207,6 +206,7 @@ def tag_qdq_clusters(self, nodes: List[torch.fx.Node]):
207206
exir_ops.edge.aten.mm.default: MMConverter, # noqa F405
208207
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
209208
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
209+
exir_ops.edge.aten.sub.Tensor: SubTensorConverter, # noqa F405
210210
exir_ops.edge.aten.tanh.default: TanhConverter, # noqa F405
211211
exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405
212212
exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405

backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
15
import numpy as np
26
import pytest
37
import torch

backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
SubTensorModule,
1717
SubTensorOneInputModule,
1818
)
19+
from executorch.exir.dialects._ops import ops as exir_ops
1920
from torch.export import ExportedProgram
2021

2122

@@ -48,8 +49,16 @@ def test_sub_tensor_quant_conversion(mocker, input_shape):
4849
# Capture converted program
4950
exported_program: ExportedProgram = converter_spy.call_args.args[1]
5051

51-
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8)
52-
input_data = {0: input_data, 1: input_data}
52+
input_data_1 = (np.random.random(input_shape).astype(np.float32) * 50).astype(
53+
np.int8
54+
)
55+
input_data_2 = (np.random.random(input_shape).astype(np.float32) * 50).astype(
56+
np.int8
57+
)
58+
input_data = {0: input_data_1, 1: input_data_2}
59+
60+
nodes = list(exported_program.graph.nodes)
61+
assert nodes[4].name == "aten_sub_tensor"
5362

5463
convert_run_compare(
5564
exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data
@@ -81,6 +90,9 @@ def test_sub_tensor_one_input_quant_conversion(mocker, input_shape):
8190

8291
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8)
8392

93+
nodes = list(exported_program.graph.nodes)
94+
assert nodes[2].name == "aten_sub_tensor"
95+
8496
convert_run_compare(
8597
exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data
8698
)
@@ -109,6 +121,9 @@ def test_sub_tensor_w_conv_quant_conversion(mocker, input_shape):
109121

110122
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8)
111123

124+
nodes = list(exported_program.graph.nodes)
125+
assert nodes[9].name == "aten_sub_tensor"
126+
112127
convert_run_compare(
113128
exported_program,
114129
input_data,
@@ -142,4 +157,6 @@ def test_sub_tensor_broadcasting_unsupported_quant_conversion(
142157
nodes = list(edge_program.graph.nodes)
143158

144159
# Broadcast is not supported, node is not converted
145-
assert nodes[6].target.__name__ == "aten.sub.Tensor" # Sub Tensor is not delegated.
160+
assert (
161+
nodes[6].target == exir_ops.edge.aten.sub.Tensor
162+
) # Sub Tensor is not delegated.

0 commit comments

Comments
 (0)