77import pytest
88
99from array_api_extra ._lib ._backends import Backend
10- from array_api_extra ._lib ._testing import xp_assert_close , xp_assert_equal
10+ from array_api_extra ._lib ._testing import (
11+ xp_assert_close ,
12+ xp_assert_equal ,
13+ xp_assert_less ,
14+ )
1115from array_api_extra ._lib ._utils ._compat import (
1216 array_namespace ,
1317 is_dask_namespace ,
2327 "func" ,
2428 [
2529 xp_assert_equal ,
30+ xp_assert_less ,
2631 pytest .param (
2732 xp_assert_close ,
2833 marks = pytest .mark .xfail_xp_backend (
3338)
3439
3540
36- @param_assert_equal_close
41+ @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no isdtype" , strict = False )
42+ @pytest .mark .parametrize ("func" , [xp_assert_equal , xp_assert_close ])
3743def test_assert_close_equal_basic (xp : ModuleType , func : Callable [..., None ]): # type: ignore[explicit-any]
3844 func (xp .asarray (0 ), xp .asarray (0 ))
3945 func (xp .asarray ([1 , 2 ]), xp .asarray ([1 , 2 ]))
@@ -53,8 +59,8 @@ def test_assert_close_equal_basic(xp: ModuleType, func: Callable[..., None]): #
5359
5460@pytest .mark .skip_xp_backend (Backend .NUMPY , reason = "test other ns vs. numpy" )
5561@pytest .mark .skip_xp_backend (Backend .NUMPY_READONLY , reason = "test other ns vs. numpy" )
56- @pytest .mark .parametrize ("func" , [xp_assert_equal , xp_assert_close ])
57- def test_assert_close_equal_namespace (xp : ModuleType , func : Callable [..., None ]): # type: ignore[explicit-any]
62+ @pytest .mark .parametrize ("func" , [xp_assert_equal , xp_assert_close , xp_assert_less ])
63+ def test_assert_close_equal_less_namespace (xp : ModuleType , func : Callable [..., None ]): # type: ignore[explicit-any]
5864 with pytest .raises (AssertionError , match = "namespaces do not match" ):
5965 func (xp .asarray (0 ), np .asarray (0 ))
6066 with pytest .raises (TypeError , match = "Unrecognized array input" ):
@@ -65,7 +71,7 @@ def test_assert_close_equal_namespace(xp: ModuleType, func: Callable[..., None])
6571
6672@param_assert_equal_close
6773@pytest .mark .parametrize ("check_shape" , [False , True ])
68- def test_assert_close_equal_shape ( # type: ignore[explicit-any]
74+ def test_assert_close_equal_less_shape ( # type: ignore[explicit-any]
6975 xp : ModuleType ,
7076 func : Callable [..., None ],
7177 check_shape : bool ,
@@ -76,12 +82,12 @@ def test_assert_close_equal_shape( # type: ignore[explicit-any]
7682 else nullcontext ()
7783 )
7884 with context :
79- func (xp .asarray ([0 , 0 ]), xp .asarray (0 ), check_shape = check_shape )
85+ func (xp .asarray ([xp . nan , xp . nan ]), xp .asarray (xp . nan ), check_shape = check_shape )
8086
8187
8288@param_assert_equal_close
8389@pytest .mark .parametrize ("check_dtype" , [False , True ])
84- def test_assert_close_equal_dtype ( # type: ignore[explicit-any]
90+ def test_assert_close_equal_less_dtype ( # type: ignore[explicit-any]
8591 xp : ModuleType ,
8692 func : Callable [..., None ],
8793 check_dtype : bool ,
@@ -92,12 +98,17 @@ def test_assert_close_equal_dtype( # type: ignore[explicit-any]
9298 else nullcontext ()
9399 )
94100 with context :
95- func (xp .asarray (0.0 ), xp .asarray (0 ), check_dtype = check_dtype )
101+ func (
102+ xp .asarray (xp .nan , dtype = xp .float32 ),
103+ xp .asarray (xp .nan , dtype = xp .float64 ),
104+ check_dtype = check_dtype ,
105+ )
96106
97107
98- @pytest .mark .parametrize ("func" , [xp_assert_equal , xp_assert_close ])
108+ @pytest .mark .parametrize ("func" , [xp_assert_equal , xp_assert_close , xp_assert_less ])
99109@pytest .mark .parametrize ("check_scalar" , [False , True ])
100- def test_assert_close_equal_scalar ( # type: ignore[explicit-any]
110+ def test_assert_close_equal_less_scalar ( # type: ignore[explicit-any]
111+ xp : ModuleType ,
101112 func : Callable [..., None ],
102113 check_scalar : bool ,
103114):
@@ -107,7 +118,7 @@ def test_assert_close_equal_scalar( # type: ignore[explicit-any]
107118 else nullcontext ()
108119 )
109120 with context :
110- func (np .asarray (0 ), np .asarray (0 )[()], check_scalar = check_scalar )
121+ func (np .asarray (xp . nan ), np .asarray (xp . nan )[()], check_scalar = check_scalar )
111122
112123
113124@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no isdtype" )
@@ -121,9 +132,18 @@ def test_assert_close_tolerance(xp: ModuleType):
121132 xp_assert_close (xp .asarray ([100.0 ]), xp .asarray ([102.0 ]), atol = 1 )
122133
123134
124- @param_assert_equal_close
135+ def test_assert_less_basic (xp : ModuleType ):
136+ xp_assert_less (xp .asarray (- 1 ), xp .asarray (0 ))
137+ xp_assert_less (xp .asarray ([1 , 2 ]), xp .asarray ([2 , 3 ]))
138+ with pytest .raises (AssertionError ):
139+ xp_assert_less (xp .asarray ([1 , 1 ]), xp .asarray ([2 , 1 ]))
140+ with pytest .raises (AssertionError , match = "hello" ):
141+ xp_assert_less (xp .asarray ([1 , 1 ]), xp .asarray ([2 , 1 ]), err_msg = "hello" )
142+
143+
125144@pytest .mark .skip_xp_backend (Backend .SPARSE , reason = "index by sparse array" )
126145@pytest .mark .skip_xp_backend (Backend .ARRAY_API_STRICTEST , reason = "boolean indexing" )
146+ @pytest .mark .parametrize ("func" , [xp_assert_equal , xp_assert_close ])
127147def test_assert_close_equal_none_shape (xp : ModuleType , func : Callable [..., None ]): # type: ignore[explicit-any]
128148 """On Dask and other lazy backends, test that a shape with NaN's or None's
129149 can be compared to a real shape.
0 commit comments