@@ -106,20 +106,6 @@ def getNodeArgs(node: Node, tosa_spec: TosaSpecification) -> list[TosaArg]:
106106 raise ValueError (f"Failed processing args to op:\n { node } " ) from e
107107
108108
109- def get_output_node (node : Node ) -> Node :
110- return list (node .users )[0 ]
111-
112-
113- """ TOSA reshape returns a tensor with the same type/values as the input.
114- No data conversion happens during a reshape operation. """
115-
116-
117- def build_reshape (tosa_fb , input_name , new_shape , output_name ):
118- attr = ts .TosaSerializerAttribute ()
119- attr .ReshapeAttribute (new_shape )
120- tosa_fb .addOperator (ts .TosaOp .Op ().RESHAPE , [input_name ], [output_name ], attr )
121-
122-
123109def are_fake_tensors_broadcastable (
124110 fake_tensors : list [FakeTensor ],
125111) -> tuple [bool , list [int ]]:
@@ -260,45 +246,6 @@ def build_reshape_tosa_1_0(
260246 )
261247
262248
263- def reshape_for_broadcast (tosa_fb , inputs , dim_order = None ):
264- assert len (inputs ) == 2
265- input1 = inputs [0 ]
266- input2 = inputs [1 ]
267-
268- def get_new_shape (l_rank_in , h_rank_in ):
269- rank_diff = len (h_rank_in .shape ) - len (l_rank_in .shape )
270- new_shape = list (l_rank_in .shape )
271-
272- for _ in range (rank_diff ):
273- new_shape .insert (0 , 1 )
274- return tuple (new_shape )
275-
276- if len (input1 .shape ) == len (input2 .shape ):
277- return input1 , input2
278- elif len (input1 .shape ) > len (input2 .shape ):
279- l_rank_in = input2
280- h_rank_in = input1
281- elif len (input1 .shape ) < len (input2 .shape ):
282- l_rank_in = input1
283- h_rank_in = input2
284-
285- new_shape = get_new_shape (l_rank_in , h_rank_in )
286- dim_order = h_rank_in .dim_order if dim_order is None else dim_order
287- new_shape = tosa_shape (new_shape , dim_order )
288-
289- reshaped = tosa_fb .addIntermediate (
290- new_shape ,
291- inputs [0 ].dtype ,
292- )
293-
294- build_reshape (tosa_fb , l_rank_in .name , new_shape , reshaped .name )
295-
296- if len (input1 .shape ) > len (input2 .shape ):
297- return input1 , reshaped
298- else :
299- return reshaped , input2
300-
301-
302249def is_consumer_node_depthwise_conv2d (node : Node ):
303250 consumer_node = list (node .users )[0 ]
304251 if consumer_node .target == exir_ops .edge .aten .convolution .default :
@@ -322,35 +269,6 @@ def tosa_shape(shape, dim_order):
322269 return removed_symints
323270
324271
325- def expand_dims (
326- tosa_graph : ts .TosaSerializer ,
327- input_node : TosaArg ,
328- dtype : int ,
329- dim : int ,
330- ) -> Any :
331- """Inserts TOSA operators into the tosa_graph, that perform the equivalent
332- of the expand_dims (a.k.a unsqueeze) operation. A new axis is created at the
333- dim location.
334-
335- Args:
336- tosa_graph (ts.TosaSerializer): The TOSA graph to manipulate.
337- input_node (TosaArg): The parent node of the expand dim operations.
338- dtype (ts.DType): The data type expand dims operations.
339- dim (int): The dimension to expand.
340-
341- Returns:
342- Any: The output tensor of the inserted operation in the TOSA graph.
343- """
344- new_shape = list (input_node .shape )
345- new_shape .insert (dim , 1 )
346-
347- intermediate = tosa_graph .addIntermediate (new_shape , dtype )
348-
349- build_reshape (tosa_graph , input_node .name , new_shape , intermediate .name )
350-
351- return intermediate
352-
353-
354272def get_resize_parameters_1d (
355273 input_size : int | torch .SymInt ,
356274 output_size : int | torch .SymInt ,
0 commit comments