Skip to content

Commit b6ee7e0

Browse files
skywallrobert-kalmar
authored andcommitted
NXP backend: Improve partitioning and conversion of Pad op
1 parent 2087aa4 commit b6ee7e0

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import numpy as np
1010

1111
from executorch.backends.nxp.backend.edge_helper import input_rank
12-
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList
1312
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
1413
apply_permutation_to,
1514
create_channels_first_to_channels_last_permutation,
@@ -24,6 +23,7 @@
2423
)
2524
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
2625
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
26+
pad_options,
2727
pad_v2_options,
2828
)
2929
from torch.fx import Node
@@ -50,6 +50,10 @@ def _is_supported_in_IR(
5050
if not NodeConverter._has_shared_q_params_if_quantized(node):
5151
return False
5252

53+
if len(paddings) > 4 and paddings[4:6] != [0, 0]:
54+
# Attempt to Pad channels dimension -> currently not supported
55+
return False
56+
5357
return True
5458

5559
# noinspection PyMethodMayBeStatic
@@ -101,6 +105,15 @@ def convert(self, node: Node):
101105
np.asarray(paddings, "int32"), "paddings"
102106
)
103107

108+
if constant == 0.0:
109+
# We're padding with zeros, we can use traditional Pad op
110+
t_op.tmp_inputs = [x, paddings_tensor]
111+
t_op.tmp_outputs = [y]
112+
t_op.builtin_options = pad_options.Pad()
113+
114+
self.builder.append_operators([t_op])
115+
return
116+
104117
if x.quantization is None:
105118
constant_tensor = self.builder.create_tensor_for_data(
106119
np.array([constant], tf_lite_type_to_numpy(x.type)), "constant"
@@ -124,6 +137,4 @@ def convert(self, node: Node):
124137
t_op.tmp_outputs = [y]
125138
t_op.builtin_options = pad_v2_options.PadV2()
126139

127-
ops_to_add = OpsList(middle_op=t_op)
128-
129-
self.builder.append_operators(ops_to_add.flatten())
140+
self.builder.append_operators([t_op])

backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,10 @@ def test_constant_pad_nd_conversion__default_constant():
6363
pytest.param((2, 4), tuple(range(4)), id="2D, padding N, H"),
6464
pytest.param((2, 4, 6), tuple(range(2)), id="3D, padding H"),
6565
pytest.param((2, 4, 6), tuple(range(4)), id="3D, padding C, H"),
66-
pytest.param((2, 4, 6), list(range(6)), id="3D, padding N, C, H"),
6766
pytest.param((2, 4, 6, 8), tuple(range(2)), id="4D, padding W"),
6867
pytest.param((2, 4, 6, 8), tuple(range(4)), id="4D, padding H, W"),
69-
pytest.param((2, 4, 6, 8), list(range(6)), id="4D, padding C, H, W"),
70-
pytest.param((2, 4, 6, 8), list(range(8)), id="4D, padding N, C, H, W"),
71-
pytest.param((1, 2, 3, 4, 5), list(range(2)), id="5D, padding D"),
68+
pytest.param((1, 2, 3, 4, 5), tuple(range(2)), id="5D, padding D"),
7269
pytest.param((1, 2, 3, 4, 5), tuple(range(4)), id="5D, padding W, D"),
73-
pytest.param((1, 2, 3, 4, 5), list(range(6)), id="5D, padding H, W, D"),
74-
pytest.param((1, 2, 3, 4, 5), tuple(range(8)), id="5D, padding C, H, W, D"),
75-
pytest.param((1, 2, 3, 4, 5), list(range(10)), id="5D, padding N, C, H, W, D"),
7670
],
7771
)
7872
def test_constant_pad_nd_conversion__format_less(input_shape, paddings):
@@ -93,8 +87,9 @@ def test_constant_pad_nd_conversion__format_less(input_shape, paddings):
9387
],
9488
)
9589
def test_constant_pad_nd_conversion__channels_first(input_shape, paddings):
90+
model = ConstantPadNDConvModule(paddings)
9691
edge_program = to_edge_program(
97-
ConstantPadNDConvModule(paddings), input_shape
92+
model, input_shape
9893
).exported_program() # Extra `Conv` after the padding.
9994

10095
input_data = np.random.random(input_shape).astype(np.float32)

0 commit comments

Comments
 (0)