@@ -2410,16 +2410,16 @@ def func(x):
2410
2410
2411
2411
@check_opset_min_version (7 , "fill" )
2412
2412
def test_zeros_like (self ):
2413
- input_val = np .random .random_sample ([10 , 20 ]).astype (np .float32 )
2414
- def func (x ):
2415
- res = tf .zeros_like (x )
2416
- return tf .identity (res , name = _TFOUTPUT )
2417
- self ._run_test_case (func , [_OUTPUT ], {_INPUT : input_val })
2413
+ input_x = np .random .random_sample ([10 , 20 ]).astype (np .float32 )
2414
+ input_y = np .array ([20 , 10 ]).astype (np .int64 )
2418
2415
2419
- def func (x ):
2420
- res = tf .zeros_like (x , dtype = tf .int32 )
2421
- return tf .identity (res , name = _TFOUTPUT )
2422
- self ._run_test_case (func , [_OUTPUT ], {_INPUT : input_val })
2416
+ def func (x , y ):
2417
+ z = tf .reshape (x , y )
2418
+ return tf .zeros_like (z , name = _TFOUTPUT )
2419
+
2420
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : input_x , _INPUT1 : input_y })
2421
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : input_x .astype (np .int32 ), _INPUT1 : input_y })
2422
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : input_x > 0.5 , _INPUT1 : input_y })
2423
2423
2424
2424
@check_opset_min_version (9 , "is_nan" )
2425
2425
def test_isnan (self ):
0 commit comments