@@ -383,6 +383,7 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
383
383
transpose = ctx .insert_new_node_on_input (node , "Transpose" , input_name )
384
384
transpose .set_attr ("perm" , NHWC_TO_NCHW )
385
385
transpose .inserted_nchw = True
386
+ transpose .skip_conversion = True
386
387
shape = ctx .get_shape (input_name )
387
388
new_shape = spatial_map (shape , NHWC_TO_NCHW )
388
389
ctx .set_shape (transpose .output [0 ], new_shape )
@@ -399,6 +400,7 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
399
400
transpose = ctx .insert_new_node_on_input (node , "Transpose" , input_name )
400
401
transpose .set_attr ("perm" , HWCN_TO_NCHW )
401
402
transpose .inserted_nchw = True
403
+ transpose .skip_conversion = True
402
404
ctx .copy_shape (input_name , transpose .output [0 ])
403
405
new_shape = spatial_map (ctx .get_shape (input_name ), HWCN_TO_NCHW )
404
406
ctx .set_shape (transpose .output [0 ], new_shape )
@@ -412,13 +414,15 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
412
414
input_name = node .input [1 ]
413
415
reshape = ctx .insert_new_node_on_input (node , "Reshape" , input_name )
414
416
reshape .set_attr ("shape" , new_kernel_shape )
417
+ reshape .skip_conversion = True
415
418
else :
416
419
# new reshape takes new shape as input[1]
417
420
shape_name = utils .make_name (node .name )
418
421
nodes .append (ctx .make_const (shape_name , np .array (new_kernel_shape , dtype = np .int64 )))
419
422
input_name = node .input [1 ]
420
423
reshape = ctx .insert_new_node_on_input (node , "Reshape" , input_name )
421
424
reshape .input .append (shape_name )
425
+ reshape .skip_conversion = True
422
426
ctx .set_shape (reshape .output [0 ], new_kernel_shape )
423
427
nodes .append (reshape )
424
428
@@ -433,6 +437,7 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
433
437
transpose = ctx .insert_new_node_on_output ("Transpose" , output_name , name = op_name )
434
438
transpose .set_attr ("perm" , NCHW_TO_NHWC )
435
439
transpose .inserted_nchw = True
440
+ transpose .skip_conversion = True
436
441
ctx .set_shape (transpose .output [0 ], ctx .get_shape (node .output [idx ]))
437
442
nodes .append (transpose )
438
443
node .data_format = "NCHW"
@@ -2300,6 +2305,47 @@ def rewrite_incomplete_type_support_rs6(g, ops):
2300
2305
return rewrite_incomplete_type_support (g , ops , ["Div" , "IsNaN" , "ReduceSum" , "Slice" , "Split" , "Tile" , "Transpose" ])
2301
2306
2302
2307
2308
+ def rewrite_conv2d_with_pad (g , ops ):
2309
+ pattern = \
2310
+ OpTypePattern ("Conv2D" , name = "conv" , inputs = [
2311
+ OpTypePattern ("Pad" , name = "pad" ),
2312
+ OpTypePattern ("*" )
2313
+ ])
2314
+ matcher = GraphMatcher (pattern )
2315
+ match_results = list (matcher .match_ops (ops ))
2316
+ for match in match_results :
2317
+ conv = match .get_op ("conv" )
2318
+ pad = match .get_op ("pad" )
2319
+ paddings = pad .inputs [1 ]
2320
+
2321
+ if not paddings .is_const ():
2322
+ continue
2323
+ mode = pad .get_attr ("mode" )
2324
+ if mode :
2325
+ mode = mode .s .decode ("utf-8" ).lower ()
2326
+ if mode not in [None , "constant" ] or len (pad .input ) >= 3 :
2327
+ continue
2328
+ # Conv2D already has a pad
2329
+ if conv .get_attr ("padding" ) == "SAME" :
2330
+ continue
2331
+
2332
+ log .debug ("merge pad [%s] into conv [%s]" , pad .name , conv .name )
2333
+ paddings_val = np .array (paddings .get_tensor_value ())
2334
+ # can't pad on batch or channel dimensions
2335
+ if np .any (paddings_val [0 ]) or np .any (paddings_val [3 ]):
2336
+ continue
2337
+ paddings_val = paddings_val [1 :3 ]
2338
+ paddings_val = paddings_val .transpose ().flatten ()
2339
+ g .replace_input (conv , conv .input [0 ], pad .input [0 ])
2340
+ # convert Conv2D
2341
+ conv .type = "Conv"
2342
+ ops .extend (conv_op (g , conv , conv .name , []))
2343
+ conv .skip_conversion = True
2344
+ conv .set_attr ("auto_pad" , "NOTSET" )
2345
+ conv .set_attr ("pads" , paddings_val )
2346
+ return ops
2347
+
2348
+
2303
2349
def tensorflow_onnx_mapping (g , continue_on_error , custom_op_handlers ):
2304
2350
mapped_op = collections .Counter ()
2305
2351
unmapped_op = collections .Counter ()
@@ -2496,7 +2542,8 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
2496
2542
# bi-directional re-writer should be placed after single directional re-writer
2497
2543
rewriters = [rewrite_transpose , rewrite_flatten ,
2498
2544
rewrite_random_uniform , rewrite_random_uniform_fold_const ,
2499
- rewrite_random_normal , rewrite_dropout , rewrite_leakyrelu ,
2545
+ rewrite_random_normal , rewrite_dropout ,
2546
+ rewrite_leakyrelu , rewrite_conv2d_with_pad ,
2500
2547
rewrite_single_direction_lstm , rewrite_bi_direction_lstm ,
2501
2548
rewrite_single_direction_gru , rewrite_single_direction_grublock ,
2502
2549
rewrite_bi_direction_gru , rewrite_logical_compare_with_equal ,
0 commit comments