1414from executorch .backends .nxp .backend .ir .converter .conversion import (
1515 aten_translator ,
1616 common ,
17+ translator ,
1718)
1819from executorch .backends .nxp .backend .ir .converter .conversion .common import try_get_input
1920from executorch .backends .nxp .backend .ir .converter .conversion .translator import (
4041from executorch .backends .nxp .backend .ir .tflite_generator .builtin_options import (
4142 conv_2d_options ,
4243 depthwise_conv_2d_options ,
44+ reshape_options ,
4345)
4446from torch .fx import Node
4547from torch .nn import Parameter
@@ -94,13 +96,15 @@ def _is_supported_in_IR(
9496 parameters_mapping : dict [str , Parameter ],
9597 custom_delegation_options : CustomDelegationOptions ,
9698 ) -> bool :
99+ input_tensor_rank = len (node .meta ["val" ].shape )
100+ dimensions = input_tensor_rank - 2
97101 is_transposed = node .args [6 ]
98102 output_padding = node .args [7 ]
99103
100104 if is_transposed :
101105 return False
102106
103- if output_padding != [0 , 0 ] :
107+ if output_padding != [0 ] * dimensions :
104108 return False
105109
106110 if input_tensor_safe (node , 2 ) is None :
@@ -125,7 +129,107 @@ def _get_convolution_arguments(
125129 _ , _ , _ , stride , padding , dilation , transposed , out_padding , groups = (
126130 conv_node .args
127131 )
128- return stride , padding , dilation , transposed , out_padding , groups
132+ return (
133+ list (stride ),
134+ list (padding ),
135+ list (dilation ),
136+ transposed ,
137+ out_padding ,
138+ groups ,
139+ )
140+
141+ def _convert_1d_conv (
142+ self , t_op : tflite_model .Operator , conv_params : ConvParameters
143+ ) -> list [tflite_model .Operator ]:
144+ """Convert the 'Conv' operator with a 1D kernel to TFLite 'Conv2D'.
145+ TFLite doesn't support 1D convolution, but this behaviour can be represented using
146+ Reshape -> Conv2D -> Reshape.
147+ The first reshape introduces a 4th dimension with size 1. The second Reshape removes the temporary dimension.
148+ """
149+ # -- Calculate the shapes for equivalent 2D convolution --
150+ conv_2d_input_shape = translator .nhc_dimensions_to_nhwc (
151+ t_op .tmp_inputs [0 ].shape .vector
152+ )
153+ conv_2d_weight_shape = translator .nhc_dimensions_to_nhwc (
154+ t_op .tmp_inputs [1 ].shape .vector
155+ )
156+ conv_2d_output_shape = translator .nhc_dimensions_to_nhwc (
157+ t_op .tmp_outputs [0 ].shape .vector
158+ )
159+
160+ # -- Generate tensors taking part in the conversion --
161+ reshape1_input = t_op .tmp_inputs [0 ]
162+
163+ reshape1_output = self .builder .duplicate_tensor (
164+ reshape1_input , name_suffix = "_4D_"
165+ )
166+ reshape1_output .shape = tflite_model .Shape (conv_2d_input_shape )
167+
168+ reshape2_input = self .builder .duplicate_tensor (
169+ t_op .tmp_outputs [0 ], name_suffix = "_4D_"
170+ )
171+ reshape2_input .shape = tflite_model .Shape (conv_2d_output_shape )
172+
173+ reshape2_output = t_op .tmp_outputs [0 ]
174+
175+ pre_reshapes = []
176+
177+ # Extend the weights tensor to 4D
178+ weights_tensor = t_op .tmp_inputs [1 ]
179+ if tensor_has_data (weights_tensor ):
180+ # Do it statically
181+ weights_tensor .shape = tflite_model .Shape (conv_2d_weight_shape )
182+ weights_tensor .tmp_buffer .data = weights_tensor .tmp_buffer .data .reshape (
183+ conv_2d_weight_shape
184+ )
185+
186+ else :
187+ # Add a Reshape before the weights tensor
188+ new_weights_tensor = self .builder .duplicate_tensor (
189+ weights_tensor , name_suffix = "_4D_"
190+ )
191+ new_weights_tensor .shape = tflite_model .Shape (conv_2d_weight_shape )
192+
193+ weight_reshape = tflite_model .Operator (
194+ builtin_options = reshape_options .Reshape (conv_2d_weight_shape )
195+ )
196+ weight_reshape .tmp_inputs = [weights_tensor ]
197+ weight_reshape .tmp_outputs = [new_weights_tensor ]
198+
199+ pre_reshapes .append (weight_reshape )
200+
201+ # Save the new weights tensor, to assign it later.
202+ weights_tensor = new_weights_tensor
203+
204+ # -- Create the new operators --
205+ reshape1 = tflite_model .Operator (
206+ builtin_options = reshape_options .Reshape (conv_2d_input_shape )
207+ )
208+ reshape1 .tmp_inputs = [reshape1_input ]
209+ reshape1 .tmp_outputs = [reshape1_output ]
210+ pre_reshapes .append (reshape1 )
211+
212+ reshape2 = tflite_model .Operator (
213+ builtin_options = reshape_options .Reshape (reshape2_output .shape .vector )
214+ )
215+ reshape2 .tmp_inputs = [reshape2_input ]
216+ reshape2 .tmp_outputs = [reshape2_output ]
217+
218+ # Assign the new input and output of the Conv2D
219+ t_op .tmp_inputs = [reshape1_output , weights_tensor ] + t_op .tmp_inputs [
220+ 2 :
221+ ] # Add bias as well, if present
222+ t_op .tmp_outputs = [reshape2_input ]
223+
224+ # Extend all Conv attributes to 2D
225+ common .extend_1d_stride_to_2d (conv_params .stride )
226+ common .extend_1d_dilation_to_2d (conv_params .dilation )
227+ common .extend_1d_padding_to_2d (conv_params .padding )
228+
229+ # Convert the now 2D Conv
230+ converted_conv_ops = self ._convert_2d_conv (t_op , conv_params )
231+
232+ return pre_reshapes + converted_conv_ops + [reshape2 ]
129233
130234 # noinspection PyPep8Naming
131235 def _convert_unpadded_2D (
@@ -266,7 +370,9 @@ def convert(self, node: Node):
266370 conv_params = ConvParameters (stride , padding , dilation , groups )
267371
268372 rank = t_op .tmp_inputs [1 ].shape .len ()
269- if rank == 4 : # Conv2D
373+ if rank == 3 : # Conv1D
374+ ops_to_add = self ._convert_1d_conv (t_op , conv_params )
375+ elif rank == 4 : # Conv2D
270376 ops_to_add = self ._convert_2d_conv (t_op , conv_params )
271377 else :
272378 raise NotImplementedError (
0 commit comments