1414from executorch .backends .nxp .backend .ir .converter .conversion import (
1515 aten_translator ,
1616 common ,
17+ translator ,
1718)
1819from executorch .backends .nxp .backend .ir .converter .conversion .common import try_get_input
1920from executorch .backends .nxp .backend .ir .converter .node_converter import (
3637from executorch .backends .nxp .backend .ir .tflite_generator .builtin_options import (
3738 conv_2d_options ,
3839 depthwise_conv_2d_options ,
40+ reshape_options ,
3941)
4042from torch .fx import Node
4143from torch .nn import Parameter
@@ -85,13 +87,15 @@ def _is_supported_on_target(
8587 def _is_supported_in_IR (
8688 node : Node , parameters_mapping : dict [str , Parameter ]
8789 ) -> bool :
90+ input_tensor_rank = len (node .meta ["val" ].shape )
91+ dimensions = input_tensor_rank - 2
8892 is_transposed = node .args [6 ]
8993 output_padding = node .args [7 ]
9094
9195 if is_transposed :
9296 return False
9397
94- if output_padding != [0 , 0 ] :
98+ if output_padding != [0 ] * dimensions :
9599 return False
96100
97101 if input_tensor_safe (node , 2 ) is None :
@@ -116,7 +120,107 @@ def _get_convolution_arguments(
116120 _ , _ , _ , stride , padding , dilation , transposed , out_padding , groups = (
117121 conv_node .args
118122 )
119- return stride , padding , dilation , transposed , out_padding , groups
123+ return (
124+ list (stride ),
125+ list (padding ),
126+ list (dilation ),
127+ transposed ,
128+ out_padding ,
129+ groups ,
130+ )
131+
132+ def _convert_1d_conv (
133+ self , t_op : tflite_model .Operator , conv_params : ConvParameters
134+ ) -> list [tflite_model .Operator ]:
135+ """Convert the 'Conv' operator with a 1D kernel to TFLite 'Conv2D'.
136+ TFLite doesn't support 1D convolution, but this behaviour can be represented using
137+ Reshape -> Conv2D -> Reshape.
138+ The first reshape introduces a 4th dimension with size 1. The second Reshape removes the temporary dimension.
139+ """
140+ # -- Calculate the shapes for equivalent 2D convolution --
141+ conv_2d_input_shape = translator .nhc_dimensions_to_nhwc (
142+ t_op .tmp_inputs [0 ].shape .vector
143+ )
144+ conv_2d_weight_shape = translator .nhc_dimensions_to_nhwc (
145+ t_op .tmp_inputs [1 ].shape .vector
146+ )
147+ conv_2d_output_shape = translator .nhc_dimensions_to_nhwc (
148+ t_op .tmp_outputs [0 ].shape .vector
149+ )
150+
151+ # -- Generate tensors taking part in the conversion --
152+ reshape1_input = t_op .tmp_inputs [0 ]
153+
154+ reshape1_output = self .builder .duplicate_tensor (
155+ reshape1_input , name_suffix = "_4D_"
156+ )
157+ reshape1_output .shape = tflite_model .Shape (conv_2d_input_shape )
158+
159+ reshape2_input = self .builder .duplicate_tensor (
160+ t_op .tmp_outputs [0 ], name_suffix = "_4D_"
161+ )
162+ reshape2_input .shape = tflite_model .Shape (conv_2d_output_shape )
163+
164+ reshape2_output = t_op .tmp_outputs [0 ]
165+
166+ pre_reshapes = []
167+
168+ # Extend the weights tensor to 4D
169+ weights_tensor = t_op .tmp_inputs [1 ]
170+ if tensor_has_data (weights_tensor ):
171+ # Do it statically
172+ weights_tensor .shape = tflite_model .Shape (conv_2d_weight_shape )
173+ weights_tensor .tmp_buffer .data = weights_tensor .tmp_buffer .data .reshape (
174+ conv_2d_weight_shape
175+ )
176+
177+ else :
178+ # Add a Reshape before the weights tensor
179+ new_weights_tensor = self .builder .duplicate_tensor (
180+ weights_tensor , name_suffix = "_4D_"
181+ )
182+ new_weights_tensor .shape = tflite_model .Shape (conv_2d_weight_shape )
183+
184+ weight_reshape = tflite_model .Operator (
185+ builtin_options = reshape_options .Reshape (conv_2d_weight_shape )
186+ )
187+ weight_reshape .tmp_inputs = [weights_tensor ]
188+ weight_reshape .tmp_outputs = [new_weights_tensor ]
189+
190+ pre_reshapes .append (weight_reshape )
191+
192+ # Save the new weights tensor, to assign it later.
193+ weights_tensor = new_weights_tensor
194+
195+ # -- Create the new operators --
196+ reshape1 = tflite_model .Operator (
197+ builtin_options = reshape_options .Reshape (conv_2d_input_shape )
198+ )
199+ reshape1 .tmp_inputs = [reshape1_input ]
200+ reshape1 .tmp_outputs = [reshape1_output ]
201+ pre_reshapes .append (reshape1 )
202+
203+ reshape2 = tflite_model .Operator (
204+ builtin_options = reshape_options .Reshape (reshape2_output .shape .vector )
205+ )
206+ reshape2 .tmp_inputs = [reshape2_input ]
207+ reshape2 .tmp_outputs = [reshape2_output ]
208+
209+ # Assign the new input and output of the Conv2D
210+ t_op .tmp_inputs = [reshape1_output , weights_tensor ] + t_op .tmp_inputs [
211+ 2 :
212+ ] # Add bias as well, if present
213+ t_op .tmp_outputs = [reshape2_input ]
214+
215+ # Extend all Conv attributes to 2D
216+ common .extend_1d_stride_to_2d (conv_params .stride )
217+ common .extend_1d_dilation_to_2d (conv_params .dilation )
218+ common .extend_1d_padding_to_2d (conv_params .padding )
219+
220+ # Convert the now 2D Conv
221+ converted_conv_ops = self ._convert_2d_conv (t_op , conv_params )
222+
223+ return pre_reshapes + converted_conv_ops + [reshape2 ]
120224
121225 # noinspection PyPep8Naming
122226 def _convert_unpadded_2D (
@@ -237,7 +341,9 @@ def convert(self, node: Node):
237341 conv_params = ConvParameters (stride , padding , dilation , groups )
238342
239343 rank = t_op .tmp_inputs [1 ].shape .len ()
240- if rank == 4 : # Conv2D
344+ if rank == 3 : # Conv1D
345+ ops_to_add = self ._convert_1d_conv (t_op , conv_params )
346+ elif rank == 4 : # Conv2D
241347 ops_to_add = self ._convert_2d_conv (t_op , conv_params )
242348 else :
243349 raise NotImplementedError (
0 commit comments