@@ -246,7 +246,7 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F
246
246
exceptions = []
247
247
if initialized_tables is None :
248
248
initialized_tables = {}
249
-
249
+
250
250
ops = list (g .get_nodes ())
251
251
for node in ops :
252
252
logger .debug ("Process node: %s\n %s" , node .name , node .summary )
@@ -263,7 +263,7 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F
263
263
logger .error ("Tensorflow op [%s: %s] is not supported" , node .name , op )
264
264
continue
265
265
mapped_op [op ] += 1
266
-
266
+
267
267
func , kwargs = map_info
268
268
if kwargs :
269
269
# if there is a tf_op/onnx_op key we'll map the old type to a new type
@@ -273,6 +273,7 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F
273
273
kwargs ["tfl_op" if is_tflite else "tf_op" ] = op
274
274
node .type = converted_op
275
275
body_graphs = node .get_body_graphs ()
276
+
276
277
if body_graphs :
277
278
for attr , b_g in body_graphs .items ():
278
279
logger .debug ("start handling subgraph of %s's attribute %s" , node .name , attr )
@@ -287,7 +288,7 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F
287
288
b_g .topological_sort (b_g .get_nodes ())
288
289
exceptions .extend (body_exceptions )
289
290
logger .debug ("finish handling subgraph of %s's attribute %s" , node .name , attr )
290
-
291
+
291
292
try :
292
293
func (g , node , ** kwargs , initialized_tables = initialized_tables , dequantize = dequantize )
293
294
if not is_tflite :
@@ -302,7 +303,6 @@ def tensorflow_onnx_mapping(g, ops_mapping, initialized_tables=None, is_tflite=F
302
303
logger .error ("Failed to convert node %r (fct=%r)\n %r" ,
303
304
node .name , func , summary , exc_info = 1 )
304
305
exceptions .append (ex )
305
-
306
306
return mapped_op , unmapped_op , exceptions
307
307
308
308
@@ -332,26 +332,96 @@ def transpose_inputs(ctx, inputs_as_nchw):
332
332
def transpose_outputs (ctx , outputs_as_nchw ):
333
333
"""Insert a transpose from NHWC to NCHW on model output on users request."""
334
334
ops = []
335
+
336
+ # First pass: Find and handle edge cases in original nodes
337
+ edge_case_handled = set ()
338
+
335
339
for node in ctx .get_nodes ():
336
340
for output_name in node .output :
337
- if output_name in outputs_as_nchw :
341
+ # Check if this output is used to create a model output
342
+ consumers = ctx .find_output_consumers (output_name )
343
+
344
+ # Look for edge case: output consumed by both model output node and other nodes
345
+ model_output_consumers = []
346
+ other_consumers = []
347
+
348
+ for consumer in consumers :
349
+ if consumer .output and any (out in outputs_as_nchw for out in consumer .output ):
350
+ model_output_consumers .append (consumer )
351
+ else :
352
+ other_consumers .append (consumer )
353
+
354
+ # Edge case: original node output goes to both model output and other layers
355
+ if model_output_consumers and other_consumers :
356
+ # Get shape for validation
357
+ shape = ctx .get_shape (output_name )
358
+ if len (shape ) != len (constants .NHWC_TO_NCHW ):
359
+ continue
360
+
361
+ # Handle edge case: Use insert_node_on_output for proper structure
362
+ # Step 1: Create Identity node and insert it on the original output
363
+ identity_name = utils .make_name (node .name + "_identity" )
364
+ identity = ctx .make_node ("Identity" , [output_name ],
365
+ outputs = [identity_name + ":0" ], name = identity_name )
366
+
367
+ # Copy shape information
368
+ ctx .copy_shape (output_name , identity .output [0 ])
369
+ ctx .set_shape (identity .output [0 ], shape )
370
+
371
+ # Insert the identity on the original output - this will redirect ALL consumers
372
+ ctx .insert_node_on_output (identity , output_name )
373
+
374
+ # Step 2: Create Transpose node and connect it to Identity
375
+ transpose_name = utils .make_name (identity .name + "_transpose" )
376
+ transpose = ctx .make_node ("Transpose" , [identity .output [0 ]],
377
+ outputs = [transpose_name + ":0" ], name = transpose_name )
378
+ transpose .set_attr ("perm" , constants .NHWC_TO_NCHW )
379
+ ctx .copy_shape (identity .output [0 ], transpose .output [0 ])
380
+ ctx .set_shape (transpose .output [0 ], np .array (shape )[constants .NHWC_TO_NCHW ])
381
+
382
+ # Step 3: Manually redirect ONLY the model output consumers to use transpose
383
+ for consumer in model_output_consumers :
384
+ ctx .replace_all_inputs (identity .output [0 ], transpose .output [0 ], ops = [consumer ])
385
+
386
+ # Mark this output as handled
387
+ edge_case_handled .add (output_name )
388
+
389
+ ops .append (node )
390
+ ops .append (identity )
391
+ ops .append (transpose )
392
+ break # Only handle one edge case per node
393
+
394
+ # If no edge case was handled for this node, add it normally
395
+ if not any (out in edge_case_handled for out in node .output ):
396
+ ops .append (node )
397
+
398
+ # Second pass: Handle normal cases (nodes that directly output to model outputs)
399
+ final_ops = []
400
+ for node in ops :
401
+ handled = False
402
+ for output_name in node .output :
403
+ if output_name in outputs_as_nchw and output_name not in edge_case_handled :
404
+ # Get shape for validation
338
405
shape = ctx .get_shape (output_name )
339
406
if len (shape ) != len (constants .NHWC_TO_NCHW ):
340
407
logger .warning ("transpose_output for %s: shape must be rank 4, ignored" % output_name )
341
- ops .append (node )
342
408
continue
409
+
343
410
# insert transpose
344
411
op_name = utils .make_name (node .name )
345
- transpose = ctx .insert_new_node_on_output ("Transpose" , node . input [ 0 ] , name = op_name )
412
+ transpose = ctx .insert_new_node_on_output ("Transpose" , output_name , name = op_name )
346
413
transpose .set_attr ("perm" , constants .NHWC_TO_NCHW )
347
- ctx .copy_shape (node .output [0 ], transpose .output [0 ])
348
- ctx .set_shape (transpose .output [0 ], np .array (shape )[constants .NHWC_TO_NCHW ])
414
+ ctx .copy_shape (output_name , transpose .output [0 ])
349
415
ctx .set_shape (output_name , np .array (shape )[constants .NHWC_TO_NCHW ])
350
- ops .append (transpose )
351
- ops .append (node )
352
- continue
353
- ops .append (node )
354
- ctx .reset_nodes (ops )
416
+ final_ops .append (transpose )
417
+ final_ops .append (node )
418
+ handled = True
419
+ break
420
+
421
+ if not handled :
422
+ final_ops .append (node )
423
+
424
+ ctx .reset_nodes (final_ops )
355
425
356
426
def topological_sort (g , continue_on_error ):
357
427
ops = g .get_nodes ()
@@ -522,7 +592,7 @@ def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, outputs_as_nchw,
522
592
initialized_tables , is_tflite = False , dequantize = False ):
523
593
524
594
op_cnt , attr_cnt = g .dump_node_statistics (include_attrs = True , include_subgraphs = False )
525
-
595
+
526
596
if is_tflite :
527
597
tfl_rewriters = []
528
598
if dequantize :
@@ -531,13 +601,16 @@ def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, outputs_as_nchw,
531
601
tfl_rewriters .append (rewrite_tfl_select_zero )
532
602
tfl_rewriters .append (rewrite_tfl_rfft )
533
603
run_rewriters (g , tfl_rewriters , continue_on_error )
604
+
534
605
tfl_ops_mapping = handler .tfl_op .create_tfl_to_tf_mapping ()
535
606
_ , _ , exceptions = tensorflow_onnx_mapping (g , tfl_ops_mapping , is_tflite = True , dequantize = False )
607
+
536
608
if exceptions and not continue_on_error :
537
609
raise exceptions [0 ]
538
610
539
611
# create ops mapping for the desired opsets
540
612
ops_mapping = handler .tf_op .create_mapping (g .opset , g .extra_opset )
613
+
541
614
542
615
# apply custom ops on top of the assembled opset. We can either complement the opset
543
616
# or override existing ops with a custom op.
0 commit comments