|
12 | 12 | import tensorflow as tf
|
13 | 13 |
|
14 | 14 | from backend_test_base import Tf2OnnxBackendTestBase
|
15 |
| -from common import unittest_main |
| 15 | +from common import unittest_main, check_opset_min_version, check_tf_min_version |
16 | 16 |
|
17 | 17 |
|
18 | 18 | # pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
|
@@ -267,6 +267,52 @@ def case_graph():
|
267 | 267 | output_names_with_port = ["output:0"]
|
268 | 268 | self.run_test_case(feed_dict, input_names_with_port, output_names_with_port)
|
269 | 269 |
|
| 270 | + @check_tf_min_version("1.8", "shape inference for Reshape op screws up") |
| 271 | + @check_opset_min_version(9, "ConstantOfShape") |
| 272 | + def test_cond_with_different_output_shape(self): |
| 273 | + input_shape = (10, 5, 20) |
| 274 | + inputs = tf.placeholder(tf.float32, input_shape, name="input") |
| 275 | + |
| 276 | + shape = tf.placeholder(tf.int32, (len(input_shape),), name="shape") |
| 277 | + # cheat onnx shape inference |
| 278 | + inputs = tf.reshape(inputs, shape) |
| 279 | + |
| 280 | + def pad_tensor(t, length): |
| 281 | + """Pads the input tensor with 0s along the first dimension up to the length. |
| 282 | +
|
| 283 | + Args: |
| 284 | + t: the input tensor, assuming the rank is at least 1. |
| 285 | + length: a tensor of shape [1] or an integer, indicating the first dimension |
| 286 | + of the input tensor t after padding, assuming length <= t.shape[0]. |
| 287 | +
|
| 288 | + Returns: |
| 289 | + padded_t: the padded tensor, whose first dimension is length. If the length |
| 290 | + is an integer, the first dimension of padded_t is set to length |
| 291 | + statically. |
| 292 | + """ |
| 293 | + t_rank = tf.rank(t) |
| 294 | + t_shape = tf.shape(t) |
| 295 | + t_d0 = t_shape[0] |
| 296 | + pad_d0 = tf.expand_dims(length - t_d0, 0) |
| 297 | + pad_shape = tf.cond( |
| 298 | + # shape is [3], depending on input shape |
| 299 | + tf.greater(t_rank, 1), lambda: tf.concat([pad_d0, t_shape[1:]], 0), |
| 300 | + # shape is always [1] |
| 301 | + lambda: tf.expand_dims(length - t_d0, 0)) |
| 302 | + padded_t = tf.concat([t, tf.zeros(pad_shape, dtype=t.dtype)], 0) |
| 303 | + return padded_t |
| 304 | + |
| 305 | + output = pad_tensor(inputs, 20) |
| 306 | + _ = tf.identity(output, name="output") |
| 307 | + input_names_with_port = ["input:0", "shape:0"] |
| 308 | + feed_dict = { |
| 309 | + "input:0": np.ones(input_shape, dtype=np.float32), |
| 310 | + "shape:0": np.array(input_shape, dtype=np.int32) |
| 311 | + } |
| 312 | + |
| 313 | + output_names_with_port = ["output:0"] |
| 314 | + self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06) |
| 315 | + |
270 | 316 |
|
271 | 317 | if __name__ == '__main__':
|
272 | 318 | unittest_main()
|
0 commit comments