Skip to content

Commit d8c7ee9

Browse files
skywallStrycekSimon
authored andcommitted
NXP backend: Improve partitioning and conversion of Pad op
1 parent 3117b63 commit d8c7ee9

File tree

2 files changed

+42
-14
lines changed

2 files changed

+42
-14
lines changed

backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import numpy as np
1010

1111
from executorch.backends.nxp.backend.edge_helper import input_rank
12-
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList
1312
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
1413
apply_permutation_to,
1514
create_channels_first_to_channels_last_permutation,
@@ -24,6 +23,7 @@
2423
)
2524
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
2625
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
26+
pad_options,
2727
pad_v2_options,
2828
)
2929
from torch.fx import Node
@@ -50,6 +50,10 @@ def _is_supported_in_IR(
5050
if not NodeConverter._has_shared_q_params_if_quantized(node):
5151
return False
5252

53+
if len(paddings) > 4 and paddings[4:6] != [0, 0]:
54+
# Attempt to Pad channels dimension -> currently not supported
55+
return False
56+
5357
return True
5458

5559
# noinspection PyMethodMayBeStatic
@@ -101,6 +105,15 @@ def convert(self, node: Node):
101105
np.asarray(paddings, "int32"), "paddings"
102106
)
103107

108+
if constant == 0.0:
109+
# We're padding with zeros, we can use traditional Pad op
110+
t_op.tmp_inputs = [x, paddings_tensor]
111+
t_op.tmp_outputs = [y]
112+
t_op.builtin_options = pad_options.Pad()
113+
114+
self.builder.append_operators([t_op])
115+
return
116+
104117
if x.quantization is None:
105118
constant_tensor = self.builder.create_tensor_for_data(
106119
np.array([constant], tf_lite_type_to_numpy(x.type)), "constant"
@@ -124,6 +137,4 @@ def convert(self, node: Node):
124137
t_op.tmp_outputs = [y]
125138
t_op.builtin_options = pad_v2_options.PadV2()
126139

127-
ops_to_add = OpsList(middle_op=t_op)
128-
129-
self.builder.append_operators(ops_to_add.flatten())
140+
self.builder.append_operators([t_op])

backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
import pytest
88
import torch
99

10-
from executorch.backends.nxp.tests.executorch_pipeline import to_edge_program
10+
from executorch.backends.nxp.tests.executorch_pipeline import (
11+
to_edge_program,
12+
to_quantized_edge_program,
13+
)
1114
from executorch.backends.nxp.tests.executors import (
1215
convert_run_compare,
1316
ToNCHWPreprocess,
@@ -60,16 +63,10 @@ def test_constant_pad_nd_conversion__default_constant():
6063
pytest.param([2, 4], list(range(4)), id="2D, padding N, H"),
6164
pytest.param([2, 4, 6], list(range(2)), id="3D, padding H"),
6265
pytest.param([2, 4, 6], list(range(4)), id="3D, padding C, H"),
63-
pytest.param([2, 4, 6], list(range(6)), id="3D, padding N, C, H"),
6466
pytest.param([2, 4, 6, 8], list(range(2)), id="4D, padding W"),
6567
pytest.param([2, 4, 6, 8], list(range(4)), id="4D, padding H, W"),
66-
pytest.param([2, 4, 6, 8], list(range(6)), id="4D, padding C, H, W"),
67-
pytest.param([2, 4, 6, 8], list(range(8)), id="4D, padding N, C, H, W"),
6868
pytest.param([1, 2, 3, 4, 5], list(range(2)), id="5D, padding D"),
6969
pytest.param([1, 2, 3, 4, 5], list(range(4)), id="5D, padding W, D"),
70-
pytest.param([1, 2, 3, 4, 5], list(range(6)), id="5D, padding H, W, D"),
71-
pytest.param([1, 2, 3, 4, 5], list(range(8)), id="5D, padding C, H, W, D"),
72-
pytest.param([1, 2, 3, 4, 5], list(range(10)), id="5D, padding N, C, H, W, D"),
7370
],
7471
)
7572
def test_constant_pad_nd_conversion__format_less(input_shape, paddings):
@@ -87,13 +84,12 @@ def test_constant_pad_nd_conversion__format_less(input_shape, paddings):
8784
[
8885
pytest.param([1, 4, 6, 8], list(range(2)), id="4D, padding W"),
8986
pytest.param([1, 4, 6, 8], list(range(4)), id="4D, padding H, W"),
90-
pytest.param([1, 1, 6, 8], [1, 2, 3, 4, 2, 1], id="4D, padding C, H, W"),
91-
# pytest.param([1, 1, 6, 8], [1, 2, 3, 4, 2, 1, 5, 6], id='4D, padding N, C, H, W'), # Batch size must stay 0.
9287
],
9388
)
9489
def test_constant_pad_nd_conversion__channels_first(input_shape, paddings):
90+
model = ConstantPadNDConvModule(paddings)
9591
edge_program = to_edge_program(
96-
ConstantPadNDConvModule(paddings), input_shape
92+
model, input_shape
9793
).exported_program() # Extra `Conv` after the padding.
9894

9995
input_data = np.random.random(input_shape).astype(np.float32)
@@ -104,3 +100,24 @@ def test_constant_pad_nd_conversion__channels_first(input_shape, paddings):
104100
tflite_input_preprocess=ToNHWCPreprocess(),
105101
tflite_output_preprocess=ToNCHWPreprocess(),
106102
)
103+
104+
105+
@pytest.mark.parametrize(
106+
"input_shape, paddings",
107+
[
108+
pytest.param([2, 4, 6], list(range(6)), id="3D, padding N, C, H"),
109+
pytest.param([2, 4, 6, 8], list(range(6)), id="4D, padding C, H, W"),
110+
pytest.param([2, 4, 6, 8], list(range(8)), id="4D, padding N, C, H, W"),
111+
pytest.param([1, 2, 3, 4, 5], list(range(6)), id="5D, padding H, W, D"),
112+
pytest.param([1, 2, 3, 4, 5], list(range(8)), id="5D, padding C, H, W, D"),
113+
pytest.param([1, 2, 3, 4, 5], list(range(10)), id="5D, padding N, C, H, W, D"),
114+
pytest.param([1, 1, 6, 8], [1, 2, 3, 4, 2, 1], id="4D, padding C, H, W"),
115+
],
116+
)
117+
def test_constant_pad_nd__unsupported_paddings(input_shape, paddings):
118+
model = ConstantPadNDModule(paddings)
119+
exec_program = to_quantized_edge_program(model, input_shape).exported_program()
120+
121+
nodes = list(exec_program.graph.nodes)
122+
# There is at least one non-delegated Pad node
123+
assert any(node.name == "aten_constant_pad_nd_default" for node in nodes)

0 commit comments

Comments
 (0)