@@ -65,6 +65,24 @@ def _check_ns_shape_dtype(
6565 return desired_xp
6666
6767
68+ def _prepare_for_test (array : Array , xp : ModuleType ) -> Array :
69+ """
70+ Ensure that the array can be compared with xp.testing or np.testing.
71+
72+ This involves transferring it from GPU to CPU memory, densifying it, etc.
73+ """
74+ if is_torch_namespace (xp ):
75+ return array .cpu () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
76+ if is_pydata_sparse_namespace (xp ):
77+ return array .todense () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
78+ if is_array_api_strict_namespace (xp ):
79+ # Note: we deliberately did not add a `.to_device` method in _typing.pyi
80+ # even if it is required by the standard as many backends don't support it
81+ return array .to_device (xp .Device ("CPU_DEVICE" )) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
82+ # Note: nothing to do for CuPy, because it uses a bespoke test function
83+ return array
84+
85+
6886def xp_assert_equal (actual : Array , desired : Array , err_msg : str = "" ) -> None :
6987 """
7088 Array-API compatible version of `np.testing.assert_array_equal`.
@@ -84,6 +102,8 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
84102 numpy.testing.assert_array_equal : Similar function for NumPy arrays.
85103 """
86104 xp = _check_ns_shape_dtype (actual , desired )
105+ actual = _prepare_for_test (actual , xp )
106+ desired = _prepare_for_test (desired , xp )
87107
88108 if is_cupy_namespace (xp ):
89109 xp .testing .assert_array_equal (actual , desired , err_msg = err_msg )
@@ -102,22 +122,7 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
102122 else :
103123 import numpy as np # pylint: disable=import-outside-toplevel
104124
105- if is_pydata_sparse_namespace (xp ):
106- actual = actual .todense () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
107- desired = desired .todense () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
108-
109- actual_np = None
110- desired_np = None
111- if is_array_api_strict_namespace (xp ):
112- # __array__ doesn't work on array-api-strict device arrays
113- # We need to convert to the CPU device first
114- actual_np = np .asarray (xp .asarray (actual , device = xp .Device ("CPU_DEVICE" )))
115- desired_np = np .asarray (xp .asarray (desired , device = xp .Device ("CPU_DEVICE" )))
116-
117- # JAX/Dask arrays work with `np.testing`
118- actual_np = actual if actual_np is None else actual_np
119- desired_np = desired if desired_np is None else desired_np
120- np .testing .assert_array_equal (actual_np , desired_np , err_msg = err_msg ) # pyright: ignore[reportUnknownArgumentType]
125+ np .testing .assert_array_equal (actual , desired , err_msg = err_msg )
121126
122127
123128def xp_assert_close (
@@ -165,6 +170,9 @@ def xp_assert_close(
165170 elif rtol is None :
166171 rtol = 1e-7
167172
173+ actual = _prepare_for_test (actual , xp )
174+ desired = _prepare_for_test (desired , xp )
175+
168176 if is_cupy_namespace (xp ):
169177 xp .testing .assert_allclose (
170178 actual , desired , rtol = rtol , atol = atol , err_msg = err_msg
@@ -176,26 +184,11 @@ def xp_assert_close(
176184 else :
177185 import numpy as np # pylint: disable=import-outside-toplevel
178186
179- if is_pydata_sparse_namespace (xp ):
180- actual = actual .todense () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
181- desired = desired .todense () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
182-
183- actual_np = None
184- desired_np = None
185- if is_array_api_strict_namespace (xp ):
186- # __array__ doesn't work on array-api-strict device arrays
187- # We need to convert to the CPU device first
188- actual_np = np .asarray (xp .asarray (actual , device = xp .Device ("CPU_DEVICE" )))
189- desired_np = np .asarray (xp .asarray (desired , device = xp .Device ("CPU_DEVICE" )))
190-
191- # JAX/Dask arrays work with `np.testing`
192- actual_np = actual if actual_np is None else actual_np
193- desired_np = desired if desired_np is None else desired_np
194-
187+ # JAX/Dask arrays work directly with `np.testing`
195188 assert isinstance (rtol , float )
196- np .testing .assert_allclose ( # pyright: ignore[reportCallIssue]
197- actual_np , # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
198- desired_np , # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
189+ np .testing .assert_allclose ( # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
190+ actual , # pyright: ignore[reportArgumentType]
191+ desired , # pyright: ignore[reportArgumentType]
199192 rtol = rtol ,
200193 atol = atol ,
201194 err_msg = err_msg ,
0 commit comments