48
48
lazy_xp_function (sinc , static_argnames = "xp" )
49
49
50
50
51
- NUMPY_GE2 = int (np .__version__ .split ("." )[0 ]) >= 2
51
+ NUMPY_VERSION = tuple ( int (v ) for v in np .__version__ .split ("." )[2 ])
52
52
53
53
54
- @pytest .mark .skip_xp_backend (
55
- Backend .SPARSE , reason = "read-only backend without .at support"
56
- )
57
54
class TestApplyWhere :
58
55
@staticmethod
59
56
def f1 (x : Array , y : Array | int = 10 ) -> Array :
@@ -153,6 +150,14 @@ def test_dont_overwrite_fill_value(self, xp: ModuleType):
153
150
xp_assert_equal (actual , xp .asarray ([100 , 12 ]))
154
151
xp_assert_equal (fill_value , xp .asarray ([100 , 200 ]))
155
152
153
+ @pytest .mark .skip_xp_backend (
154
+ Backend .ARRAY_API_STRICTEST ,
155
+ reason = "no boolean indexing -> run everywhere" ,
156
+ )
157
+ @pytest .mark .skip_xp_backend (
158
+ Backend .SPARSE ,
159
+ reason = "no indexing by sparse array -> run everywhere" ,
160
+ )
156
161
def test_dont_run_on_false (self , xp : ModuleType ):
157
162
x = xp .asarray ([1.0 , 2.0 , 0.0 ])
158
163
y = xp .asarray ([0.0 , 3.0 , 4.0 ])
@@ -192,6 +197,7 @@ def test_device(self, xp: ModuleType, device: Device):
192
197
y = apply_where (x % 2 == 0 , x , self .f1 , fill_value = x )
193
198
assert get_device (y ) == device
194
199
200
+ @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no isdtype" )
195
201
@pytest .mark .filterwarnings ("ignore::RuntimeWarning" ) # overflows, etc.
196
202
@hypothesis .settings (
197
203
# The xp and library fixtures are not regenerated between hypothesis iterations
@@ -218,7 +224,7 @@ def test_hypothesis( # type: ignore[explicit-any,decorated-any]
218
224
):
219
225
if (
220
226
library in (Backend .NUMPY , Backend .NUMPY_READONLY )
221
- and not NUMPY_GE2
227
+ and NUMPY_VERSION < ( 2 , 0 )
222
228
and dtype is np .float32
223
229
):
224
230
pytest .xfail (reason = "NumPy 1.x dtype promotion for scalars" )
@@ -562,6 +568,9 @@ def test_xp(self, xp: ModuleType):
562
568
assert y .shape == (1 , 1 , 1 , 3 )
563
569
564
570
571
+ @pytest .mark .filterwarnings ( # array_api_strictest
572
+ "ignore:invalid value encountered:RuntimeWarning:array_api_strict"
573
+ )
565
574
@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no isdtype" )
566
575
class TestIsClose :
567
576
@pytest .mark .parametrize ("swap" , [False , True ])
@@ -680,13 +689,15 @@ def test_bool_dtype(self, xp: ModuleType):
680
689
isclose (xp .asarray (True ), b , atol = 1 ), xp .asarray ([True , True , True ])
681
690
)
682
691
692
+ @pytest .mark .skip_xp_backend (Backend .ARRAY_API_STRICTEST , reason = "unknown shape" )
683
693
def test_none_shape (self , xp : ModuleType ):
684
694
a = xp .asarray ([1 , 5 , 0 ])
685
695
b = xp .asarray ([1 , 4 , 2 ])
686
696
b = b [a < 5 ]
687
697
a = a [a < 5 ]
688
698
xp_assert_equal (isclose (a , b ), xp .asarray ([True , False ]))
689
699
700
+ @pytest .mark .skip_xp_backend (Backend .ARRAY_API_STRICTEST , reason = "unknown shape" )
690
701
def test_none_shape_bool (self , xp : ModuleType ):
691
702
a = xp .asarray ([True , True , False ])
692
703
b = xp .asarray ([True , False , True ])
@@ -819,8 +830,30 @@ def test_empty(self, xp: ModuleType):
819
830
a = xp .asarray ([])
820
831
xp_assert_equal (nunique (a ), xp .asarray (0 ))
821
832
822
- def test_device (self , xp : ModuleType , device : Device ):
823
- a = xp .asarray (0.0 , device = device )
833
+ def test_size1 (self , xp : ModuleType ):
834
+ a = xp .asarray ([123 ])
835
+ xp_assert_equal (nunique (a ), xp .asarray (1 ))
836
+
837
+ def test_all_equal (self , xp : ModuleType ):
838
+ a = xp .asarray ([123 , 123 , 123 ])
839
+ xp_assert_equal (nunique (a ), xp .asarray (1 ))
840
+
841
+ @pytest .mark .xfail_xp_backend (Backend .DASK , reason = "No equal_nan kwarg in unique" )
842
+ @pytest .mark .xfail_xp_backend (
843
+ Backend .SPARSE , reason = "Non-compliant equal_nan=True behaviour"
844
+ )
845
+ def test_nan (self , xp : ModuleType , library : Backend ):
846
+ is_numpy = library in (Backend .NUMPY , Backend .NUMPY_READONLY )
847
+ if is_numpy and NUMPY_VERSION < (1 , 24 ):
848
+ pytest .xfail ("NumPy <1.24 has no equal_nan kwarg in unique" )
849
+
850
+ # Each NaN is counted separately
851
+ a = xp .asarray ([xp .nan , 123.0 , xp .nan ])
852
+ xp_assert_equal (nunique (a ), xp .asarray (3 ))
853
+
854
+ @pytest .mark .parametrize ("size" , [0 , 1 , 2 ])
855
+ def test_device (self , xp : ModuleType , device : Device , size : int ):
856
+ a = xp .asarray ([0.0 ] * size , device = device )
824
857
assert get_device (nunique (a )) == device
825
858
826
859
def test_xp (self , xp : ModuleType ):
@@ -895,6 +928,7 @@ def test_sequence_of_tuples_width(self, xp: ModuleType):
895
928
896
929
897
930
@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no argsort" )
931
+ @pytest .mark .skip_xp_backend (Backend .ARRAY_API_STRICTEST , reason = "no unique_values" )
898
932
class TestSetDiff1D :
899
933
@pytest .mark .xfail_xp_backend (Backend .DASK , reason = "NaN-shaped arrays" )
900
934
@pytest .mark .xfail_xp_backend (
0 commit comments