12
12
import tensorflow as tf
13
13
14
14
from backend_test_base import Tf2OnnxBackendTestBase
15
- from common import unittest_main , check_opset_min_version
15
+ from common import unittest_main , check_opset_min_version , check_tf_min_version
16
16
17
17
18
18
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -267,7 +267,8 @@ def case_graph():
267
267
output_names_with_port = ["output:0" ]
268
268
self .run_test_case (feed_dict , input_names_with_port , output_names_with_port )
269
269
270
- @check_opset_min_version (9 , "" )
270
+ @check_tf_min_version ("1.8" , "shape inference for Reshape op screws up" )
271
+ @check_opset_min_version (9 , "ConstantOfShape" )
271
272
def test_cond_with_different_output_shape (self ):
272
273
input_shape = (10 , 5 , 20 )
273
274
inputs = tf .placeholder (tf .float32 , input_shape , name = "input" )
@@ -277,29 +278,29 @@ def test_cond_with_different_output_shape(self):
277
278
inputs = tf .reshape (inputs , shape )
278
279
279
280
def pad_tensor (t , length ):
280
- """Pads the input tensor with 0s along the first dimension up to the length.
281
-
282
- Args:
283
- t: the input tensor, assuming the rank is at least 1.
284
- length: a tensor of shape [1] or an integer, indicating the first dimension
285
- of the input tensor t after padding, assuming length <= t.shape[0].
286
-
287
- Returns:
288
- padded_t: the padded tensor, whose first dimension is length. If the length
289
- is an integer, the first dimension of padded_t is set to length
290
- statically.
291
- """
292
- t_rank = tf .rank (t )
293
- t_shape = tf .shape (t )
294
- t_d0 = t_shape [0 ]
295
- pad_d0 = tf .expand_dims (length - t_d0 , 0 )
296
- pad_shape = tf .cond (
297
- # shape is [3], depending on input shape
298
- tf .greater (t_rank , 1 ), lambda : tf .concat ([pad_d0 , t_shape [1 :]], 0 ),
299
- # shape is always [1]
300
- lambda : tf .expand_dims (length - t_d0 , 0 ))
301
- padded_t = tf .concat ([t , tf .zeros (pad_shape , dtype = t .dtype )], 0 )
302
- return padded_t
281
+ """Pads the input tensor with 0s along the first dimension up to the length.
282
+
283
+ Args:
284
+ t: the input tensor, assuming the rank is at least 1.
285
+ length: a tensor of shape [1] or an integer, indicating the first dimension
286
+ of the input tensor t after padding, assuming length <= t.shape[0].
287
+
288
+ Returns:
289
+ padded_t: the padded tensor, whose first dimension is length. If the length
290
+ is an integer, the first dimension of padded_t is set to length
291
+ statically.
292
+ """
293
+ t_rank = tf .rank (t )
294
+ t_shape = tf .shape (t )
295
+ t_d0 = t_shape [0 ]
296
+ pad_d0 = tf .expand_dims (length - t_d0 , 0 )
297
+ pad_shape = tf .cond (
298
+ # shape is [3], depending on input shape
299
+ tf .greater (t_rank , 1 ), lambda : tf .concat ([pad_d0 , t_shape [1 :]], 0 ),
300
+ # shape is always [1]
301
+ lambda : tf .expand_dims (length - t_d0 , 0 ))
302
+ padded_t = tf .concat ([t , tf .zeros (pad_shape , dtype = t .dtype )], 0 )
303
+ return padded_t
303
304
304
305
output = pad_tensor (inputs , 20 )
305
306
_ = tf .identity (output , name = "output" )
0 commit comments