Skip to content

Commit 9331aa4

Browse files
fix unittest and pylint
1 parent 867109d commit 9331aa4

File tree

3 files changed

+29
-26
lines changed

3 files changed

+29
-26
lines changed

tests/test_cond.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import tensorflow as tf
1313

1414
from backend_test_base import Tf2OnnxBackendTestBase
15-
from common import unittest_main, check_opset_min_version
15+
from common import unittest_main, check_opset_min_version, check_tf_min_version
1616

1717

1818
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -267,7 +267,8 @@ def case_graph():
267267
output_names_with_port = ["output:0"]
268268
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port)
269269

270-
@check_opset_min_version(9, "")
270+
@check_tf_min_version("1.8", "shape inference for Reshape op screws up")
271+
@check_opset_min_version(9, "ConstantOfShape")
271272
def test_cond_with_different_output_shape(self):
272273
input_shape = (10, 5, 20)
273274
inputs = tf.placeholder(tf.float32, input_shape, name="input")
@@ -277,29 +278,29 @@ def test_cond_with_different_output_shape(self):
277278
inputs = tf.reshape(inputs, shape)
278279

279280
def pad_tensor(t, length):
280-
"""Pads the input tensor with 0s along the first dimension up to the length.
281-
282-
Args:
283-
t: the input tensor, assuming the rank is at least 1.
284-
length: a tensor of shape [1] or an integer, indicating the first dimension
285-
of the input tensor t after padding, assuming length <= t.shape[0].
286-
287-
Returns:
288-
padded_t: the padded tensor, whose first dimension is length. If the length
289-
is an integer, the first dimension of padded_t is set to length
290-
statically.
291-
"""
292-
t_rank = tf.rank(t)
293-
t_shape = tf.shape(t)
294-
t_d0 = t_shape[0]
295-
pad_d0 = tf.expand_dims(length - t_d0, 0)
296-
pad_shape = tf.cond(
297-
# shape is [3], depending on input shape
298-
tf.greater(t_rank, 1), lambda: tf.concat([pad_d0, t_shape[1:]], 0),
299-
# shape is always [1]
300-
lambda: tf.expand_dims(length - t_d0, 0))
301-
padded_t = tf.concat([t, tf.zeros(pad_shape, dtype=t.dtype)], 0)
302-
return padded_t
281+
"""Pads the input tensor with 0s along the first dimension up to the length.
282+
283+
Args:
284+
t: the input tensor, assuming the rank is at least 1.
285+
length: a tensor of shape [1] or an integer, indicating the first dimension
286+
of the input tensor t after padding, assuming length <= t.shape[0].
287+
288+
Returns:
289+
padded_t: the padded tensor, whose first dimension is length. If the length
290+
is an integer, the first dimension of padded_t is set to length
291+
statically.
292+
"""
293+
t_rank = tf.rank(t)
294+
t_shape = tf.shape(t)
295+
t_d0 = t_shape[0]
296+
pad_d0 = tf.expand_dims(length - t_d0, 0)
297+
pad_shape = tf.cond(
298+
# shape is [3], depending on input shape
299+
tf.greater(t_rank, 1), lambda: tf.concat([pad_d0, t_shape[1:]], 0),
300+
# shape is always [1]
301+
lambda: tf.expand_dims(length - t_d0, 0))
302+
padded_t = tf.concat([t, tf.zeros(pad_shape, dtype=t.dtype)], 0)
303+
return padded_t
303304

304305
output = pad_tensor(inputs, 20)
305306
_ = tf.identity(output, name="output")

tf2onnx/onnx_opset/generator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def version_7(cls, ctx, node, **kwargs):
4242
# T output = Fill(int32 dims, T value, @int32 index_type)
4343
# T outputs = Tile(T value, int64 repeats (e.g. dims))
4444
fill_shape = ctx.get_shape(node.input[0])
45+
utils.make_sure(fill_shape is not None, "shape of {} is None".format(node.input[0]))
4546
fill_shape_dims = fill_shape[0]
47+
utils.make_sure(fill_shape_dims > 0, "opset 7 requires fill shape length > 0, or please try opset > 7")
4648
val_dtype = ctx.get_dtype(node.input[1])
4749
val_shape = ctx.get_shape(node.input[1])
4850

tf2onnx/rewriter/cond_rewriter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _get_output_shape_dtype(self, cond_context):
107107
false_output = cond_context.false_branch_context.output[i]
108108
true_shape = self.g.get_shape(true_output)
109109
utils.make_sure(true_shape is not None, "Shape of {} is None".format(true_output))
110-
true_rank = len(true_shape)
110+
true_rank = len(true_shape)
111111
true_dtype = self.g.get_dtype(true_output)
112112
false_shape = self.g.get_shape(false_output)
113113
utils.make_sure(false_shape is not None, "Shape of {} is None".format(false_output))

0 commit comments

Comments
 (0)