99from types import ModuleType
1010from typing import cast
1111
12+ import numpy as np
1213import pytest
1314
1415from ._utils ._compat import (
1516 array_namespace ,
1617 is_array_api_strict_namespace ,
1718 is_cupy_namespace ,
1819 is_dask_namespace ,
20+ is_numpy_namespace ,
1921 is_pydata_sparse_namespace ,
2022 is_torch_namespace ,
2123)
2527
2628
2729def _check_ns_shape_dtype (
28- actual : Array , desired : Array
30+ actual : Array ,
31+ desired : Array ,
32+ check_dtype : bool ,
33+ check_shape : bool ,
34+ check_scalar : bool ,
2935) -> ModuleType : # numpydoc ignore=RT03
3036 """
3137 Assert that namespace, shape and dtype of the two arrays match.
@@ -47,43 +53,64 @@ def _check_ns_shape_dtype(
4753 msg = f"namespaces do not match: { actual_xp } != f{ desired_xp } "
4854 assert actual_xp == desired_xp , msg
4955
50- actual_shape = actual .shape
51- desired_shape = desired .shape
52- if is_dask_namespace (desired_xp ):
53- # Dask uses nan instead of None for unknown shapes
54- if any (math .isnan (i ) for i in cast (tuple [float , ...], actual_shape )):
55- actual_shape = actual .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
56- if any (math .isnan (i ) for i in cast (tuple [float , ...], desired_shape )):
57- desired_shape = desired .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
58-
59- msg = f"shapes do not match: { actual_shape } != f{ desired_shape } "
60- assert actual_shape == desired_shape , msg
61-
62- msg = f"dtypes do not match: { actual .dtype } != { desired .dtype } "
63- assert actual .dtype == desired .dtype , msg
56+ if check_shape :
57+ actual_shape = actual .shape
58+ desired_shape = desired .shape
59+ if is_dask_namespace (desired_xp ):
60+ # Dask uses nan instead of None for unknown shapes
61+ if any (math .isnan (i ) for i in cast (tuple [float , ...], actual_shape )):
62+ actual_shape = actual .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
63+ if any (math .isnan (i ) for i in cast (tuple [float , ...], desired_shape )):
64+ desired_shape = desired .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
65+
66+ msg = f"shapes do not match: { actual_shape } != f{ desired_shape } "
67+ assert actual_shape == desired_shape , msg
68+
69+ if check_dtype :
70+ msg = f"dtypes do not match: { actual .dtype } != { desired .dtype } "
71+ assert actual .dtype == desired .dtype , msg
72+
73+ if is_numpy_namespace (actual_xp ) and check_scalar :
74+ # only NumPy distinguishes between scalars and arrays; we do if check_scalar.
75+ _msg = (
76+ "array-ness does not match:\n Actual: "
77+ f"{ type (actual )} \n Desired: { type (desired )} "
78+ )
79+ assert (np .isscalar (actual ) and np .isscalar (desired )) or (
80+ not np .isscalar (actual ) and not np .isscalar (desired )
81+ ), _msg
6482
6583 return desired_xp
6684
6785
6886def _prepare_for_test (array : Array , xp : ModuleType ) -> Array :
6987 """
70- Ensure that the array can be compared with xp.testing or np.testing.
88+ Ensure that the array can be compared with np.testing.
7189
7290 This involves transferring it from GPU to CPU memory, densifying it, etc.
7391 """
7492 if is_torch_namespace (xp ):
75- return array .cpu () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
93+ return np . asarray ( array .cpu ()) # type: ignore[attr-defined, return-value ] # pyright: ignore[reportAttributeAccessIssue, reportUnknownArgumentType, reportReturnType ]
7694 if is_pydata_sparse_namespace (xp ):
7795 return array .todense () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
7896 if is_array_api_strict_namespace (xp ):
7997 # Note: we deliberately did not add a `.to_device` method in _typing.pyi
8098 # even if it is required by the standard as many backends don't support it
8199 return array .to_device (xp .Device ("CPU_DEVICE" )) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
82- # Note: nothing to do for CuPy, because it uses a bespoke test function
100+ if is_cupy_namespace (xp ):
101+ return xp .asnumpy (array )
83102 return array
84103
85104
86- def xp_assert_equal (actual : Array , desired : Array , err_msg : str = "" ) -> None :
105+ def xp_assert_equal (
106+ actual : Array ,
107+ desired : Array ,
108+ * ,
109+ err_msg : str = "" ,
110+ check_dtype : bool = True ,
111+ check_shape : bool = True ,
112+ check_scalar : bool = False ,
113+ ) -> None :
87114 """
88115 Array-API compatible version of `np.testing.assert_array_equal`.
89116
@@ -95,34 +122,21 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
95122 The expected array (typically hardcoded).
96123 err_msg : str, optional
97124 Error message to display on failure.
125+ check_dtype, check_shape : bool, default: True
126+ Whether to check agreement between actual and desired dtypes and shapes
127+ check_scalar : bool, default: False
128+ NumPy only: whether to check agreement between actual and desired types -
129+ 0d array vs scalar.
98130
99131 See Also
100132 --------
101133 xp_assert_close : Similar function for inexact equality checks.
102134 numpy.testing.assert_array_equal : Similar function for NumPy arrays.
103135 """
104- xp = _check_ns_shape_dtype (actual , desired )
136+ xp = _check_ns_shape_dtype (actual , desired , check_dtype , check_shape , check_scalar )
105137 actual = _prepare_for_test (actual , xp )
106138 desired = _prepare_for_test (desired , xp )
107-
108- if is_cupy_namespace (xp ):
109- xp .testing .assert_array_equal (actual , desired , err_msg = err_msg )
110- elif is_torch_namespace (xp ):
111- # PyTorch recommends using `rtol=0, atol=0` like this
112- # to test for exact equality
113- xp .testing .assert_close (
114- actual ,
115- desired ,
116- rtol = 0 ,
117- atol = 0 ,
118- equal_nan = True ,
119- check_dtype = False ,
120- msg = err_msg or None ,
121- )
122- else :
123- import numpy as np # pylint: disable=import-outside-toplevel
124-
125- np .testing .assert_array_equal (actual , desired , err_msg = err_msg )
139+ np .testing .assert_array_equal (actual , desired , err_msg = err_msg )
126140
127141
128142def xp_assert_close (
@@ -132,6 +146,9 @@ def xp_assert_close(
132146 rtol : float | None = None ,
133147 atol : float = 0 ,
134148 err_msg : str = "" ,
149+ check_dtype : bool = True ,
150+ check_shape : bool = True ,
151+ check_scalar : bool = False ,
135152) -> None :
136153 """
137154 Array-API compatible version of `np.testing.assert_allclose`.
@@ -148,6 +165,11 @@ def xp_assert_close(
148165 Absolute tolerance. Default: 0.
149166 err_msg : str, optional
150167 Error message to display on failure.
168+ check_dtype, check_shape : bool, default: True
169+ Whether to check agreement between actual and desired dtypes and shapes
170+ check_scalar : bool, default: False
171+ NumPy only: whether to check agreement between actual and desired types -
172+ 0d array vs scalar.
151173
152174 See Also
153175 --------
@@ -159,7 +181,7 @@ def xp_assert_close(
159181 -----
160182 The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
161183 """
162- xp = _check_ns_shape_dtype (actual , desired )
184+ xp = _check_ns_shape_dtype (actual , desired , check_dtype , check_shape , check_scalar )
163185
164186 floating = xp .isdtype (actual .dtype , ("real floating" , "complex floating" ))
165187 if rtol is None and floating :
@@ -173,26 +195,15 @@ def xp_assert_close(
173195 actual = _prepare_for_test (actual , xp )
174196 desired = _prepare_for_test (desired , xp )
175197
176- if is_cupy_namespace (xp ):
177- xp .testing .assert_allclose (
178- actual , desired , rtol = rtol , atol = atol , err_msg = err_msg
179- )
180- elif is_torch_namespace (xp ):
181- xp .testing .assert_close (
182- actual , desired , rtol = rtol , atol = atol , equal_nan = True , msg = err_msg or None
183- )
184- else :
185- import numpy as np # pylint: disable=import-outside-toplevel
186-
187- # JAX/Dask arrays work directly with `np.testing`
188- assert isinstance (rtol , float )
189- np .testing .assert_allclose ( # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
190- actual , # pyright: ignore[reportArgumentType]
191- desired , # pyright: ignore[reportArgumentType]
192- rtol = rtol ,
193- atol = atol ,
194- err_msg = err_msg ,
195- )
198+ # JAX/Dask arrays work directly with `np.testing`
199+ assert isinstance (rtol , float )
200+ np .testing .assert_allclose ( # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
201+ actual , # pyright: ignore[reportArgumentType]
202+ desired , # pyright: ignore[reportArgumentType]
203+ rtol = rtol ,
204+ atol = atol ,
205+ err_msg = err_msg ,
206+ )
196207
197208
198209def xfail (request : pytest .FixtureRequest , reason : str ) -> None :
0 commit comments