Skip to content

Commit 348b4dd

Browse files
authored
Merge pull request #507 from lucienwang1009/cond_enhance_1
enhance cond_rewriter by allowing different output shapes
2 parents 25973f3 + 9331aa4 commit 348b4dd

File tree

3 files changed

+59
-9
lines changed

3 files changed

+59
-9
lines changed

tests/test_cond.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import tensorflow as tf
1313

1414
from backend_test_base import Tf2OnnxBackendTestBase
15-
from common import unittest_main
15+
from common import unittest_main, check_opset_min_version, check_tf_min_version
1616

1717

1818
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -267,6 +267,52 @@ def case_graph():
267267
output_names_with_port = ["output:0"]
268268
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port)
269269

270+
@check_tf_min_version("1.8", "shape inference for Reshape op screws up")
271+
@check_opset_min_version(9, "ConstantOfShape")
272+
def test_cond_with_different_output_shape(self):
273+
input_shape = (10, 5, 20)
274+
inputs = tf.placeholder(tf.float32, input_shape, name="input")
275+
276+
shape = tf.placeholder(tf.int32, (len(input_shape),), name="shape")
277+
# cheat onnx shape inference
278+
inputs = tf.reshape(inputs, shape)
279+
280+
def pad_tensor(t, length):
281+
"""Pads the input tensor with 0s along the first dimension up to the length.
282+
283+
Args:
284+
t: the input tensor, assuming the rank is at least 1.
285+
length: a tensor of shape [1] or an integer, indicating the first dimension
286+
of the input tensor t after padding, assuming length <= t.shape[0].
287+
288+
Returns:
289+
padded_t: the padded tensor, whose first dimension is length. If the length
290+
is an integer, the first dimension of padded_t is set to length
291+
statically.
292+
"""
293+
t_rank = tf.rank(t)
294+
t_shape = tf.shape(t)
295+
t_d0 = t_shape[0]
296+
pad_d0 = tf.expand_dims(length - t_d0, 0)
297+
pad_shape = tf.cond(
298+
# shape is [3], depending on input shape
299+
tf.greater(t_rank, 1), lambda: tf.concat([pad_d0, t_shape[1:]], 0),
300+
# shape is always [1]
301+
lambda: tf.expand_dims(length - t_d0, 0))
302+
padded_t = tf.concat([t, tf.zeros(pad_shape, dtype=t.dtype)], 0)
303+
return padded_t
304+
305+
output = pad_tensor(inputs, 20)
306+
_ = tf.identity(output, name="output")
307+
input_names_with_port = ["input:0", "shape:0"]
308+
feed_dict = {
309+
"input:0": np.ones(input_shape, dtype=np.float32),
310+
"shape:0": np.array(input_shape, dtype=np.int32)
311+
}
312+
313+
output_names_with_port = ["output:0"]
314+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
315+
270316

271317
if __name__ == '__main__':
272318
unittest_main()

tf2onnx/onnx_opset/generator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def version_7(cls, ctx, node, **kwargs):
4242
# T output = Fill(int32 dims, T value, @int32 index_type)
4343
# T outputs = Tile(T value, int64 repeats (e.g. dims))
4444
fill_shape = ctx.get_shape(node.input[0])
45+
utils.make_sure(fill_shape is not None, "shape of {} is None".format(node.input[0]))
4546
fill_shape_dims = fill_shape[0]
47+
utils.make_sure(fill_shape_dims > 0, "opset 7 requires fill shape length > 0, or please try opset > 7")
4648
val_dtype = ctx.get_dtype(node.input[1])
4749
val_shape = ctx.get_shape(node.input[1])
4850

tf2onnx/rewriter/cond_rewriter.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,21 @@ def _get_output_shape_dtype(self, cond_context):
106106
true_output = cond_context.true_branch_context.output[i]
107107
false_output = cond_context.false_branch_context.output[i]
108108
true_shape = self.g.get_shape(true_output)
109+
utils.make_sure(true_shape is not None, "Shape of {} is None".format(true_output))
110+
true_rank = len(true_shape)
109111
true_dtype = self.g.get_dtype(true_output)
110112
false_shape = self.g.get_shape(false_output)
113+
utils.make_sure(false_shape is not None, "Shape of {} is None".format(false_output))
114+
false_rank = len(false_shape)
111115
false_dtype = self.g.get_dtype(false_output)
112-
if not utils.are_shapes_compatible(true_shape, false_shape):
116+
# just require rank is equal
117+
if true_rank != false_rank:
113118
raise RuntimeError(
114-
"the shape of outputs {} and {} mismatch: {}, {}".format(
119+
"the rank of outputs {} and {} mismatch: {}, {}".format(
115120
true_output,
116121
false_output,
117-
true_shape,
118-
false_shape
122+
true_rank,
123+
false_rank
119124
)
120125
)
121126
if true_dtype != false_dtype:
@@ -127,10 +132,7 @@ def _get_output_shape_dtype(self, cond_context):
127132
false_dtype
128133
)
129134
)
130-
# in tf, the shape of different branched can be different,
131-
# for example output shape of branch A can be [-1] while branch B can be [1].
132-
# Under this case, we should set output shape to be [-1]
133-
output_shapes.append(utils.create_vague_shape_like(utils.merge_shapes(true_shape, false_shape)))
135+
output_shapes.append(utils.create_vague_shape_like(true_shape))
134136
output_dtypes.append(true_dtype)
135137
return output_shapes, output_dtypes
136138

0 commit comments

Comments
 (0)