2222 is_jax_namespace ,
2323 is_numpy_namespace ,
2424 is_pydata_sparse_namespace ,
25+ is_torch_array ,
2526 is_torch_namespace ,
2627 to_device ,
2728)
@@ -62,18 +63,28 @@ def _check_ns_shape_dtype(
6263 msg = f"namespaces do not match: { actual_xp } != f{ desired_xp } "
6364 assert actual_xp == desired_xp , msg
6465
65- if check_shape :
66- actual_shape = actual .shape
67- desired_shape = desired .shape
68- if is_dask_namespace (desired_xp ):
69- # Dask uses nan instead of None for unknown shapes
70- if any (math .isnan (i ) for i in cast (tuple [float , ...], actual_shape )):
71- actual_shape = actual .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
72- if any (math .isnan (i ) for i in cast (tuple [float , ...], desired_shape )):
73- desired_shape = desired .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
66+ # Dask uses nan instead of None for unknown shapes
67+ actual_shape = cast (tuple [float , ...], actual .shape )
68+ desired_shape = cast (tuple [float , ...], desired .shape )
69+ assert None not in actual_shape # Requires explicit support
70+ assert None not in desired_shape
71+ if is_dask_namespace (desired_xp ):
72+ if any (math .isnan (i ) for i in actual_shape ):
73+ actual_shape = actual .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
74+ if any (math .isnan (i ) for i in desired_shape ):
75+ desired_shape = desired .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
7476
77+ if check_shape :
7578 msg = f"shapes do not match: { actual_shape } != f{ desired_shape } "
7679 assert actual_shape == desired_shape , msg
80+ else :
81+ # Ignore shape, but check flattened size. This is normally done by
82+ # np.testing.assert_array_equal etc even when strict=False, but not for
83+ # non-materializable arrays.
84+ actual_size = math .prod (actual_shape ) # pyright: ignore[reportUnknownArgumentType]
85+ desired_size = math .prod (desired_shape ) # pyright: ignore[reportUnknownArgumentType]
86+ msg = f"sizes do not match: { actual_size } != f{ desired_size } "
87+ assert actual_size == desired_size , msg
7788
7889 if check_dtype :
7990 msg = f"dtypes do not match: { actual .dtype } != { desired .dtype } "
@@ -90,6 +101,17 @@ def _check_ns_shape_dtype(
90101 return desired_xp
91102
92103
104+ def _is_materializable (x : Array ) -> bool :
105+ """
106+ Check if the array is materializable, e.g. `as_numpy_array` can be called on it
107+ and one can assume that `__dlpack__` will succeed (if implemented, and given a
108+ compatible device).
109+ """
110+ # Important: here we assume that we're not tracing -
111+ # e.g. we're not inside `jax.jit`` nor `cupy.cuda.Stream.begin_capture`.
112+ return not is_torch_array (x ) or x .device .type != "meta" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
113+
114+
93115def as_numpy_array (array : Array , * , xp : ModuleType ) -> np .typing .NDArray [Any ]: # type: ignore[explicit-any]
94116 """
95117 Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards.
@@ -100,11 +122,7 @@ def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]:
100122 return array .todense () # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
101123
102124 if is_torch_namespace (xp ):
103- if array .device .type == "meta" : # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
104- # Can't materialize; generate dummy data instead
105- array = xp .zeros_like (array , device = "cpu" )
106- else :
107- array = to_device (array , "cpu" )
125+ array = to_device (array , "cpu" )
108126 if is_array_api_strict_namespace (xp ):
109127 cpu : Device = xp .Device ("CPU_DEVICE" )
110128 array = to_device (array , cpu )
@@ -150,6 +168,8 @@ def xp_assert_equal(
150168 numpy.testing.assert_array_equal : Similar function for NumPy arrays.
151169 """
152170 xp = _check_ns_shape_dtype (actual , desired , check_dtype , check_shape , check_scalar )
171+ if not _is_materializable (actual ):
172+ return
153173 actual_np = as_numpy_array (actual , xp = xp )
154174 desired_np = as_numpy_array (desired , xp = xp )
155175 np .testing .assert_array_equal (actual_np , desired_np , err_msg = err_msg )
@@ -185,6 +205,8 @@ def xp_assert_less(
185205 numpy.testing.assert_array_equal : Similar function for NumPy arrays.
186206 """
187207 xp = _check_ns_shape_dtype (x , y , check_dtype , check_shape , check_scalar )
208+ if not _is_materializable (x ):
209+ return
188210 x_np = as_numpy_array (x , xp = xp )
189211 y_np = as_numpy_array (y , xp = xp )
190212 np .testing .assert_array_less (x_np , y_np , err_msg = err_msg )
@@ -233,6 +255,8 @@ def xp_assert_close(
233255 The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
234256 """
235257 xp = _check_ns_shape_dtype (actual , desired , check_dtype , check_shape , check_scalar )
258+ if not _is_materializable (actual ):
259+ return
236260
237261 if rtol is None :
238262 if xp .isdtype (actual .dtype , ("real floating" , "complex floating" )):
0 commit comments