@@ -146,7 +146,6 @@ def tensorflow_to_onnx(graph, shape_override):
146
146
147
147
def _convert_shapenode_to_int64 (ctx , node , input_number ):
148
148
"""cast int32 shape into int64 shape."""
149
- shape_node = node .inputs [input_number ]
150
149
name = node .input [input_number ]
151
150
152
151
cast_node = ctx .insert_new_node_on_input (node , "Cast" , name )
@@ -382,7 +381,6 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
382
381
input_name = node .input [idx ]
383
382
transpose = ctx .insert_new_node_on_input (node , "Transpose" , input_name )
384
383
transpose .set_attr ("perm" , NHWC_TO_NCHW )
385
- transpose .inserted_nchw = True
386
384
transpose .skip_conversion = True
387
385
shape = ctx .get_shape (input_name )
388
386
new_shape = spatial_map (shape , NHWC_TO_NCHW )
@@ -393,19 +391,30 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
393
391
# kernel must to be transposed
394
392
if with_kernel :
395
393
parent = node .inputs [1 ]
396
- # note: kernel may be used by multiple nodes,
397
- # so even kernel is a const, transposing kernel can't be done statically.
398
- # so "transpose" op is inserted here and will consider to remove it in later optimization phase if possible.
399
- input_name = node .input [1 ]
400
- transpose = ctx .insert_new_node_on_input (node , "Transpose" , input_name )
401
- transpose .set_attr ("perm" , HWCN_TO_NCHW )
402
- transpose .inserted_nchw = True
403
- transpose .skip_conversion = True
404
- ctx .copy_shape (input_name , transpose .output [0 ])
405
- new_shape = spatial_map (ctx .get_shape (input_name ), HWCN_TO_NCHW )
406
- ctx .set_shape (transpose .output [0 ], new_shape )
407
- nodes .append (transpose )
408
- parent .data_format = "NCHW"
394
+
395
+ need_transpose = True
396
+ if node .inputs [1 ].is_const ():
397
+ # kernel is const - transpose the const if we are the only consumer of const
398
+ # TODO: maybe we should make a copy of the const, or look at the other consumers
399
+ # if they'd want a transose as well.
400
+ consumers = ctx .find_output_consumers (node .input [1 ])
401
+ if len (consumers ) == 1 :
402
+ val = parent .get_tensor_value (as_list = False )
403
+ val = val .transpose (HWCN_TO_NCHW )
404
+ parent .set_tensor_value (val )
405
+ parent .data_format = "NCHW"
406
+ need_transpose = False
407
+
408
+ if need_transpose :
409
+ input_name = node .input [1 ]
410
+ transpose = ctx .insert_new_node_on_input (node , "Transpose" , input_name )
411
+ transpose .set_attr ("perm" , HWCN_TO_NCHW )
412
+ transpose .skip_conversion = True
413
+ ctx .copy_shape (input_name , transpose .output [0 ])
414
+ new_shape = spatial_map (ctx .get_shape (input_name ), HWCN_TO_NCHW )
415
+ ctx .set_shape (transpose .output [0 ], new_shape )
416
+ nodes .append (transpose )
417
+ parent .data_format = "NCHW"
409
418
410
419
# some onnx conv ops require the reshape the kernel (ie. depthwise_conv2d)
411
420
if new_kernel_shape :
@@ -436,7 +445,6 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
436
445
op_name = utils .make_name (node .name )
437
446
transpose = ctx .insert_new_node_on_output ("Transpose" , output_name , name = op_name )
438
447
transpose .set_attr ("perm" , NCHW_TO_NHWC )
439
- transpose .inserted_nchw = True
440
448
transpose .skip_conversion = True
441
449
ctx .set_shape (transpose .output [0 ], ctx .get_shape (node .output [idx ]))
442
450
nodes .append (transpose )
@@ -2434,7 +2442,6 @@ def transpose_inputs(ctx, inputs_as_nchw):
2434
2442
op_name = utils .make_name (node .name )
2435
2443
transpose = ctx .insert_new_node_on_output ("Transpose" , output_name , name = op_name )
2436
2444
transpose .set_attr ("perm" , NCHW_TO_NHWC )
2437
- transpose .inserted_nchw = True
2438
2445
ctx .copy_shape (output_name , transpose .output [0 ])
2439
2446
ctx .set_shape (output_name , np .array (shape )[NHWC_TO_NCHW ])
2440
2447
ops .append (transpose )
@@ -2527,8 +2534,9 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
2527
2534
# check output existence in case user passed in wrong output ids
2528
2535
non_exists = set (io_to_check ) - set (output_shapes .keys ())
2529
2536
if non_exists :
2530
- log .error ("\n Failed to convert: inputs/outputs specified do not exist, make sure your passed" \
2531
- " in format: input/output_node_name:port_id. Problematical inputs/outputs are: %s \n " , non_exists )
2537
+ log .error ("\n Failed to convert: inputs/outputs specified do not exist, make sure your passed"
2538
+ "in format: input/output_node_name:port_id. Problematical inputs/outputs are: %s \n " ,
2539
+ non_exists )
2532
2540
raise ValueError ("Inputs/Outputs Not Found" )
2533
2541
2534
2542
g = Graph (onnx_nodes , output_shapes , dtypes , target , opset , extra_opset , output_names )
0 commit comments