@@ -429,7 +429,18 @@ def version_1(cls, ctx, node, **kwargs):
429
429
430
430
node .set_attr ("output_shape" , new_output_shape )
431
431
else :
432
- # FIXME: This case fails in edge cases where strides > 1
432
+ utils .make_sure (ctx .opset >= 10 , "Opset 10 needed for Conv Backprop Input with non-constant shape" )
433
+ strides = parse_dims_attr (node , node .get_attr ('strides' ).ints , spatial )
434
+ use_strides_workaround = any (d > 1 for d in strides )
435
+ if use_strides_workaround and ctx .opset < 12 :
436
+ # When strides > 1, ONNX and TF have an implementation difference in ConvTranspose. ONNX outputs a
437
+ # slightly smaller tensor which must be padded with a row of 0s. Pad with dynamic shape requires
438
+ # opset >= 11 and Max of int64 needs opset >= 12. Depending on the output_shape, this row of 0s might
439
+ # be shaved off, in which case TF and ONNX agree. When output_shape is dynamic it is impossible to
440
+ # know at conversion time whether this is the case and the workaround is needed.
441
+ logger .warning ("Conv Backprop Input with strides > 1 and non-constant shape has known bug. "
442
+ "Workaround requires opset 12." )
443
+ use_strides_workaround = False
433
444
input_shape = ctx .make_node ("Cast" , [node .input [0 ]], attr = {'to' : TensorProto .INT64 })
434
445
output_shape = ctx .make_node ("Shape" , [node .output [0 ]])
435
446
output_h = GraphBuilder (ctx ).make_slice (
@@ -442,9 +453,17 @@ def version_1(cls, ctx, node, **kwargs):
442
453
{"data" : input_shape .output [0 ], "ends" : [3 ], "starts" : [2 ], "axes" : [0 ]})
443
454
diff_h = ctx .make_node ("Sub" , [output_h , expect_h ])
444
455
diff_w = ctx .make_node ("Sub" , [output_w , expect_w ])
456
+ nonneg_diff_h = diff_h
457
+ nonneg_diff_w = diff_w
458
+
459
+ if use_strides_workaround :
460
+ const_zero = ctx .make_const (utils .make_name (node .name + "_const_zero" ), np .array ([0 ], dtype = np .int64 ))
461
+ nonneg_diff_h = ctx .make_node ("Max" , [diff_h .output [0 ], const_zero .output [0 ]])
462
+ nonneg_diff_w = ctx .make_node ("Max" , [diff_w .output [0 ], const_zero .output [0 ]])
463
+
445
464
const_two = ctx .make_const (utils .make_name (node .name + "_const_two" ), np .array ([2 ], dtype = np .int64 ))
446
- start_h = ctx .make_node ("Div" , [diff_h .output [0 ], const_two .output [0 ]])
447
- start_w = ctx .make_node ("Div" , [diff_w .output [0 ], const_two .output [0 ]])
465
+ start_h = ctx .make_node ("Div" , [nonneg_diff_h .output [0 ], const_two .output [0 ]])
466
+ start_w = ctx .make_node ("Div" , [nonneg_diff_w .output [0 ], const_two .output [0 ]])
448
467
end_h = ctx .make_node ("Add" , [start_h .output [0 ], expect_h ])
449
468
end_w = ctx .make_node ("Add" , [start_w .output [0 ], expect_w ])
450
469
if spatial == 3 :
@@ -453,7 +472,10 @@ def version_1(cls, ctx, node, **kwargs):
453
472
expect_d = GraphBuilder (ctx ).make_slice (
454
473
{"data" : input_shape .output [0 ], "ends" : [4 ], "starts" : [3 ], "axes" : [0 ]})
455
474
diff_d = ctx .make_node ("Sub" , [output_d , expect_d ])
456
- start_d = ctx .make_node ("Div" , [diff_d .output [0 ], const_two .output [0 ]])
475
+ nonneg_diff_d = diff_d
476
+ if use_strides_workaround :
477
+ nonneg_diff_d = ctx .make_node ("Max" , [diff_d .output [0 ], const_zero .output [0 ]])
478
+ start_d = ctx .make_node ("Div" , [nonneg_diff_d .output [0 ], const_two .output [0 ]])
457
479
end_d = ctx .make_node ("Add" , [start_d .output [0 ], expect_d ])
458
480
459
481
starts = ctx .make_node ("Concat" , [start_h .output [0 ], start_w .output [0 ], start_d .output [0 ]],
@@ -471,10 +493,35 @@ def version_1(cls, ctx, node, **kwargs):
471
493
[node .output [0 ], starts .output [0 ], ends .output [0 ], slice_axes .output [0 ]],
472
494
shapes = output_shape_orig )
473
495
496
+ final_node = slice_node
497
+
498
+ if use_strides_workaround :
499
+ cz = const_zero .output [0 ]
500
+
501
+ neg_diff_h = ctx .make_node ("Neg" , [diff_h .output [0 ]])
502
+ shrink_h_by = ctx .make_node ("Max" , [neg_diff_h .output [0 ], const_zero .output [0 ]])
503
+ shb = shrink_h_by .output [0 ]
504
+
505
+ neg_diff_w = ctx .make_node ("Neg" , [diff_w .output [0 ]])
506
+ shrink_w_by = ctx .make_node ("Max" , [neg_diff_w .output [0 ], const_zero .output [0 ]])
507
+ swb = shrink_w_by .output [0 ]
508
+
509
+ if spatial == 3 :
510
+ neg_diff_d = ctx .make_node ("Neg" , [diff_d .output [0 ]])
511
+ shrink_d_by = ctx .make_node ("Max" , [neg_diff_d .output [0 ], const_zero .output [0 ]])
512
+ sdb = shrink_d_by .output [0 ]
513
+ pads = ctx .make_node ("Concat" , [cz , cz , cz , cz , cz , cz , shb , swb , sdb , cz ], attr = {"axis" : 0 })
514
+ padded_node = ctx .make_node ("Pad" , [slice_node .output [0 ], pads .output [0 ]])
515
+ else :
516
+ pads = ctx .make_node ("Concat" , [cz , cz , cz , cz , cz , shb , swb , cz ], attr = {"axis" : 0 })
517
+ padded_node = ctx .make_node ("Pad" , [slice_node .output [0 ], pads .output [0 ]])
518
+
519
+ final_node = padded_node
520
+
474
521
downstream_nodes = ctx .find_output_consumers (node .output [0 ])
475
522
downstream_nodes .remove (output_shape )
476
523
downstream_nodes .remove (slice_node )
477
- ctx .replace_all_inputs (node .output [0 ], slice_node .output [0 ], ops = downstream_nodes )
524
+ ctx .replace_all_inputs (node .output [0 ], final_node .output [0 ], ops = downstream_nodes )
478
525
479
526
conv_dims_attr (node , "strides" , spatial = spatial )
480
527
conv_dims_attr (node , "dilations" , spatial = spatial )
0 commit comments