Skip to content

Commit 5182658

Browse files
NXP backend: Add support for conversion of Conv1D operator
+ fix operators_not_to_delegate assignment in partitioner + fix input_shapes type hint in to_quantized_edge_program() + add test cases for Conv1D operator + add fix for padding with zero-point
1 parent bd92f1a commit 5182658

File tree

7 files changed

+563
-85
lines changed

7 files changed

+563
-85
lines changed

backends/nxp/backend/ir/converter/conversion/common.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -70,29 +70,22 @@ def try_get_input(t_op: tflite_model.Operator, idx: int) -> tflite_model.Tensor
7070
return tensor
7171

7272

73-
def extend_1d_pads_to_2d(onnx_1d_pads: MutableSequence):
74-
"""Extend the onnx 'pads' operator attribute that represents padding for a 1D kernel to 2D, by adding '0's."""
75-
if onnx_1d_pads is not None:
76-
onnx_1d_pads.insert(1, 0)
77-
onnx_1d_pads.append(0)
73+
def extend_1d_padding_to_2d(tflite_1d_padding: MutableSequence):
74+
"""Extend the PyTorch 'padding' operator attribute that represents padding for a 1D kernel to 2D, by adding '0's."""
75+
if tflite_1d_padding is not None:
76+
tflite_1d_padding.append(0)
7877

7978

80-
def extend_1d_strides_to_2d(onnx_1d_strides: MutableSequence):
81-
"""Extend the onnx 'strides' operator attribute that represents strides for a 1D kernel to 2D, by adding '1'."""
82-
if onnx_1d_strides is not None:
83-
onnx_1d_strides.append(1)
79+
def extend_1d_stride_to_2d(tflite_1d_stride: MutableSequence):
80+
"""Extend the PyTorch 'stride' operator attribute that represents stride for a 1D kernel to 2D, by adding '1'."""
81+
if tflite_1d_stride is not None:
82+
tflite_1d_stride.append(1)
8483

8584

86-
def extend_1d_dilations_to_2d(onnx_1d_dilations: MutableSequence):
87-
"""Extend the onnx 'dilations' operator attribute that represents dilations for a 1D kernel to 2D, by adding '1'."""
88-
if onnx_1d_dilations is not None:
89-
onnx_1d_dilations.append(1)
90-
91-
92-
def extend_1d_kernel_shape_to_2d(onnx_1d_kernel_shape: MutableSequence):
93-
"""Extend the onnx 1D 'kernel_shape' operator attribute to 2D, by adding '1'."""
94-
if onnx_1d_kernel_shape is not None:
95-
onnx_1d_kernel_shape.append(1)
85+
def extend_1d_dilation_to_2d(tflite_1d_dilation: MutableSequence):
86+
"""Extend the PyTorch 'dilation' operator attribute that represents dilation for a 1D kernel to 2D, by adding '1'."""
87+
if tflite_1d_dilation is not None:
88+
tflite_1d_dilation.append(1)
9689

9790

9891
StridedOptions = (

backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py

Lines changed: 136 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414
from executorch.backends.nxp.backend.ir.converter.conversion import (
1515
aten_translator,
1616
common,
17+
translator,
1718
)
1819
from executorch.backends.nxp.backend.ir.converter.conversion.common import try_get_input
20+
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
21+
tf_lite_type_to_numpy,
22+
)
1923
from executorch.backends.nxp.backend.ir.converter.node_converter import (
2024
NodeConverter,
2125
Target,
@@ -36,6 +40,7 @@
3640
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
3741
conv_2d_options,
3842
depthwise_conv_2d_options,
43+
reshape_options,
3944
)
4045
from torch.fx import Node
4146
from torch.nn import Parameter
@@ -85,13 +90,15 @@ def _is_supported_on_target(
8590
def _is_supported_in_IR(
8691
node: Node, parameters_mapping: dict[str, Parameter]
8792
) -> bool:
93+
input_tensor_rank = len(node.meta["val"].shape)
94+
dimensions = input_tensor_rank - 2
8895
is_transposed = node.args[6]
8996
output_padding = node.args[7]
9097

9198
if is_transposed:
9299
return False
93100

94-
if output_padding != [0, 0]:
101+
if output_padding != [0] * dimensions:
95102
return False
96103

97104
if input_tensor_safe(node, 2) is None:
@@ -116,7 +123,107 @@ def _get_convolution_arguments(
116123
_, _, _, stride, padding, dilation, transposed, out_padding, groups = (
117124
conv_node.args
118125
)
119-
return stride, padding, dilation, transposed, out_padding, groups
126+
return (
127+
list(stride),
128+
list(padding),
129+
list(dilation),
130+
transposed,
131+
out_padding,
132+
groups,
133+
)
134+
135+
def _convert_1d_conv(
136+
self, t_op: tflite_model.Operator, conv_params: ConvParameters
137+
) -> list[tflite_model.Operator]:
138+
"""Convert the 'Conv' operator with a 1D kernel to TFLite 'Conv2D'.
139+
TFLite doesn't support 1D convolution, but this behaviour can be represented using
140+
Reshape -> Conv2D -> Reshape.
141+
The first reshape introduces a 4th dimension with size 1. The second Reshape removes the temporary dimension.
142+
"""
143+
# -- Calculate the shapes for equivalent 2D convolution --
144+
conv_2d_input_shape = translator.nhc_dimensions_to_nhwc(
145+
t_op.tmp_inputs[0].shape.vector
146+
)
147+
conv_2d_weight_shape = translator.nhc_dimensions_to_nhwc(
148+
t_op.tmp_inputs[1].shape.vector
149+
)
150+
conv_2d_output_shape = translator.nhc_dimensions_to_nhwc(
151+
t_op.tmp_outputs[0].shape.vector
152+
)
153+
154+
# -- Generate tensors taking part in the conversion --
155+
reshape1_input = t_op.tmp_inputs[0]
156+
157+
reshape1_output = self.builder.duplicate_tensor(
158+
reshape1_input, name_suffix="_4D_"
159+
)
160+
reshape1_output.shape = tflite_model.Shape(conv_2d_input_shape)
161+
162+
reshape2_input = self.builder.duplicate_tensor(
163+
t_op.tmp_outputs[0], name_suffix="_4D_"
164+
)
165+
reshape2_input.shape = tflite_model.Shape(conv_2d_output_shape)
166+
167+
reshape2_output = t_op.tmp_outputs[0]
168+
169+
pre_reshapes = []
170+
171+
# Extend the weights tensor to 4D
172+
weights_tensor = t_op.tmp_inputs[1]
173+
if tensor_has_data(weights_tensor):
174+
# Do it statically
175+
weights_tensor.shape = tflite_model.Shape(conv_2d_weight_shape)
176+
weights_tensor.tmp_buffer.data = weights_tensor.tmp_buffer.data.reshape(
177+
conv_2d_weight_shape
178+
)
179+
180+
else:
181+
# Add a Reshape before the weights tensor
182+
new_weights_tensor = self.builder.duplicate_tensor(
183+
weights_tensor, name_suffix="_4D_"
184+
)
185+
new_weights_tensor.shape = tflite_model.Shape(conv_2d_weight_shape)
186+
187+
weight_reshape = tflite_model.Operator(
188+
builtin_options=reshape_options.Reshape(conv_2d_weight_shape)
189+
)
190+
weight_reshape.tmp_inputs = [weights_tensor]
191+
weight_reshape.tmp_outputs = [new_weights_tensor]
192+
193+
pre_reshapes.append(weight_reshape)
194+
195+
# Save the new weights tensor, to assign it later.
196+
weights_tensor = new_weights_tensor
197+
198+
# -- Create the new operators --
199+
reshape1 = tflite_model.Operator(
200+
builtin_options=reshape_options.Reshape(conv_2d_input_shape)
201+
)
202+
reshape1.tmp_inputs = [reshape1_input]
203+
reshape1.tmp_outputs = [reshape1_output]
204+
pre_reshapes.append(reshape1)
205+
206+
reshape2 = tflite_model.Operator(
207+
builtin_options=reshape_options.Reshape(reshape2_output.shape.vector)
208+
)
209+
reshape2.tmp_inputs = [reshape2_input]
210+
reshape2.tmp_outputs = [reshape2_output]
211+
212+
# Assign the new input and output of the Conv2D
213+
t_op.tmp_inputs = [reshape1_output, weights_tensor] + t_op.tmp_inputs[
214+
2:
215+
] # Add bias as well, if present
216+
t_op.tmp_outputs = [reshape2_input]
217+
218+
# Extend all Conv attributes to 2D
219+
common.extend_1d_stride_to_2d(conv_params.stride)
220+
common.extend_1d_dilation_to_2d(conv_params.dilation)
221+
common.extend_1d_padding_to_2d(conv_params.padding)
222+
223+
# Convert the now 2D Conv
224+
converted_conv_ops = self._convert_2d_conv(t_op, conv_params)
225+
226+
return pre_reshapes + converted_conv_ops + [reshape2]
120227

121228
# noinspection PyPep8Naming
122229
def _convert_unpadded_2D(
@@ -182,9 +289,19 @@ def _convert_2d_conv(
182289
aten_translator.convert_padding(conv_params.padding)
183290
)
184291
if explicit_padding is not None:
185-
# Need to prepend a 'Pad' operator, which adds 0s.
292+
# Need to prepend a 'Pad' operator, which adds 0s (or `zero_point` for the quantized case).
293+
input_quantization = t_op.tmp_inputs[0].quantization
294+
pad_value = (
295+
None
296+
if input_quantization is None
297+
else np.array(input_quantization.zero_point[0]).astype(
298+
tf_lite_type_to_numpy(t_op.tmp_inputs[0].type)
299+
)
300+
)
186301
conversion_result.ops_list.add_pre(
187-
self.builder.create_pad_operator_before(t_op, 0, explicit_padding)
302+
self.builder.create_pad_operator_before(
303+
t_op, 0, explicit_padding, constant_value=pad_value
304+
)
188305
)
189306

190307
# DepthwiseConv2D expects weights in format [kernel_channels, kernel_height, kernel_width, output_channels]
@@ -221,9 +338,19 @@ def _convert_2d_conv(
221338
aten_translator.convert_padding(conv_params.padding)
222339
)
223340
if explicit_padding is not None:
224-
# Need to prepend a 'Pad' operator, which adds 0s.
341+
# Need to prepend a 'Pad' operator, which adds 0s (or `zero_point` for the quantized case).
342+
input_quantization = t_op.tmp_inputs[0].quantization
343+
pad_value = (
344+
None
345+
if input_quantization is None
346+
else np.array(input_quantization.zero_point[0]).astype(
347+
tf_lite_type_to_numpy(t_op.tmp_inputs[0].type)
348+
)
349+
)
225350
conversion_result.ops_list.add_pre(
226-
self.builder.create_pad_operator_before(t_op, 0, explicit_padding)
351+
self.builder.create_pad_operator_before(
352+
t_op, 0, explicit_padding, constant_value=pad_value
353+
)
227354
)
228355

229356
return conversion_result.ops_list.flatten()
@@ -237,7 +364,9 @@ def convert(self, node: Node):
237364
conv_params = ConvParameters(stride, padding, dilation, groups)
238365

239366
rank = t_op.tmp_inputs[1].shape.len()
240-
if rank == 4: # Conv2D
367+
if rank == 3: # Conv1D
368+
ops_to_add = self._convert_1d_conv(t_op, conv_params)
369+
elif rank == 4: # Conv2D
241370
ops_to_add = self._convert_2d_conv(t_op, conv_params)
242371
else:
243372
raise NotImplementedError(

backends/nxp/backend/ir/converter/node_converters/shared/conv_utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
)
1515
from executorch.backends.nxp.backend.ir.converter.conversion import aten_translator
1616
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList
17+
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
18+
tf_lite_type_to_numpy,
19+
)
1720
from executorch.backends.nxp.backend.ir.converter.tensor_utils import tensor_has_data
1821
from executorch.backends.nxp.backend.ir.lib.tflite.Padding import Padding
1922
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
@@ -289,9 +292,17 @@ def build_input_tensor_padding(
289292

290293
tfl_padding, explicit_padding = aten_translator.convert_padding(conv_params.padding)
291294
if explicit_padding is not None:
292-
# Must add extra 'Pad' operator
295+
# Must add extra 'Pad' operator, which adds 0s (or `zero_point` for the quantized case).
296+
input_quantization = t_op.tmp_inputs[0].quantization
297+
pad_value = (
298+
None
299+
if input_quantization is None
300+
else np.array(input_quantization.zero_point[0]).astype(
301+
tf_lite_type_to_numpy(t_op.tmp_inputs[0].type)
302+
)
303+
)
293304
return tfl_padding, builder.create_pad_operator_before(
294-
t_op, input_idx, explicit_padding
305+
t_op, input_idx, explicit_padding, pad_value
295306
)
296307

297308
return tfl_padding, None

backends/nxp/neutron_partitioner.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -297,15 +297,10 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
297297
qdq_cluster_recognizer = QDQClusterRecognizer()
298298
qdq_cluster_recognizer.tag_qdq_clusters(nodes)
299299
graph_module.recompile()
300+
target = self.delegation_spec[1][2].value
301+
target = Target(target.decode())
300302

301-
target = None
302-
operators_not_to_delegate = ""
303-
for spec in self.delegation_spec.compile_specs:
304-
if spec.key == "target":
305-
target = Target(spec.value.decode())
306-
if spec.key == "operators_not_to_delegate":
307-
operators_not_to_delegate = spec.value.decode().split(",")
308-
assert target is not None
303+
operators_not_to_delegate = self.delegation_spec[1][4].value.decode().split(",")
309304
logging.info(f"Operators not to delegate: {operators_not_to_delegate}")
310305

311306
parameters_mapping = EdgeProgramToIRConverter.map_inputs_to_parameters(

backends/nxp/tests/executorch_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def get_random_float_data(input_shapes: tuple[int] | list[tuple[int]]):
4848

4949
def to_quantized_edge_program(
5050
model: torch.nn.Module,
51-
input_shapes: tuple[int] | list[tuple[int]],
51+
input_shapes: tuple[int, ...] | list[tuple[int, ...]],
5252
operators_not_to_delegate: list[str] = None,
5353
target="imxrt700",
5454
neutron_converter_flavor="SDK_25_03",
@@ -100,7 +100,7 @@ def to_quantized_edge_program(
100100

101101

102102
def to_quantized_executorch_program(
103-
model: torch.nn.Module, input_shapes: tuple[int] | list[tuple[int]]
103+
model: torch.nn.Module, input_shapes: tuple[int, ...] | list[tuple[int, ...]]
104104
) -> ExecutorchProgramManager:
105105
edge_program_manager = to_quantized_edge_program(model, input_shapes)
106106

@@ -110,7 +110,7 @@ def to_quantized_executorch_program(
110110

111111

112112
def to_edge_program(
113-
model: nn.Module, input_shapes: tuple[int] | list[tuple[int]]
113+
model: nn.Module, input_shapes: tuple[int, ...] | list[tuple[int, ...]]
114114
) -> EdgeProgramManager:
115115
if isinstance(input_shapes, list):
116116
assert all(isinstance(input_shape, tuple) for input_shape in input_shapes), (

0 commit comments

Comments
 (0)