Skip to content

Commit f74d542

Browse files
Fixed bug with ConvTranspose
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 9e8e42d commit f74d542

File tree

2 files changed

+73
-5
lines changed

2 files changed

+73
-5
lines changed

tests/test_backend.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3349,6 +3349,16 @@ def func(input_sizes, filters, out_backprop):
33493349
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 5]).astype(np.float32)
33503350
self._run_test_case(func, [_OUTPUT], {_INPUT: input_sizes_val, _INPUT1: filters_val, _INPUT2: out_backprop_val})
33513351

3352+
@check_opset_min_version(12, "Conv2DBackpropInput with strided workaround")
3353+
def test_Conv2DBackpropInput_strided_same(self):
3354+
def func(input_sizes, filters, out_backprop):
3355+
return conv2d_backprop_input(input_sizes, filters, out_backprop, strides=[1, 5, 10, 1], padding='SAME',
3356+
name=_TFOUTPUT)
3357+
input_sizes_val = np.array([1, 10, 10, 3], dtype=np.int32)
3358+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
3359+
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 2, 1, 5]).astype(np.float32)
3360+
self._run_test_case(func, [_OUTPUT], {_INPUT: input_sizes_val, _INPUT1: filters_val, _INPUT2: out_backprop_val})
3361+
33523362
@check_opset_min_version(10, "Conv3DBackpropInputV2")
33533363
def test_Conv3DBackpropInputV2_const(self):
33543364
output_shape_val_ = np.array([1, 10, 10, 10, 3], dtype=np.int32)
@@ -3415,6 +3425,17 @@ def func(value, filters, output_shape):
34153425
self._run_test_case(func, [_OUTPUT], {_INPUT: value_val, _INPUT1: filters_val, _INPUT2: output_shape_val},
34163426
rtol=1e-6)
34173427

3428+
@check_opset_min_version(12, "Conv3DBackpropInputV2 with strided workaround")
3429+
def test_Conv3DBackpropInputV2_strided_same(self):
3430+
def func(value, filters, output_shape):
3431+
return conv3d_transpose(value, filters, output_shape, strides=[1, 10, 4, 3, 1],
3432+
padding='SAME', data_format="NDHWC", name=_TFOUTPUT)
3433+
filters_val = np.random.randint(low=1, high=256, size=[1, 1, 1, 1, 1]).astype(np.float32)
3434+
value_val = np.random.randint(low=1, high=256, size=[1, 3, 2, 5, 1]).astype(np.float32)
3435+
output_shape_val = np.array([1, 30, 8, 15, 1], dtype=np.int32)
3436+
self._run_test_case(func, [_OUTPUT], {_INPUT: value_val, _INPUT1: filters_val, _INPUT2: output_shape_val},
3437+
rtol=1e-6)
3438+
34183439
@check_opset_min_version(8, "CategoryMapper")
34193440
@skip_tf2()
34203441
def test_hashtable_lookup(self):

tf2onnx/onnx_opset/nn.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,18 @@ def version_1(cls, ctx, node, **kwargs):
429429

430430
node.set_attr("output_shape", new_output_shape)
431431
else:
432-
# FIXME: This case fails in edge cases where strides > 1
432+
utils.make_sure(ctx.opset >= 10, "Opset 10 needed for Conv Backprop Input with non-constant shape")
433+
strides = parse_dims_attr(node, node.get_attr('strides').ints, spatial)
434+
use_strides_workaround = any(d > 1 for d in strides)
435+
if use_strides_workaround and ctx.opset < 12:
436+
# When strides > 1, ONNX and TF have an implementation difference in ConvTranspose. ONNX outputs a
437+
# slightly smaller tensor which must be padded with a row of 0s. Pad with dynamic shape requires
438+
# opset >= 11 and Max of int64 needs opset >= 12. Depending on the output_shape, this row of 0s might
439+
# be shaved off, in which case TF and ONNX agree. When output_shape is dynamic it is impossible to
440+
# know at conversion time whether this is the case and the workaround is needed.
441+
logger.warning("Conv Backprop Input with strides > 1 and non-constant shape has known bug. "
442+
"Workaround requires opset 12.")
443+
use_strides_workaround = False
433444
input_shape = ctx.make_node("Cast", [node.input[0]], attr={'to': TensorProto.INT64})
434445
output_shape = ctx.make_node("Shape", [node.output[0]])
435446
output_h = GraphBuilder(ctx).make_slice(
@@ -442,9 +453,17 @@ def version_1(cls, ctx, node, **kwargs):
442453
{"data": input_shape.output[0], "ends": [3], "starts": [2], "axes": [0]})
443454
diff_h = ctx.make_node("Sub", [output_h, expect_h])
444455
diff_w = ctx.make_node("Sub", [output_w, expect_w])
456+
nonneg_diff_h = diff_h
457+
nonneg_diff_w = diff_w
458+
459+
if use_strides_workaround:
460+
const_zero = ctx.make_const(utils.make_name(node.name + "_const_zero"), np.array([0], dtype=np.int64))
461+
nonneg_diff_h = ctx.make_node("Max", [diff_h.output[0], const_zero.output[0]])
462+
nonneg_diff_w = ctx.make_node("Max", [diff_w.output[0], const_zero.output[0]])
463+
445464
const_two = ctx.make_const(utils.make_name(node.name + "_const_two"), np.array([2], dtype=np.int64))
446-
start_h = ctx.make_node("Div", [diff_h.output[0], const_two.output[0]])
447-
start_w = ctx.make_node("Div", [diff_w.output[0], const_two.output[0]])
465+
start_h = ctx.make_node("Div", [nonneg_diff_h.output[0], const_two.output[0]])
466+
start_w = ctx.make_node("Div", [nonneg_diff_w.output[0], const_two.output[0]])
448467
end_h = ctx.make_node("Add", [start_h.output[0], expect_h])
449468
end_w = ctx.make_node("Add", [start_w.output[0], expect_w])
450469
if spatial == 3:
@@ -453,7 +472,10 @@ def version_1(cls, ctx, node, **kwargs):
453472
expect_d = GraphBuilder(ctx).make_slice(
454473
{"data": input_shape.output[0], "ends": [4], "starts": [3], "axes": [0]})
455474
diff_d = ctx.make_node("Sub", [output_d, expect_d])
456-
start_d = ctx.make_node("Div", [diff_d.output[0], const_two.output[0]])
475+
nonneg_diff_d = diff_d
476+
if use_strides_workaround:
477+
nonneg_diff_d = ctx.make_node("Max", [diff_d.output[0], const_zero.output[0]])
478+
start_d = ctx.make_node("Div", [nonneg_diff_d.output[0], const_two.output[0]])
457479
end_d = ctx.make_node("Add", [start_d.output[0], expect_d])
458480

459481
starts = ctx.make_node("Concat", [start_h.output[0], start_w.output[0], start_d.output[0]],
@@ -471,10 +493,35 @@ def version_1(cls, ctx, node, **kwargs):
471493
[node.output[0], starts.output[0], ends.output[0], slice_axes.output[0]],
472494
shapes=output_shape_orig)
473495

496+
final_node = slice_node
497+
498+
if use_strides_workaround:
499+
cz = const_zero.output[0]
500+
501+
neg_diff_h = ctx.make_node("Neg", [diff_h.output[0]])
502+
shrink_h_by = ctx.make_node("Max", [neg_diff_h.output[0], const_zero.output[0]])
503+
shb = shrink_h_by.output[0]
504+
505+
neg_diff_w = ctx.make_node("Neg", [diff_w.output[0]])
506+
shrink_w_by = ctx.make_node("Max", [neg_diff_w.output[0], const_zero.output[0]])
507+
swb = shrink_w_by.output[0]
508+
509+
if spatial == 3:
510+
neg_diff_d = ctx.make_node("Neg", [diff_d.output[0]])
511+
shrink_d_by = ctx.make_node("Max", [neg_diff_d.output[0], const_zero.output[0]])
512+
sdb = shrink_d_by.output[0]
513+
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, cz, shb, swb, sdb, cz], attr={"axis": 0})
514+
padded_node = ctx.make_node("Pad", [slice_node.output[0], pads.output[0]])
515+
else:
516+
pads = ctx.make_node("Concat", [cz, cz, cz, cz, cz, shb, swb, cz], attr={"axis": 0})
517+
padded_node = ctx.make_node("Pad", [slice_node.output[0], pads.output[0]])
518+
519+
final_node = padded_node
520+
474521
downstream_nodes = ctx.find_output_consumers(node.output[0])
475522
downstream_nodes.remove(output_shape)
476523
downstream_nodes.remove(slice_node)
477-
ctx.replace_all_inputs(node.output[0], slice_node.output[0], ops=downstream_nodes)
524+
ctx.replace_all_inputs(node.output[0], final_node.output[0], ops=downstream_nodes)
478525

479526
conv_dims_attr(node, "strides", spatial=spatial)
480527
conv_dims_attr(node, "dilations", spatial=spatial)

0 commit comments

Comments
 (0)