Skip to content

Commit 9e8e42d

Browse files
Merge pull request #1147 from onnx/tom/FixConv3DConst
Improve ConvTranspose shape inference to work with unknown batch dim
2 parents e5bb0b5 + b8bfb37 commit 9e8e42d

File tree

3 files changed

+52
-3
lines changed

3 files changed

+52
-3
lines changed

tests/test_backend.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3282,6 +3282,31 @@ def func(filter_val, out_backprop_val):
32823282
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 5, 5, 5]).astype(np.float32)
32833283
self._run_test_case(func, [_OUTPUT], {_INPUT: filters_val, _INPUT1: out_backprop_val})
32843284

3285+
@check_tf_min_version("1.15", "tf.repeat needs tf 1.15")
3286+
@check_opset_min_version(10, "Conv2DBackpropInput")
3287+
def test_Conv2DBackpropInput_shape_implied(self):
3288+
batch_dim_val = np.array(1, dtype=np.int32)
3289+
def func(filter_val, out_backprop_val, batch_dim):
3290+
out_backprop_val = tf.repeat(out_backprop_val, batch_dim, axis=0)
3291+
s = tf.shape(out_backprop_val)
3292+
t1 = tf.constant([0], dtype=tf.int32)
3293+
t2 = tf.constant([1], dtype=tf.int32)
3294+
batch_dim = tf.strided_slice(s, t1, t2, shrink_axis_mask=1)
3295+
# Sometimes the size given is a stack of constants with unknown batch dim
3296+
input_sizes_val = tf.stack([batch_dim, 10, 10, 3])
3297+
return conv2d_backprop_input(input_sizes=input_sizes_val, filter=filter_val,
3298+
out_backprop=out_backprop_val, strides=[1, 2, 2, 1],
3299+
padding='SAME', name=_TFOUTPUT)
3300+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
3301+
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 5, 5, 5]).astype(np.float32)
3302+
def graph_validator(g):
3303+
for n in g.get_nodes():
3304+
if n.type == 'ConvTranspose':
3305+
return "output_shape" in n.attr
3306+
return False
3307+
self._run_test_case(func, [_OUTPUT], {_INPUT: filters_val, _INPUT1: out_backprop_val, _INPUT2: batch_dim_val},
3308+
graph_validator=graph_validator)
3309+
32853310
@check_opset_min_version(10, "Conv2DBackpropInput")
32863311
def test_Conv2DBackpropInput_const_valid(self):
32873312
input_sizes_val_ = np.array([1, 12, 12, 3], dtype=np.int32)

tf2onnx/onnx_opset/nn.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,27 @@ def version_11(cls, ctx, node, **kwargs):
366366
# No change.
367367
cls.version_1(ctx, node, **kwargs)
368368

369+
def get_shape_from_const_or_concat(ctx, node):
370+
if node.is_const():
371+
return node.get_tensor_value()
372+
if node.type == 'Concat':
373+
# Sometimes the shape is formed by concating a bunch of consts together
374+
res = []
375+
if any(ctx.get_shape(inp) != [1] for inp in node.input):
376+
return None
377+
for i, inp in enumerate(node.inputs):
378+
# The concat is converted from a Pack. Conversion adds an unsqueeze to the inputs.
379+
if node.inputs[i].type == 'Unsqueeze' and node.inputs[i].inputs[0].is_scalar():
380+
res.append(node.inputs[i].inputs[0].get_tensor_value())
381+
else:
382+
if i == 0:
383+
# For the batch dimension we don't care if it is unknown
384+
res.append(-1)
385+
else:
386+
return None
387+
return res
388+
return None
389+
369390
@tf_op(["Conv2DBackpropInput", "Conv3DBackpropInputV2"])
370391
class ConvTranspose:
371392
@classmethod
@@ -386,8 +407,9 @@ def version_1(cls, ctx, node, **kwargs):
386407
output_shape_orig = node.output_shapes
387408

388409
# ouput_shape is explicitly specified here, in this case pads values are auto generated/calculated.
389-
if node.inputs[0].is_const():
390-
output_shape = ctx.get_shape(node.output[0])
410+
output_shape = get_shape_from_const_or_concat(ctx, node.inputs[0])
411+
if output_shape is not None:
412+
#output_shape = ctx.get_shape(node.output[0])
391413
if is_channels_last(node):
392414
new_output_shape = [output_shape[1], output_shape[2]]
393415
input_dims = [input_shape[1], input_shape[2]]
@@ -407,6 +429,7 @@ def version_1(cls, ctx, node, **kwargs):
407429

408430
node.set_attr("output_shape", new_output_shape)
409431
else:
432+
# FIXME: This case fails in edge cases where strides > 1
410433
input_shape = ctx.make_node("Cast", [node.input[0]], attr={'to': TensorProto.INT64})
411434
output_shape = ctx.make_node("Shape", [node.output[0]])
412435
output_h = GraphBuilder(ctx).make_slice(

tf2onnx/onnx_opset/tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1054,7 +1054,8 @@ def version_1(cls, ctx, node, **kwargs):
10541054
# insert Unsqueeze on each input
10551055
for i, n in enumerate(node.inputs):
10561056
dtype = ctx.get_dtype(node.input[i])
1057-
shape = ctx.get_shape(node.input[i])
1057+
shape = ctx.get_shape(node.input[i]).copy()
1058+
shape.insert(axis, 1)
10581059
new_node = ctx.make_node("Unsqueeze", [node.input[i]], op_name_scope=node.name, attr={"axes": [axis]},
10591060
shapes=[shape], dtypes=[dtype])
10601061
output_name = new_node.output[0]

0 commit comments

Comments
 (0)