@@ -2080,15 +2080,17 @@ def test_space_to_batchnd(self):
2080
2080
2081
2081
@check_opset_min_version (10 , "is_inf" )
2082
2082
def test_isinf (self ):
2083
- x_val1 = np .array ([1.0 , 2.0 , - 3.0 , - 4.0 ], dtype = np .float32 ).reshape ((2 , 2 ))
2084
- x_val2 = np .array ([np .inf , np .inf , np .inf , np .inf ], dtype = np .float32 ).reshape ((2 , 2 ))
2085
- x_val3 = np .array ([1.0 , np .inf , - 3.0 , np .inf ], dtype = np .float32 ).reshape ((2 , 2 ))
2086
- for x_val in [x_val1 , x_val2 , x_val3 ]:
2087
- x = tf .placeholder (tf .float32 , x_val .shape , name = _TFINPUT )
2088
- x_ = tf .is_inf (x )
2089
- _ = tf .identity (x_ , name = _TFOUTPUT )
2090
- self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
2091
- tf .reset_default_graph ()
2083
+ x_types = [np .float32 , np .float64 ]
2084
+ for x_type in x_types :
2085
+ x_val1 = np .array ([1.0 , - 2.0 , 3.0 , - 4.0 ], dtype = x_type ).reshape ((2 , 2 ))
2086
+ x_val2 = np .array ([np .inf , np .inf , np .inf , np .inf ], dtype = x_type ).reshape ((1 , 4 ))
2087
+ x_val3 = np .array ([1.0 , np .inf , - 3.0 , np .inf , 5.0 , np .inf , - 7.0 , np .inf , 9.0 ], dtype = x_type ).reshape ((3 , 3 ))
2088
+ for x_val in [x_val1 , x_val2 , x_val3 ]:
2089
+ x = tf .placeholder (x_type , x_val .shape , name = _TFINPUT )
2090
+ x_ = tf .is_inf (x )
2091
+ _ = tf .identity (x_ , name = _TFOUTPUT )
2092
+ self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
2093
+ tf .reset_default_graph ()
2092
2094
2093
2095
if __name__ == '__main__' :
2094
2096
unittest_main ()
0 commit comments