Skip to content

Commit 72376bf

Browse files
roman-janik-nxpMartinPavella
authored andcommitted
NXP backend: Add support for aten_permute_copy_default operator
1 parent 6c2ba75 commit 72376bf

File tree

6 files changed

+218
-38
lines changed

6 files changed

+218
-38
lines changed

backends/nxp/backend/ir/converter/node_converters/ops_converters/permute_copy_converter.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,55 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import numpy as np
7+
import torch
78

89
from executorch.backends.nxp.backend.ir.converter import quantization_utils
910
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList
1011
from executorch.backends.nxp.backend.ir.converter.node_converter import (
1112
CustomDelegationOptions,
13+
NeutronTargetSpec,
1214
NodeConverter,
1315
)
1416
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
1517
transpose_options,
1618
)
19+
from executorch.backends.nxp.backend.neutron_operator_support import (
20+
transposition_is_supported_on_neutron,
21+
)
1722
from torch.fx import Node
1823
from torch.nn import Parameter
1924

2025

26+
def _get_shape(node: torch.fx.Node) -> list[int]:
27+
return list(node.meta["val"].shape)
28+
29+
2130
class PermuteCopyConverter(NodeConverter):
31+
@staticmethod
32+
def _is_supported_on_target(
33+
node: Node,
34+
neutron_target_spec: NeutronTargetSpec,
35+
parameters_mapping: dict[str, Parameter],
36+
custom_delegation_options: CustomDelegationOptions,
37+
) -> bool:
38+
input_shape = _get_shape(node.args[0])
39+
permutation = list(node.args[1])
40+
41+
# TODO Handle tensor formats properly.
42+
43+
return transposition_is_supported_on_neutron(
44+
input_shape, permutation, neutron_target_spec
45+
)
2246

2347
@staticmethod
2448
def _is_supported_in_IR(
2549
node: Node,
2650
parameters_mapping: dict[str, Parameter],
2751
custom_delegation_options: CustomDelegationOptions,
2852
) -> bool:
53+
if not NodeConverter._has_shared_q_params_if_quantized(node):
54+
return False
55+
2956
return True
3057

3158
def convert(self, node: Node):

backends/nxp/neutron_partitioner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,12 +210,13 @@ def tag_qdq_clusters(self, nodes: list[torch.fx.Node]):
210210
exir_ops.edge.aten.max_pool2d_with_indices.default: MaxPool2dConverter, # noqa F405
211211
exir_ops.edge.aten.mean.dim: MeanDimConverter, # noqa F405
212212
exir_ops.edge.aten.mm.default: MMConverter, # noqa F405
213+
exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405
213214
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
214215
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
216+
exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405
215217
exir_ops.edge.aten.sub.Tensor: SubTensorConverter, # noqa F405
216218
exir_ops.edge.aten.tanh.default: TanhConverter, # noqa F405
217219
exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405
218-
exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405
219220
}
220221

221222

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
SubTensorPattern,
4040
TanhInPlacePattern,
4141
TanhPattern,
42+
TransposeIntPattern,
4243
ViewPattern,
4344
)
4445
from executorch.backends.nxp.quantizer.utils import (
@@ -212,6 +213,7 @@ def __init__(self):
212213
NeutronAtenQuantizer(SubTensorPattern(), static_qconfig),
213214
NeutronAtenQuantizer(TanhPattern(), static_qconfig),
214215
NeutronAtenQuantizer(TanhInPlacePattern(), static_qconfig),
216+
NeutronAtenQuantizer(TransposeIntPattern(), static_qconfig),
215217
NeutronAtenQuantizer(ViewPattern(), static_qconfig),
216218
]
217219
)

backends/nxp/quantizer/patterns.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,15 @@ def partition_types(self):
513513
return [torch.ops.aten.permute.default]
514514

515515

516+
class TransposeIntPattern(SharedSpecPattern):
517+
"""
518+
Quantizer for Transpose Int operator.
519+
"""
520+
521+
def partition_types(self) -> list[OpOverload]:
522+
return [torch.ops.aten.transpose.int]
523+
524+
516525
class ReluPattern(SharedSpecPattern):
517526
"""
518527
Quantizer for Relu operator. Shared quantization spec is selected, as ReLU usually follows computation layer.

backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py

Lines changed: 165 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import unittest
7+
8+
import kgb
69
import numpy as np
7-
import pytest
810
import torch
911

1012
from executorch.backends.nxp.backend.edge_program_converter import (
@@ -13,52 +15,187 @@
1315
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
1416
from executorch.backends.nxp.tests.executors import (
1517
convert_run_compare,
16-
ToNCHWPreprocess,
17-
ToNHWCPreprocess,
18+
graph_contains_any_of_ops,
19+
ToChannelFirstPreprocess,
20+
ToChannelLastPreprocess,
1821
)
1922
from executorch.backends.nxp.tests.models import Conv2dModule
23+
from executorch.exir.dialects._ops import ops as exir_ops
24+
from parameterized import parameterized
2025
from torch.export import ExportedProgram
2126

2227

23-
@pytest.fixture(autouse=True)
24-
def reseed_model_per_test_run():
25-
torch.manual_seed(23)
26-
np.random.seed(23)
28+
class Conv2dTransposeModule(torch.nn.Module):
29+
def __init__(self, in_channels: int, dim0: int, dim1: int):
30+
super().__init__()
31+
self.dim0 = dim0
32+
self.dim1 = dim1
33+
self.conv = Conv2dModule(
34+
in_channels=in_channels, out_channels=in_channels, kernel_size=(1, 1)
35+
)
2736

37+
def forward(self, x):
38+
x = self.conv(x)
39+
return torch.transpose(x, self.dim0, self.dim1)
2840

29-
class Conv2dPermuteCopyModule(torch.nn.Module):
30-
def __init__(self, new_dims: tuple[int, ...]):
41+
42+
class Conv2dPermuteModule(torch.nn.Module):
43+
def __init__(self, in_channels: int, new_dims: tuple[int, ...]):
3144
super().__init__()
3245
self.new_dims = new_dims
33-
self.conv = Conv2dModule()
46+
self.conv = Conv2dModule(
47+
in_channels=in_channels,
48+
out_channels=in_channels,
49+
stride=1,
50+
kernel_size=3,
51+
padding=1,
52+
)
3453

3554
def forward(self, x):
3655
x = self.conv(x)
3756
return torch.permute(x, self.new_dims)
3857

3958

40-
def test_permute_copy_quant_conversion__with_bias(mocker):
41-
input_shape = (1, 4, 8, 8)
42-
new_dims = (0, 2, 3, 1)
59+
class LinearPermuteModule(torch.nn.Module):
60+
def __init__(self, in_features: int, new_dims: tuple[int, ...]):
61+
super().__init__()
62+
self.new_dims = new_dims
63+
self.fc = torch.nn.Linear(in_features, in_features)
64+
65+
def forward(self, x):
66+
x = self.fc(x)
67+
return torch.permute(x, self.new_dims)
68+
69+
70+
class TestPermuteCopyConversion(kgb.SpyAgency, unittest.TestCase):
71+
@classmethod
72+
def setUpClass(cls):
73+
torch.manual_seed(23)
74+
np.random.seed(42)
75+
76+
@parameterized.expand(
77+
[
78+
["To channel first permutation", (1, 16, 8, 8), (0, 3, 1, 2)],
79+
["To channel last permutation", (1, 16, 8, 8), (0, 2, 3, 1)],
80+
]
81+
)
82+
def test_permute_copy_conversion__from_permute_4D__quantized(
83+
self, _: str, input_shape, new_dims
84+
):
85+
with kgb.spy_on(
86+
EdgeProgramToIRConverter.convert_program, call_original=True
87+
) as converter_spy:
88+
model = Conv2dPermuteModule(input_shape[1], new_dims)
89+
90+
# Run conversion
91+
edge_program = to_quantized_edge_program(
92+
model, input_shape
93+
).exported_program()
4394

44-
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
95+
# Make sure the `Permute_copy` was delegated.
96+
assert not graph_contains_any_of_ops(
97+
graph=edge_program.graph, ops=[exir_ops.edge.aten.permute_copy.default]
98+
)
99+
assert any(
100+
"lowered_module" in node.name for node in edge_program.graph.nodes
101+
)
45102

46-
# Run conversion
47-
_ = to_quantized_edge_program(Conv2dPermuteCopyModule(new_dims), input_shape)
103+
# Capture generated model
104+
tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value
48105

49-
# Capture generated model
50-
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
106+
# Capture converted program
107+
exported_program: ExportedProgram = converter_spy.calls[-1].args[0]
51108

52-
# Capture converted program
53-
edge_program: ExportedProgram = converter_spy.call_args.args[1]
109+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(
110+
np.int8
111+
)
54112

55-
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8)
113+
convert_run_compare(
114+
exported_program,
115+
input_data,
116+
tfl_model=tflite_flatbuffers_model,
117+
atol=1.0,
118+
tflite_input_preprocess=ToChannelLastPreprocess(),
119+
tflite_output_preprocess=ToChannelFirstPreprocess(),
120+
)
56121

57-
convert_run_compare(
58-
edge_program,
59-
input_data,
60-
tfl_model=tflite_flatbuffers_model,
61-
atol=1.0,
62-
tflite_input_preprocess=ToNHWCPreprocess(),
63-
tflite_output_preprocess=ToNCHWPreprocess(),
122+
@parameterized.expand(
123+
[
124+
["Permutation can be replaced by reshapes", (10, 1, 8), (0, 2, 1)],
125+
["Permutation can be replaced by reshapes", (10, 1, 1), (2, 1, 0)],
126+
["Permutation is identical and can be removed", (10, 1, 8), (0, 1, 2)],
127+
]
64128
)
129+
def test_permute_copy_conversion__from_permute_3D__quantized(
130+
self, _: str, input_shape, new_dims
131+
):
132+
with kgb.spy_on(
133+
EdgeProgramToIRConverter.convert_program, call_original=True
134+
) as converter_spy:
135+
# Run conversion
136+
edge_program = to_quantized_edge_program(
137+
LinearPermuteModule(input_shape[2], new_dims), input_shape
138+
).exported_program()
139+
140+
# Make sure the `Permute_copy` was delegated.
141+
assert not graph_contains_any_of_ops(
142+
graph=edge_program.graph, ops=[exir_ops.edge.aten.permute_copy.default]
143+
)
144+
assert any(
145+
"lowered_module" in node.name for node in edge_program.graph.nodes
146+
)
147+
148+
# Capture generated model
149+
tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value
150+
151+
# Capture converted program
152+
exported_program: ExportedProgram = converter_spy.calls[-1].args[0]
153+
154+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(
155+
np.int8
156+
)
157+
158+
convert_run_compare(
159+
exported_program,
160+
input_data,
161+
tfl_model=tflite_flatbuffers_model,
162+
atol=1.0,
163+
)
164+
165+
@parameterized.expand(
166+
[
167+
["Transpose dims 1 and 2", (1, 16, 8, 8), (0, 2, 1, 3)],
168+
["To (2, 0, 1, 3) permutation", (1, 16, 8, 8), (2, 0, 1, 3)],
169+
["To (3, 1, 2, 0) permutation", (1, 16, 8, 8), (3, 1, 2, 0)],
170+
["To (3, 1, 0, 2) permutation", (1, 16, 8, 8), (3, 1, 0, 2)],
171+
]
172+
)
173+
def test_permute_copy_non_delegated_conversion__from_permute_4D__quantized(
174+
self, _: str, input_shape, new_dims
175+
):
176+
model = Conv2dPermuteModule(input_shape[1], new_dims)
177+
edge_program = to_quantized_edge_program(model, input_shape).exported_program()
178+
179+
nodes = list(edge_program.graph.nodes)
180+
assert len(nodes) == 10
181+
assert (
182+
nodes[6].target == exir_ops.edge.aten.permute_copy.default
183+
) # PermuteCopy not delegated.
184+
185+
@parameterized.expand(
186+
[
187+
["Transpose dims 1 and 2", (1, 16, 8, 8), 1, 2],
188+
["Transpose dims 2 and 3", (1, 16, 8, 8), 2, 3],
189+
]
190+
)
191+
def test_permute_copy_non_delegated_conversion__from_transpose_4D__quantized(
192+
self, _: str, input_shape, dim0, dim1
193+
):
194+
model = Conv2dTransposeModule(input_shape[1], dim0, dim1)
195+
edge_program = to_quantized_edge_program(model, input_shape).exported_program()
196+
197+
nodes = list(edge_program.graph.nodes)
198+
assert len(nodes) == 10
199+
assert (
200+
nodes[6].target == exir_ops.edge.aten.permute_copy.default
201+
) # PermuteCopy not delegated.

0 commit comments

Comments
 (0)