11from collections .abc import Callable
2- from contextlib import nullcontext
32from types import ModuleType
43from typing import cast
54
2120from array_api_extra ._lib ._utils ._typing import Array , Device
2221from array_api_extra .testing import lazy_xp_function
2322
24- # mypy: disable-error-code=decorated-any
23+ # mypy: disable-error-code=" decorated-any, explicit-any"
2524# pyright: reportUnknownParameterType=false,reportMissingParameterType=false
2625
27- param_assert_equal_close = pytest .mark .parametrize (
28- "func" ,
29- [
30- xp_assert_equal ,
31- xp_assert_less ,
32- pytest .param (
33- xp_assert_close ,
34- marks = pytest .mark .xfail_xp_backend (
35- Backend .SPARSE , reason = "no isdtype" , strict = False
36- ),
37- ),
38- ],
39- )
40-
4126
4227class TestAsNumPyArray :
4328 def test_basic (self , xp : ModuleType ):
@@ -57,136 +42,144 @@ def test_device(self, xp: ModuleType, library: Backend, device: Device):
5742 xp_assert_equal (actual , expect ) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
5843
5944
60- @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no isdtype" , strict = False )
61- @pytest .mark .parametrize ("func" , [xp_assert_equal , xp_assert_close ])
62- def test_assert_close_equal_basic (xp : ModuleType , func : Callable [..., None ]): # type: ignore[explicit-any]
63- func (xp .asarray (0 ), xp .asarray (0 ))
64- func (xp .asarray ([1 , 2 ]), xp .asarray ([1 , 2 ]))
65-
66- with pytest .raises (AssertionError , match = "shapes do not match" ):
67- func (xp .asarray ([0 ]), xp .asarray ([[0 ]]))
68-
69- with pytest .raises (AssertionError , match = "dtypes do not match" ):
70- func (xp .asarray (0 , dtype = xp .float32 ), xp .asarray (0 , dtype = xp .float64 ))
71-
72- with pytest .raises (AssertionError ):
73- func (xp .asarray ([1 , 2 ]), xp .asarray ([1 , 3 ]))
74-
75- with pytest .raises (AssertionError , match = "hello" ):
76- func (xp .asarray ([1 , 2 ]), xp .asarray ([1 , 3 ]), err_msg = "hello" )
77-
78-
79- @pytest .mark .skip_xp_backend (Backend .NUMPY , reason = "test other ns vs. numpy" )
80- @pytest .mark .skip_xp_backend (Backend .NUMPY_READONLY , reason = "test other ns vs. numpy" )
81- @pytest .mark .parametrize ("func" , [xp_assert_equal , xp_assert_close , xp_assert_less ])
82- def test_assert_close_equal_less_namespace (xp : ModuleType , func : Callable [..., None ]): # type: ignore[explicit-any]
83- with pytest .raises (AssertionError , match = "namespaces do not match" ):
84- func (xp .asarray (0 ), np .asarray (0 ))
85- with pytest .raises (TypeError , match = "Unrecognized array input" ):
86- func (xp .asarray (0 ), 0 )
87- with pytest .raises (TypeError , match = "list is not a supported array type" ):
88- func (xp .asarray ([0 ]), [0 ])
89-
90-
91- @param_assert_equal_close
92- @pytest .mark .parametrize ("check_shape" , [False , True ])
93- def test_assert_close_equal_less_shape ( # type: ignore[explicit-any]
94- xp : ModuleType ,
95- func : Callable [..., None ],
96- check_shape : bool ,
97- ):
98- context = (
99- pytest .raises (AssertionError , match = "shapes do not match" )
100- if check_shape
101- else nullcontext ()
102- )
103- with context :
104- # note: NaNs are handled by all 3 checks
105- func (xp .asarray ([xp .nan , xp .nan ]), xp .asarray (xp .nan ), check_shape = check_shape )
106-
107-
108- @param_assert_equal_close
109- @pytest .mark .parametrize ("check_dtype" , [False , True ])
110- def test_assert_close_equal_less_dtype ( # type: ignore[explicit-any]
111- xp : ModuleType ,
112- func : Callable [..., None ],
113- check_dtype : bool ,
114- ):
115- context = (
116- pytest .raises (AssertionError , match = "dtypes do not match" )
117- if check_dtype
118- else nullcontext ()
45+ class TestAssertEqualCloseLess :
46+ pr_assert_close = pytest .param ( # pyright: ignore[reportUnannotatedClassAttribute]
47+ xp_assert_close ,
48+ marks = pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no isdtype" ),
11949 )
120- with context :
121- func (
122- xp .asarray (xp .nan , dtype = xp .float32 ),
123- xp .asarray (xp .nan , dtype = xp .float64 ),
124- check_dtype = check_dtype ,
125- )
126-
127-
128- @pytest .mark .parametrize ("func" , [xp_assert_equal , xp_assert_close , xp_assert_less ])
129- @pytest .mark .parametrize ("check_scalar" , [False , True ])
130- def test_assert_close_equal_less_scalar ( # type: ignore[explicit-any]
131- xp : ModuleType ,
132- func : Callable [..., None ],
133- check_scalar : bool ,
134- ):
135- context = (
136- pytest .raises (AssertionError , match = "array-ness does not match" )
137- if check_scalar
138- else nullcontext ()
139- )
140- with context :
141- func (np .asarray (xp .nan ), np .asarray (xp .nan )[()], check_scalar = check_scalar )
142-
14350
144- @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no isdtype" )
145- def test_assert_close_tolerance (xp : ModuleType ):
146- xp_assert_close (xp .asarray ([100.0 ]), xp .asarray ([102.0 ]), rtol = 0.03 )
147- with pytest .raises (AssertionError ):
148- xp_assert_close (xp .asarray ([100.0 ]), xp .asarray ([102.0 ]), rtol = 0.01 )
51+ @pytest .mark .parametrize ("func" , [xp_assert_equal , pr_assert_close ])
52+ def test_assert_equal_close_basic (self , xp : ModuleType , func : Callable [..., None ]):
53+ func (xp .asarray (0 ), xp .asarray (0 ))
54+ func (xp .asarray ([1 , 2 ]), xp .asarray ([1 , 2 ]))
14955
150- xp_assert_close (xp .asarray ([100.0 ]), xp .asarray ([102.0 ]), atol = 3 )
151- with pytest .raises (AssertionError ):
152- xp_assert_close (xp .asarray ([100.0 ]), xp .asarray ([102.0 ]), atol = 1 )
56+ with pytest .raises (AssertionError , match = "Mismatched elements" ):
57+ func (xp .asarray ([1 , 2 ]), xp .asarray ([2 , 1 ]))
15358
59+ with pytest .raises (AssertionError , match = "hello" ):
60+ func (xp .asarray ([1 , 2 ]), xp .asarray ([2 , 1 ]), err_msg = "hello" )
15461
155- def test_assert_less_basic (xp : ModuleType ):
156- xp_assert_less (xp .asarray (- 1 ), xp .asarray (0 ))
157- xp_assert_less (xp .asarray ([1 , 2 ]), xp .asarray ([2 , 3 ]))
158- with pytest .raises (AssertionError ):
159- xp_assert_less (xp .asarray ([1 , 1 ]), xp .asarray ([2 , 1 ]))
160- with pytest .raises (AssertionError , match = "hello" ):
161- xp_assert_less (xp .asarray ([1 , 1 ]), xp .asarray ([2 , 1 ]), err_msg = "hello" )
62+ @pytest .mark .parametrize ("func" , [xp_assert_equal , xp_assert_close , xp_assert_less ])
63+ def test_shape_dtype (self , xp : ModuleType , func : Callable [..., None ]):
64+ with pytest .raises (AssertionError , match = "shapes do not match" ):
65+ func (xp .asarray ([0 ]), xp .asarray ([[0 ]]))
16266
67+ with pytest .raises (AssertionError , match = "dtypes do not match" ):
68+ func (xp .asarray (0 , dtype = xp .float32 ), xp .asarray (0 , dtype = xp .float64 ))
16369
164- @pytest .mark .skip_xp_backend (Backend .SPARSE , reason = "index by sparse array" )
165- @pytest .mark .skip_xp_backend (Backend .ARRAY_API_STRICTEST , reason = "boolean indexing" )
166- @pytest .mark .parametrize ("func" , [xp_assert_equal , xp_assert_close ])
167- def test_assert_close_equal_none_shape (xp : ModuleType , func : Callable [..., None ]): # type: ignore[explicit-any]
168- """On Dask and other lazy backends, test that a shape with NaN's or None's
169- can be compared to a real shape.
170- """
171- a = xp .asarray ([1 , 2 ])
172- a = a [a > 1 ]
173-
174- func (a , xp .asarray ([2 ]))
175- with pytest .raises (AssertionError ):
176- func (a , xp .asarray ([2 , 3 ]))
177- with pytest .raises (AssertionError ):
178- func (a , xp .asarray (2 ))
179- with pytest .raises (AssertionError ):
180- func (a , xp .asarray ([3 ]))
181-
182- # Swap actual and desired
183- func (xp .asarray ([2 ]), a )
184- with pytest .raises (AssertionError ):
185- func (xp .asarray ([2 , 3 ]), a )
186- with pytest .raises (AssertionError ):
187- func (xp .asarray (2 ), a )
188- with pytest .raises (AssertionError ):
189- func (xp .asarray ([3 ]), a )
70+ @pytest .mark .skip_xp_backend (Backend .NUMPY , reason = "test other ns vs. numpy" )
71+ @pytest .mark .skip_xp_backend (
72+ Backend .NUMPY_READONLY , reason = "test other ns vs. numpy"
73+ )
74+ @pytest .mark .parametrize ("func" , [xp_assert_equal , xp_assert_close , xp_assert_less ])
75+ def test_namespace (self , xp : ModuleType , func : Callable [..., None ]):
76+ with pytest .raises (AssertionError , match = "namespaces do not match" ):
77+ func (xp .asarray (0 ), np .asarray (0 ))
78+ with pytest .raises (TypeError , match = "Unrecognized array input" ):
79+ func (xp .asarray (0 ), 0 )
80+ with pytest .raises (TypeError , match = "list is not a supported array type" ):
81+ func (xp .asarray ([0 ]), [0 ])
82+
83+ @pytest .mark .parametrize ("func" , [xp_assert_equal , pr_assert_close , xp_assert_less ])
84+ def test_check_shape (self , xp : ModuleType , func : Callable [..., None ]):
85+ a = xp .asarray ([1 ] if func is xp_assert_less else [2 ])
86+ b = xp .asarray (2 )
87+ c = xp .asarray (0 )
88+ d = xp .asarray ([2 , 2 ])
89+
90+ with pytest .raises (AssertionError , match = "shapes do not match" ):
91+ func (a , b )
92+ func (a , b , check_shape = False )
93+ with pytest .raises (AssertionError , match = "Mismatched elements" ):
94+ func (a , c , check_shape = False )
95+ with pytest .raises (AssertionError , match = r"shapes \(1,\), \(2,\) mismatch" ):
96+ func (a , d , check_shape = False )
97+
98+ @pytest .mark .parametrize ("func" , [xp_assert_equal , pr_assert_close , xp_assert_less ])
99+ def test_check_dtype (self , xp : ModuleType , func : Callable [..., None ]):
100+ a = xp .asarray (1 if func is xp_assert_less else 2 )
101+ b = xp .asarray (2 , dtype = xp .int16 )
102+ c = xp .asarray (0 , dtype = xp .int16 )
103+
104+ with pytest .raises (AssertionError , match = "dtypes do not match" ):
105+ func (a , b )
106+ func (a , b , check_dtype = False )
107+ with pytest .raises (AssertionError , match = "Mismatched elements" ):
108+ func (a , c , check_dtype = False )
109+
110+ @pytest .mark .parametrize ("func" , [xp_assert_equal , pr_assert_close , xp_assert_less ])
111+ @pytest .mark .xfail_xp_backend (
112+ Backend .SPARSE , reason = "sparse [()] returns np.generic"
113+ )
114+ def test_check_scalar (
115+ self , xp : ModuleType , library : Backend , func : Callable [..., None ]
116+ ):
117+ a = xp .asarray (1 if func is xp_assert_less else 2 )
118+ b = xp .asarray (2 )[()] # Note: only makes a difference on NumPy
119+ c = xp .asarray (0 )
120+
121+ func (a , b )
122+ if library .like (Backend .NUMPY ):
123+ with pytest .raises (AssertionError , match = "array-ness does not match" ):
124+ func (a , b , check_scalar = True )
125+ else :
126+ func (a , b , check_scalar = True )
127+ with pytest .raises (AssertionError , match = "Mismatched elements" ):
128+ func (a , c , check_scalar = True )
129+
130+ @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no isdtype" )
131+ @pytest .mark .parametrize ("dtype" , ["int64" , "float64" ])
132+ def test_assert_close_tolerance (self , dtype : str , xp : ModuleType ):
133+ a = xp .asarray ([100 ], dtype = getattr (xp , dtype ))
134+ b = xp .asarray ([102 ], dtype = getattr (xp , dtype ))
135+
136+ with pytest .raises (AssertionError , match = "Mismatched elements" ):
137+ xp_assert_close (a , b )
138+
139+ xp_assert_close (a , b , rtol = 0.03 )
140+ with pytest .raises (AssertionError , match = "Mismatched elements" ):
141+ xp_assert_close (a , b , rtol = 0.01 )
142+
143+ xp_assert_close (a , b , atol = 3 )
144+ with pytest .raises (AssertionError , match = "Mismatched elements" ):
145+ xp_assert_close (a , b , atol = 1 )
146+
147+ def test_assert_less (self , xp : ModuleType ):
148+ xp_assert_less (xp .asarray (- 1 ), xp .asarray (0 ))
149+ xp_assert_less (xp .asarray ([1 , 2 ]), xp .asarray ([2 , 3 ]))
150+ with pytest .raises (AssertionError , match = "Mismatched elements" ):
151+ xp_assert_less (xp .asarray ([1 , 1 ]), xp .asarray ([2 , 1 ]))
152+
153+ @pytest .mark .parametrize ("func" , [xp_assert_equal , pr_assert_close , xp_assert_less ])
154+ @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "index by sparse array" )
155+ @pytest .mark .skip_xp_backend (Backend .ARRAY_API_STRICTEST , reason = "boolean indexing" )
156+ def test_none_shape (self , xp : ModuleType , func : Callable [..., None ]):
157+ """On Dask and other lazy backends, test that a shape with NaN's or None's
158+ can be compared to a real shape.
159+ """
160+ # actual has shape=(None, )
161+ a = xp .asarray ([1 ] if func is xp_assert_less else [2 ])
162+ a = a [a > 0 ]
163+
164+ func (a , xp .asarray ([2 ]))
165+ with pytest .raises (AssertionError , match = "shapes do not match" ):
166+ func (a , xp .asarray (2 ))
167+ with pytest .raises (AssertionError , match = "shapes do not match" ):
168+ func (a , xp .asarray ([2 , 3 ]))
169+ with pytest .raises (AssertionError , match = "Mismatched elements" ):
170+ func (a , xp .asarray ([0 ]))
171+
172+ # desired has shape=(None, )
173+ a = xp .asarray ([3 ] if func is xp_assert_less else [2 ])
174+ a = a [a > 0 ]
175+
176+ func (xp .asarray ([2 ]), a )
177+ with pytest .raises (AssertionError , match = "shapes do not match" ):
178+ func (xp .asarray (2 ), a )
179+ with pytest .raises (AssertionError , match = "shapes do not match" ):
180+ func (xp .asarray ([2 , 3 ]), a )
181+ with pytest .raises (AssertionError , match = "Mismatched elements" ):
182+ func (xp .asarray ([4 ]), a )
190183
191184
192185def good_lazy (x : Array ) -> Array :
0 commit comments