55See also ..testing for public testing utilities.
66"""
77
8+ from __future__ import annotations
9+
810import math
911from types import ModuleType
10- from typing import cast
12+ from typing import Any , cast
1113
14+ import numpy as np
1215import pytest
1316
1417from ._utils ._compat import (
1518 array_namespace ,
1619 is_array_api_strict_namespace ,
1720 is_cupy_namespace ,
1821 is_dask_namespace ,
22+ is_jax_namespace ,
23+ is_numpy_namespace ,
1924 is_pydata_sparse_namespace ,
2025 is_torch_namespace ,
26+ to_device ,
2127)
22- from ._utils ._typing import Array
28+ from ._utils ._typing import Array , Device
2329
24- __all__ = ["xp_assert_close" , "xp_assert_equal" ]
30+ __all__ = ["as_numpy_array" , " xp_assert_close" , "xp_assert_equal" , "xp_assert_less " ]
2531
2632
2733def _check_ns_shape_dtype (
28- actual : Array , desired : Array
34+ actual : Array ,
35+ desired : Array ,
36+ check_dtype : bool ,
37+ check_shape : bool ,
38+ check_scalar : bool ,
2939) -> ModuleType : # numpydoc ignore=RT03
3040 """
3141 Assert that namespace, shape and dtype of the two arrays match.
@@ -36,6 +46,11 @@ def _check_ns_shape_dtype(
3646 The array produced by the tested function.
3747 desired : Array
3848 The expected array (typically hardcoded).
49+ check_dtype, check_shape : bool, default: True
50+ Whether to check agreement between actual and desired dtypes and shapes
51+ check_scalar : bool, default: False
52+ NumPy only: whether to check agreement between actual and desired types -
53+ 0d array vs scalar.
3954
4055 Returns
4156 -------
@@ -47,43 +62,67 @@ def _check_ns_shape_dtype(
4762 msg = f"namespaces do not match: { actual_xp } != f{ desired_xp } "
4863 assert actual_xp == desired_xp , msg
4964
50- actual_shape = actual .shape
51- desired_shape = desired .shape
52- if is_dask_namespace (desired_xp ):
53- # Dask uses nan instead of None for unknown shapes
54- if any (math .isnan (i ) for i in cast (tuple [float , ...], actual_shape )):
55- actual_shape = actual .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
56- if any (math .isnan (i ) for i in cast (tuple [float , ...], desired_shape )):
57- desired_shape = desired .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
58-
59- msg = f"shapes do not match: { actual_shape } != f{ desired_shape } "
60- assert actual_shape == desired_shape , msg
61-
62- msg = f"dtypes do not match: { actual .dtype } != { desired .dtype } "
63- assert actual .dtype == desired .dtype , msg
65+ if check_shape :
66+ actual_shape = actual .shape
67+ desired_shape = desired .shape
68+ if is_dask_namespace (desired_xp ):
69+ # Dask uses nan instead of None for unknown shapes
70+ if any (math .isnan (i ) for i in cast (tuple [float , ...], actual_shape )):
71+ actual_shape = actual .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
72+ if any (math .isnan (i ) for i in cast (tuple [float , ...], desired_shape )):
73+ desired_shape = desired .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
74+
75+ msg = f"shapes do not match: { actual_shape } != f{ desired_shape } "
76+ assert actual_shape == desired_shape , msg
77+
78+ if check_dtype :
79+ msg = f"dtypes do not match: { actual .dtype } != { desired .dtype } "
80+ assert actual .dtype == desired .dtype , msg
81+
82+ if is_numpy_namespace (actual_xp ) and check_scalar :
83+ # only NumPy distinguishes between scalars and arrays; we do if check_scalar.
84+ _msg = (
85+ "array-ness does not match:\n Actual: "
86+ f"{ type (actual )} \n Desired: { type (desired )} "
87+ )
88+ assert np .isscalar (actual ) == np .isscalar (desired ), _msg
6489
6590 return desired_xp
6691
6792
68- def _prepare_for_test (array : Array , xp : ModuleType ) -> Array :
93+ def as_numpy_array (array : Array , * , xp : ModuleType ) -> np . typing . NDArray [ Any ]: # type: ignore[explicit-any]
6994 """
70- Ensure that the array can be compared with xp.testing or np.testing.
71-
72- This involves transferring it from GPU to CPU memory, densifying it, etc.
95+ Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards.
7396 """
74- if is_torch_namespace (xp ):
75- return array . cpu () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
97+ if is_cupy_namespace (xp ):
98+ return xp . asnumpy ( array )
7699 if is_pydata_sparse_namespace (xp ):
77100 return array .todense () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
101+
102+ if is_torch_namespace (xp ):
103+ array = to_device (array , "cpu" )
78104 if is_array_api_strict_namespace (xp ):
79- # Note: we deliberately did not add a `.to_device` method in _typing.pyi
80- # even if it is required by the standard as many backends don't support it
81- return array .to_device (xp .Device ("CPU_DEVICE" )) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
82- # Note: nothing to do for CuPy, because it uses a bespoke test function
83- return array
105+ cpu : Device = xp .Device ("CPU_DEVICE" )
106+ array = to_device (array , cpu )
107+ if is_jax_namespace (xp ):
108+ import jax
84109
110+ # Note: only needed if the transfer guard is enabled
111+ cpu = cast (Device , jax .devices ("cpu" )[0 ])
112+ array = to_device (array , cpu )
85113
86- def xp_assert_equal (actual : Array , desired : Array , err_msg : str = "" ) -> None :
114+ return np .asarray (array )
115+
116+
117+ def xp_assert_equal (
118+ actual : Array ,
119+ desired : Array ,
120+ * ,
121+ err_msg : str = "" ,
122+ check_dtype : bool = True ,
123+ check_shape : bool = True ,
124+ check_scalar : bool = False ,
125+ ) -> None :
87126 """
88127 Array-API compatible version of `np.testing.assert_array_equal`.
89128
@@ -95,34 +134,56 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
95134 The expected array (typically hardcoded).
96135 err_msg : str, optional
97136 Error message to display on failure.
137+ check_dtype, check_shape : bool, default: True
138+ Whether to check agreement between actual and desired dtypes and shapes
139+ check_scalar : bool, default: False
140+ NumPy only: whether to check agreement between actual and desired types -
141+ 0d array vs scalar.
98142
99143 See Also
100144 --------
101145 xp_assert_close : Similar function for inexact equality checks.
102146 numpy.testing.assert_array_equal : Similar function for NumPy arrays.
103147 """
104- xp = _check_ns_shape_dtype (actual , desired )
105- actual = _prepare_for_test (actual , xp )
106- desired = _prepare_for_test (desired , xp )
148+ xp = _check_ns_shape_dtype (actual , desired , check_dtype , check_shape , check_scalar )
149+ actual_np = as_numpy_array (actual , xp = xp )
150+ desired_np = as_numpy_array (desired , xp = xp )
151+ np .testing .assert_array_equal (actual_np , desired_np , err_msg = err_msg )
107152
108- if is_cupy_namespace (xp ):
109- xp .testing .assert_array_equal (actual , desired , err_msg = err_msg )
110- elif is_torch_namespace (xp ):
111- # PyTorch recommends using `rtol=0, atol=0` like this
112- # to test for exact equality
113- xp .testing .assert_close (
114- actual ,
115- desired ,
116- rtol = 0 ,
117- atol = 0 ,
118- equal_nan = True ,
119- check_dtype = False ,
120- msg = err_msg or None ,
121- )
122- else :
123- import numpy as np # pylint: disable=import-outside-toplevel
124153
125- np .testing .assert_array_equal (actual , desired , err_msg = err_msg )
154+ def xp_assert_less (
155+ x : Array ,
156+ y : Array ,
157+ * ,
158+ err_msg : str = "" ,
159+ check_dtype : bool = True ,
160+ check_shape : bool = True ,
161+ check_scalar : bool = False ,
162+ ) -> None :
163+ """
164+ Array-API compatible version of `np.testing.assert_array_less`.
165+
166+ Parameters
167+ ----------
168+ x, y : Array
169+ The arrays to compare according to ``x < y`` (elementwise).
170+ err_msg : str, optional
171+ Error message to display on failure.
172+ check_dtype, check_shape : bool, default: True
173+ Whether to check agreement between actual and desired dtypes and shapes
174+ check_scalar : bool, default: False
175+ NumPy only: whether to check agreement between actual and desired types -
176+ 0d array vs scalar.
177+
178+ See Also
179+ --------
180+ xp_assert_close : Similar function for inexact equality checks.
181+ numpy.testing.assert_array_equal : Similar function for NumPy arrays.
182+ """
183+ xp = _check_ns_shape_dtype (x , y , check_dtype , check_shape , check_scalar )
184+ x_np = as_numpy_array (x , xp = xp )
185+ y_np = as_numpy_array (y , xp = xp )
186+ np .testing .assert_array_less (x_np , y_np , err_msg = err_msg )
126187
127188
128189def xp_assert_close (
@@ -132,6 +193,9 @@ def xp_assert_close(
132193 rtol : float | None = None ,
133194 atol : float = 0 ,
134195 err_msg : str = "" ,
196+ check_dtype : bool = True ,
197+ check_shape : bool = True ,
198+ check_scalar : bool = False ,
135199) -> None :
136200 """
137201 Array-API compatible version of `np.testing.assert_allclose`.
@@ -148,6 +212,11 @@ def xp_assert_close(
148212 Absolute tolerance. Default: 0.
149213 err_msg : str, optional
150214 Error message to display on failure.
215+ check_dtype, check_shape : bool, default: True
216+ Whether to check agreement between actual and desired dtypes and shapes
217+ check_scalar : bool, default: False
218+ NumPy only: whether to check agreement between actual and desired types -
219+ 0d array vs scalar.
151220
152221 See Also
153222 --------
@@ -159,40 +228,26 @@ def xp_assert_close(
159228 -----
160229 The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
161230 """
162- xp = _check_ns_shape_dtype (actual , desired )
163-
164- floating = xp .isdtype (actual .dtype , ("real floating" , "complex floating" ))
165- if rtol is None and floating :
166- # multiplier of 4 is used as for `np.float64` this puts the default `rtol`
167- # roughly half way between sqrt(eps) and the default for
168- # `numpy.testing.assert_allclose`, 1e-7
169- rtol = xp .finfo (actual .dtype ).eps ** 0.5 * 4
170- elif rtol is None :
171- rtol = 1e-7
172-
173- actual = _prepare_for_test (actual , xp )
174- desired = _prepare_for_test (desired , xp )
175-
176- if is_cupy_namespace (xp ):
177- xp .testing .assert_allclose (
178- actual , desired , rtol = rtol , atol = atol , err_msg = err_msg
179- )
180- elif is_torch_namespace (xp ):
181- xp .testing .assert_close (
182- actual , desired , rtol = rtol , atol = atol , equal_nan = True , msg = err_msg or None
183- )
184- else :
185- import numpy as np # pylint: disable=import-outside-toplevel
186-
187- # JAX/Dask arrays work directly with `np.testing`
188- assert isinstance (rtol , float )
189- np .testing .assert_allclose ( # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
190- actual , # pyright: ignore[reportArgumentType]
191- desired , # pyright: ignore[reportArgumentType]
192- rtol = rtol ,
193- atol = atol ,
194- err_msg = err_msg ,
195- )
231+ xp = _check_ns_shape_dtype (actual , desired , check_dtype , check_shape , check_scalar )
232+
233+ if rtol is None :
234+ if xp .isdtype (actual .dtype , ("real floating" , "complex floating" )):
235+ # multiplier of 4 is used as for `np.float64` this puts the default `rtol`
236+ # roughly half way between sqrt(eps) and the default for
237+ # `numpy.testing.assert_allclose`, 1e-7
238+ rtol = xp .finfo (actual .dtype ).eps ** 0.5 * 4
239+ else :
240+ rtol = 1e-7
241+
242+ actual_np = as_numpy_array (actual , xp = xp )
243+ desired_np = as_numpy_array (desired , xp = xp )
244+ np .testing .assert_allclose ( # pyright: ignore[reportCallIssue]
245+ actual_np ,
246+ desired_np ,
247+ rtol = rtol , # pyright: ignore[reportArgumentType]
248+ atol = atol ,
249+ err_msg = err_msg ,
250+ )
196251
197252
198253def xfail (
0 commit comments