1414
1515from scipy ._lib ._array_api import (xp_assert_equal , xp_assert_close , is_numpy ,
1616 is_array_api_strict )
17+ from scipy ._lib ._lazy_testing import lazy_xp_function
1718from scipy ._lib ._util import (_aligned_zeros , check_random_state , MapWrapper ,
1819 getfullargspec_no_self , FullArgSpec ,
1920 rng_integers , _validate_int , _rename_parameter ,
2324
2425skip_xp_backends = pytest .mark .skip_xp_backends
2526
27+ lazy_xp_function (_contains_nan , static_argnames = ("nan_policy" , "xp_omit_okay" , "xp" ))
28+ # FIXME @jax.jit fails: complex bool mask
29+ lazy_xp_function (_lazywhere , jax_jit = False , static_argnames = ("f" , "f2" ))
30+
31+
2632@pytest .mark .slow
2733def test__aligned_zeros ():
2834 niter = 10
@@ -344,6 +350,7 @@ def test_contains_nan_with_strings(self):
344350 data4 = np .array ([["1" , 2 ], [3 , np .nan ]], dtype = 'object' )
345351 assert _contains_nan (data4 )
346352
353+ @pytest .mark .skip_xp_backends ("jax.numpy" , reason = "lazy backends tested separately" )
347354 @pytest .mark .parametrize ("nan_policy" , ['propagate' , 'omit' , 'raise' ])
348355 def test_array_api (self , xp , nan_policy ):
349356 rng = np .random .default_rng (932347235892482 )
@@ -359,9 +366,40 @@ def test_array_api(self, xp, nan_policy):
359366 elif nan_policy == 'omit' and not is_numpy (xp ):
360367 with pytest .raises (ValueError , match = "nan_policy='omit' is incompatible" ):
361368 _contains_nan (x , nan_policy )
369+ assert _contains_nan (x , nan_policy , xp_omit_okay = True )
362370 elif nan_policy == 'propagate' :
363371 assert _contains_nan (x , nan_policy )
364372
373+ @pytest .mark .skip_xp_backends ("numpy" , reason = "lazy backends only" )
374+ @pytest .mark .skip_xp_backends ("cupy" , reason = "lazy backends only" )
375+ @pytest .mark .skip_xp_backends ("array_api_strict" , reason = "lazy backends only" )
376+ @pytest .mark .skip_xp_backends ("torch" , reason = "lazy backends only" )
377+ def test_array_api_lazy (self , xp ):
378+ rng = np .random .default_rng (932347235892482 )
379+ x0 = rng .random (size = (2 , 3 , 4 ))
380+ x = xp .asarray (x0 )
381+
382+ xp_assert_equal (_contains_nan (x ), xp .asarray (False ))
383+ xp_assert_equal (_contains_nan (x , "propagate" ), xp .asarray (False ))
384+ xp_assert_equal (_contains_nan (x , "omit" , xp_omit_okay = True ), xp .asarray (False ))
385+ # Lazy arrays don't support "omit" and "raise" policies
386+ # TODO test that we're emitting a user-friendly error message.
387+ # Blocked by https://github.com/data-apis/array-api-compat/pull/228
388+ with pytest .raises (TypeError ):
389+ _contains_nan (x , "omit" )
390+ with pytest .raises (TypeError ):
391+ _contains_nan (x , "raise" )
392+
393+ x = xpx .at (x )[1 , 2 , 1 ].set (np .nan )
394+
395+ xp_assert_equal (_contains_nan (x ), xp .asarray (True ))
396+ xp_assert_equal (_contains_nan (x , "propagate" ), xp .asarray (True ))
397+ xp_assert_equal (_contains_nan (x , "omit" , xp_omit_okay = True ), xp .asarray (True ))
398+ with pytest .raises (TypeError ):
399+ _contains_nan (x , "omit" )
400+ with pytest .raises (TypeError ):
401+ _contains_nan (x , "raise" )
402+
365403
366404def test__rng_html_rewrite ():
367405 def mock_str ():
0 commit comments