@@ -3488,6 +3488,17 @@ def func(x):
3488
3488
return tf .identity (picks , name = _TFOUTPUT )
3489
3489
self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val })
3490
3490
3491
+ @check_opset_min_version (9 , "IsNaN" )
3492
+ def test_where_ismulinf (self ):
3493
+ x_val1 = np .array ([np .inf ], dtype = np .float32 )
3494
+ x_val2 = np .array ([0 ], dtype = np .float32 )
3495
+ true_result = np .array ([np .inf ], dtype = np .float32 )
3496
+ def func (x1 , x2 ):
3497
+ mul = tf .multiply (x1 , x2 )
3498
+ picks = tf .where (x1 < mul , true_result , x2 )
3499
+ return tf .identity (picks , name = _TFOUTPUT )
3500
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val1 , _INPUT1 : x_val2 })
3501
+
3491
3502
@check_opset_min_version (9 , "Where for strings needs opset 9" )
3492
3503
@skip_tfjs ("Technically tf where doesn't support strings and tfjs doesn't like it" )
3493
3504
def test_where_string (self ):
@@ -5542,7 +5553,7 @@ def func(x):
5542
5553
self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val0 }, rtol = 1e-6 , atol = 1e-4 )
5543
5554
self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-6 , atol = 1e-4 )
5544
5555
5545
- x_val = np .random .random (size = [4 , 3 ]).astype (np .float32 ) * 2048. - 1024
5556
+ x_val = np .random .random (size = [4 , 3 ]).astype (np .float32 ) * 2048. - 1024.
5546
5557
x_val [0 , 0 ] = - 1024
5547
5558
x_val [0 , 1 ] = - 1023
5548
5559
x_val [0 , 2 ] = 1024
@@ -5579,7 +5590,7 @@ def func(x):
5579
5590
self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val0 }, rtol = 1e-6 , atol = 1e-4 )
5580
5591
self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val }, rtol = 1e-6 , atol = 1e-4 )
5581
5592
5582
- x_val = np .random .random (size = [4 , 3 ]).astype (np .float32 ) * 2048. - 1024
5593
+ x_val = np .random .random (size = [4 , 3 ]).astype (np .float32 ) * 2048. - 1024.
5583
5594
x_val [0 , 0 ] = - 1024
5584
5595
x_val [0 , 1 ] = - 1023
5585
5596
x_val [0 , 2 ] = 1024
0 commit comments