Skip to content

Commit 867109d

Browse files
enhance cond_rewriter by allowing different output shapes
1 parent 03a6379 commit 867109d

File tree

2 files changed

+56
-9
lines changed

2 files changed

+56
-9
lines changed

tests/test_cond.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import tensorflow as tf
1313

1414
from backend_test_base import Tf2OnnxBackendTestBase
15-
from common import unittest_main
15+
from common import unittest_main, check_opset_min_version
1616

1717

1818
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -267,6 +267,51 @@ def case_graph():
267267
output_names_with_port = ["output:0"]
268268
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port)
269269

270+
@check_opset_min_version(9, "")
271+
def test_cond_with_different_output_shape(self):
272+
input_shape = (10, 5, 20)
273+
inputs = tf.placeholder(tf.float32, input_shape, name="input")
274+
275+
shape = tf.placeholder(tf.int32, (len(input_shape),), name="shape")
276+
# cheat onnx shape inference
277+
inputs = tf.reshape(inputs, shape)
278+
279+
def pad_tensor(t, length):
280+
"""Pads the input tensor with 0s along the first dimension up to the length.
281+
282+
Args:
283+
t: the input tensor, assuming the rank is at least 1.
284+
length: a tensor of shape [1] or an integer, indicating the first dimension
285+
of the input tensor t after padding, assuming length <= t.shape[0].
286+
287+
Returns:
288+
padded_t: the padded tensor, whose first dimension is length. If the length
289+
is an integer, the first dimension of padded_t is set to length
290+
statically.
291+
"""
292+
t_rank = tf.rank(t)
293+
t_shape = tf.shape(t)
294+
t_d0 = t_shape[0]
295+
pad_d0 = tf.expand_dims(length - t_d0, 0)
296+
pad_shape = tf.cond(
297+
# shape is [3], depending on input shape
298+
tf.greater(t_rank, 1), lambda: tf.concat([pad_d0, t_shape[1:]], 0),
299+
# shape is always [1]
300+
lambda: tf.expand_dims(length - t_d0, 0))
301+
padded_t = tf.concat([t, tf.zeros(pad_shape, dtype=t.dtype)], 0)
302+
return padded_t
303+
304+
output = pad_tensor(inputs, 20)
305+
_ = tf.identity(output, name="output")
306+
input_names_with_port = ["input:0", "shape:0"]
307+
feed_dict = {
308+
"input:0": np.ones(input_shape, dtype=np.float32),
309+
"shape:0": np.array(input_shape, dtype=np.int32)
310+
}
311+
312+
output_names_with_port = ["output:0"]
313+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
314+
270315

271316
if __name__ == '__main__':
272317
unittest_main()

tf2onnx/rewriter/cond_rewriter.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,21 @@ def _get_output_shape_dtype(self, cond_context):
106106
true_output = cond_context.true_branch_context.output[i]
107107
false_output = cond_context.false_branch_context.output[i]
108108
true_shape = self.g.get_shape(true_output)
109+
utils.make_sure(true_shape is not None, "Shape of {} is None".format(true_output))
110+
true_rank = len(true_shape)
109111
true_dtype = self.g.get_dtype(true_output)
110112
false_shape = self.g.get_shape(false_output)
113+
utils.make_sure(false_shape is not None, "Shape of {} is None".format(false_output))
114+
false_rank = len(false_shape)
111115
false_dtype = self.g.get_dtype(false_output)
112-
if not utils.are_shapes_compatible(true_shape, false_shape):
116+
# just require rank is equal
117+
if true_rank != false_rank:
113118
raise RuntimeError(
114-
"the shape of outputs {} and {} mismatch: {}, {}".format(
119+
"the rank of outputs {} and {} mismatch: {}, {}".format(
115120
true_output,
116121
false_output,
117-
true_shape,
118-
false_shape
122+
true_rank,
123+
false_rank
119124
)
120125
)
121126
if true_dtype != false_dtype:
@@ -127,10 +132,7 @@ def _get_output_shape_dtype(self, cond_context):
127132
false_dtype
128133
)
129134
)
130-
# in tf, the shape of different branched can be different,
131-
# for example output shape of branch A can be [-1] while branch B can be [1].
132-
# Under this case, we should set output shape to be [-1]
133-
output_shapes.append(utils.create_vague_shape_like(utils.merge_shapes(true_shape, false_shape)))
135+
output_shapes.append(utils.create_vague_shape_like(true_shape))
134136
output_dtypes.append(true_dtype)
135137
return output_shapes, output_dtypes
136138

0 commit comments

Comments
 (0)