Skip to content

Commit 321714d

Browse files
committed
add UT
1 parent 45786b9 commit 321714d

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

tests/test_backend.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2410,16 +2410,16 @@ def func(x):
24102410

24112411
@check_opset_min_version(7, "fill")
24122412
def test_zeros_like(self):
2413-
input_val = np.random.random_sample([10, 20]).astype(np.float32)
2414-
def func(x):
2415-
res = tf.zeros_like(x)
2416-
return tf.identity(res, name=_TFOUTPUT)
2417-
self._run_test_case(func, [_OUTPUT], {_INPUT: input_val})
2413+
input_x = np.random.random_sample([10, 20]).astype(np.float32)
2414+
input_y = np.array([20, 10]).astype(np.int64)
24182415

2419-
def func(x):
2420-
res = tf.zeros_like(x, dtype=tf.int32)
2421-
return tf.identity(res, name=_TFOUTPUT)
2422-
self._run_test_case(func, [_OUTPUT], {_INPUT: input_val})
2416+
def func(x, y):
2417+
z = tf.reshape(x, y)
2418+
return tf.zeros_like(z, name=_TFOUTPUT)
2419+
2420+
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x, _INPUT1: input_y})
2421+
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x.astype(np.int32), _INPUT1: input_y})
2422+
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x > 0.5, _INPUT1: input_y})
24232423

24242424
@check_opset_min_version(9, "is_nan")
24252425
def test_isnan(self):

0 commit comments

Comments
 (0)