@@ -106,20 +106,6 @@ def getNodeArgs(node: Node, tosa_spec: TosaSpecification) -> list[TosaArg]:
106
106
raise ValueError (f"Failed processing args to op:\n { node } " ) from e
107
107
108
108
109
- def get_output_node (node : Node ) -> Node :
110
- return list (node .users )[0 ]
111
-
112
-
113
- """ TOSA reshape returns a tensor with the same type/values as the input.
114
- No data conversion happens during a reshape operation. """
115
-
116
-
117
- def build_reshape (tosa_fb , input_name , new_shape , output_name ):
118
- attr = ts .TosaSerializerAttribute ()
119
- attr .ReshapeAttribute (new_shape )
120
- tosa_fb .addOperator (ts .TosaOp .Op ().RESHAPE , [input_name ], [output_name ], attr )
121
-
122
-
123
109
def are_fake_tensors_broadcastable (
124
110
fake_tensors : list [FakeTensor ],
125
111
) -> tuple [bool , list [int ]]:
@@ -260,45 +246,6 @@ def build_reshape_tosa_1_0(
260
246
)
261
247
262
248
263
- def reshape_for_broadcast (tosa_fb , inputs , dim_order = None ):
264
- assert len (inputs ) == 2
265
- input1 = inputs [0 ]
266
- input2 = inputs [1 ]
267
-
268
- def get_new_shape (l_rank_in , h_rank_in ):
269
- rank_diff = len (h_rank_in .shape ) - len (l_rank_in .shape )
270
- new_shape = list (l_rank_in .shape )
271
-
272
- for _ in range (rank_diff ):
273
- new_shape .insert (0 , 1 )
274
- return tuple (new_shape )
275
-
276
- if len (input1 .shape ) == len (input2 .shape ):
277
- return input1 , input2
278
- elif len (input1 .shape ) > len (input2 .shape ):
279
- l_rank_in = input2
280
- h_rank_in = input1
281
- elif len (input1 .shape ) < len (input2 .shape ):
282
- l_rank_in = input1
283
- h_rank_in = input2
284
-
285
- new_shape = get_new_shape (l_rank_in , h_rank_in )
286
- dim_order = h_rank_in .dim_order if dim_order is None else dim_order
287
- new_shape = tosa_shape (new_shape , dim_order )
288
-
289
- reshaped = tosa_fb .addIntermediate (
290
- new_shape ,
291
- inputs [0 ].dtype ,
292
- )
293
-
294
- build_reshape (tosa_fb , l_rank_in .name , new_shape , reshaped .name )
295
-
296
- if len (input1 .shape ) > len (input2 .shape ):
297
- return input1 , reshaped
298
- else :
299
- return reshaped , input2
300
-
301
-
302
249
def is_consumer_node_depthwise_conv2d (node : Node ):
303
250
consumer_node = list (node .users )[0 ]
304
251
if consumer_node .target == exir_ops .edge .aten .convolution .default :
@@ -322,35 +269,6 @@ def tosa_shape(shape, dim_order):
322
269
return removed_symints
323
270
324
271
325
- def expand_dims (
326
- tosa_graph : ts .TosaSerializer ,
327
- input_node : TosaArg ,
328
- dtype : int ,
329
- dim : int ,
330
- ) -> Any :
331
- """Inserts TOSA operators into the tosa_graph, that perform the equivalent
332
- of the expand_dims (a.k.a unsqueeze) operation. A new axis is created at the
333
- dim location.
334
-
335
- Args:
336
- tosa_graph (ts.TosaSerializer): The TOSA graph to manipulate.
337
- input_node (TosaArg): The parent node of the expand dim operations.
338
- dtype (ts.DType): The data type expand dims operations.
339
- dim (int): The dimension to expand.
340
-
341
- Returns:
342
- Any: The output tensor of the inserted operation in the TOSA graph.
343
- """
344
- new_shape = list (input_node .shape )
345
- new_shape .insert (dim , 1 )
346
-
347
- intermediate = tosa_graph .addIntermediate (new_shape , dtype )
348
-
349
- build_reshape (tosa_graph , input_node .name , new_shape , intermediate .name )
350
-
351
- return intermediate
352
-
353
-
354
272
def get_resize_parameters_1d (
355
273
input_size : int | torch .SymInt ,
356
274
output_size : int | torch .SymInt ,
0 commit comments