Skip to content

Commit b8bfb37

Browse files
Improve ConvTranspose shape inference to work with unknown batch dim
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 0e2ea55 commit b8bfb37

File tree

3 files changed

+52
-3
lines changed

3 files changed

+52
-3
lines changed

tests/test_backend.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3257,6 +3257,31 @@ def func(filter_val, out_backprop_val):
32573257
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 5, 5, 5]).astype(np.float32)
32583258
self._run_test_case(func, [_OUTPUT], {_INPUT: filters_val, _INPUT1: out_backprop_val})
32593259

3260+
@check_tf_min_version("1.15", "tf.repeat needs tf 1.15")
3261+
@check_opset_min_version(10, "Conv2DBackpropInput")
3262+
def test_Conv2DBackpropInput_shape_implied(self):
3263+
batch_dim_val = np.array(1, dtype=np.int32)
3264+
def func(filter_val, out_backprop_val, batch_dim):
3265+
out_backprop_val = tf.repeat(out_backprop_val, batch_dim, axis=0)
3266+
s = tf.shape(out_backprop_val)
3267+
t1 = tf.constant([0], dtype=tf.int32)
3268+
t2 = tf.constant([1], dtype=tf.int32)
3269+
batch_dim = tf.strided_slice(s, t1, t2, shrink_axis_mask=1)
3270+
# Sometimes the size given is a stack of constants with unknown batch dim
3271+
input_sizes_val = tf.stack([batch_dim, 10, 10, 3])
3272+
return conv2d_backprop_input(input_sizes=input_sizes_val, filter=filter_val,
3273+
out_backprop=out_backprop_val, strides=[1, 2, 2, 1],
3274+
padding='SAME', name=_TFOUTPUT)
3275+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
3276+
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 5, 5, 5]).astype(np.float32)
3277+
def graph_validator(g):
3278+
for n in g.get_nodes():
3279+
if n.type == 'ConvTranspose':
3280+
return "output_shape" in n.attr
3281+
return False
3282+
self._run_test_case(func, [_OUTPUT], {_INPUT: filters_val, _INPUT1: out_backprop_val, _INPUT2: batch_dim_val},
3283+
graph_validator=graph_validator)
3284+
32603285
@check_opset_min_version(10, "Conv2DBackpropInput")
32613286
def test_Conv2DBackpropInput_const_valid(self):
32623287
input_sizes_val_ = np.array([1, 12, 12, 3], dtype=np.int32)

tf2onnx/onnx_opset/nn.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,27 @@ def version_11(cls, ctx, node, **kwargs):
366366
# No change.
367367
cls.version_1(ctx, node, **kwargs)
368368

369+
def get_shape_from_const_or_concat(ctx, node):
370+
if node.is_const():
371+
return node.get_tensor_value()
372+
if node.type == 'Concat':
373+
# Sometimes the shape is formed by concating a bunch of consts together
374+
res = []
375+
if any(ctx.get_shape(inp) != [1] for inp in node.input):
376+
return None
377+
for i, inp in enumerate(node.inputs):
378+
# The concat is converted from a Pack. Conversion adds an unsqueeze to the inputs.
379+
if node.inputs[i].type == 'Unsqueeze' and node.inputs[i].inputs[0].is_scalar():
380+
res.append(node.inputs[i].inputs[0].get_tensor_value())
381+
else:
382+
if i == 0:
383+
# For the batch dimension we don't care if it is unknown
384+
res.append(-1)
385+
else:
386+
return None
387+
return res
388+
return None
389+
369390
@tf_op(["Conv2DBackpropInput", "Conv3DBackpropInputV2"])
370391
class ConvTranspose:
371392
@classmethod
@@ -386,8 +407,9 @@ def version_1(cls, ctx, node, **kwargs):
386407
output_shape_orig = node.output_shapes
387408

388409
# ouput_shape is explicitly specified here, in this case pads values are auto generated/calculated.
389-
if node.inputs[0].is_const():
390-
output_shape = ctx.get_shape(node.output[0])
410+
output_shape = get_shape_from_const_or_concat(ctx, node.inputs[0])
411+
if output_shape is not None:
412+
#output_shape = ctx.get_shape(node.output[0])
391413
if is_channels_last(node):
392414
new_output_shape = [output_shape[1], output_shape[2]]
393415
input_dims = [input_shape[1], input_shape[2]]
@@ -407,6 +429,7 @@ def version_1(cls, ctx, node, **kwargs):
407429

408430
node.set_attr("output_shape", new_output_shape)
409431
else:
432+
# FIXME: This case fails in edge cases where strides > 1
410433
input_shape = ctx.make_node("Cast", [node.input[0]], attr={'to': TensorProto.INT64})
411434
output_shape = ctx.make_node("Shape", [node.output[0]])
412435
output_h = GraphBuilder(ctx).make_slice(

tf2onnx/onnx_opset/tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1054,7 +1054,8 @@ def version_1(cls, ctx, node, **kwargs):
10541054
# insert Unsqueeze on each input
10551055
for i, n in enumerate(node.inputs):
10561056
dtype = ctx.get_dtype(node.input[i])
1057-
shape = ctx.get_shape(node.input[i])
1057+
shape = ctx.get_shape(node.input[i]).copy()
1058+
shape.insert(axis, 1)
10581059
new_node = ctx.make_node("Unsqueeze", [node.input[i]], op_name_scope=node.name, attr={"axes": [axis]},
10591060
shapes=[shape], dtypes=[dtype])
10601061
output_name = new_node.output[0]

0 commit comments

Comments
 (0)