@@ -383,6 +383,7 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
383
383
transpose = ctx .insert_new_node_on_input (node , "Transpose" , input_name )
384
384
transpose .set_attr ("perm" , NHWC_TO_NCHW )
385
385
transpose .inserted_nchw = True
386
+ transpose .skip_conversion = True
386
387
shape = ctx .get_shape (input_name )
387
388
new_shape = spatial_map (shape , NHWC_TO_NCHW )
388
389
ctx .set_shape (transpose .output [0 ], new_shape )
@@ -399,6 +400,7 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
399
400
transpose = ctx .insert_new_node_on_input (node , "Transpose" , input_name )
400
401
transpose .set_attr ("perm" , HWCN_TO_NCHW )
401
402
transpose .inserted_nchw = True
403
+ transpose .skip_conversion = True
402
404
ctx .copy_shape (input_name , transpose .output [0 ])
403
405
new_shape = spatial_map (ctx .get_shape (input_name ), HWCN_TO_NCHW )
404
406
ctx .set_shape (transpose .output [0 ], new_shape )
@@ -412,13 +414,15 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
412
414
input_name = node .input [1 ]
413
415
reshape = ctx .insert_new_node_on_input (node , "Reshape" , input_name )
414
416
reshape .set_attr ("shape" , new_kernel_shape )
417
+ reshape .skip_conversion = True
415
418
else :
416
419
# new reshape takes new shape as input[1]
417
420
shape_name = utils .make_name (node .name )
418
421
nodes .append (ctx .make_const (shape_name , np .array (new_kernel_shape , dtype = np .int64 )))
419
422
input_name = node .input [1 ]
420
423
reshape = ctx .insert_new_node_on_input (node , "Reshape" , input_name )
421
424
reshape .input .append (shape_name )
425
+ reshape .skip_conversion = True
422
426
ctx .set_shape (reshape .output [0 ], new_kernel_shape )
423
427
nodes .append (reshape )
424
428
@@ -433,6 +437,7 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
433
437
transpose = ctx .insert_new_node_on_output ("Transpose" , output_name , name = op_name )
434
438
transpose .set_attr ("perm" , NCHW_TO_NHWC )
435
439
transpose .inserted_nchw = True
440
+ transpose .skip_conversion = True
436
441
ctx .set_shape (transpose .output [0 ], ctx .get_shape (node .output [idx ]))
437
442
nodes .append (transpose )
438
443
node .data_format = "NCHW"
@@ -2280,6 +2285,47 @@ def rewrite_incomplete_type_support_rs6(g, ops):
2280
2285
return rewrite_incomplete_type_support (g , ops , ["Div" , "ReduceSum" , "Slice" , "Split" , "Tile" , "Transpose" ])
2281
2286
2282
2287
2288
+ def rewrite_conv2d_with_pad (g , ops ):
2289
+ pattern = \
2290
+ OpTypePattern ("Conv2D" , name = "conv" , inputs = [
2291
+ OpTypePattern ("Pad" , name = "pad" ),
2292
+ OpTypePattern ("*" )
2293
+ ])
2294
+ matcher = GraphMatcher (pattern )
2295
+ match_results = list (matcher .match_ops (ops ))
2296
+ for match in match_results :
2297
+ conv = match .get_op ("conv" )
2298
+ pad = match .get_op ("pad" )
2299
+ paddings = pad .inputs [1 ]
2300
+
2301
+ if not paddings .is_const ():
2302
+ return ops
2303
+ mode = pad .get_attr ("mode" )
2304
+ if mode :
2305
+ mode = mode .s .decode ("utf-8" ).lower ()
2306
+ if mode not in [None , "constant" ] or len (pad .input ) >= 3 :
2307
+ return ops
2308
+ # Conv2D already has a pad
2309
+ if conv .get_attr ("padding" ) == "SAME" :
2310
+ return ops
2311
+
2312
+ log .debug ("merge pad [%s] into conv [%s]" , pad .name , conv .name )
2313
+ paddings_val = np .array (paddings .get_tensor_value ())
2314
+ # can't pad on batch or channel dimensions
2315
+ if np .any (paddings_val [0 ]) or np .any (paddings_val [3 ]):
2316
+ return ops
2317
+ paddings_val = paddings_val [1 :3 ]
2318
+ paddings_val = paddings_val .transpose ().flatten ()
2319
+ g .replace_input (conv , conv .input [0 ], pad .input [0 ])
2320
+ # convert Conv2D
2321
+ conv .type = "Conv"
2322
+ ops .extend (conv_op (g , conv , conv .name , []))
2323
+ conv .skip_conversion = True
2324
+ conv .set_attr ("auto_pad" , "NOTSET" )
2325
+ conv .set_attr ("pads" , paddings_val )
2326
+ return ops
2327
+
2328
+
2283
2329
def tensorflow_onnx_mapping (g , continue_on_error , custom_op_handlers ):
2284
2330
mapped_op = collections .Counter ()
2285
2331
unmapped_op = collections .Counter ()
@@ -2476,7 +2522,8 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
2476
2522
# bi-directional re-writer should be placed after single directional re-writer
2477
2523
rewriters = [rewrite_transpose , rewrite_flatten ,
2478
2524
rewrite_random_uniform , rewrite_random_uniform_fold_const ,
2479
- rewrite_random_normal , rewrite_dropout , rewrite_leakyrelu ,
2525
+ rewrite_random_normal , rewrite_dropout ,
2526
+ rewrite_leakyrelu , rewrite_conv2d_with_pad ,
2480
2527
rewrite_single_direction_lstm , rewrite_bi_direction_lstm ,
2481
2528
rewrite_single_direction_gru , rewrite_single_direction_grublock ,
2482
2529
rewrite_bi_direction_gru , rewrite_logical_compare_with_equal ,
0 commit comments