2222 is_jax_namespace ,
2323 is_numpy_namespace ,
2424 is_pydata_sparse_namespace ,
25+ is_torch_array ,
2526 is_torch_namespace ,
2627 to_device ,
2728)
@@ -62,18 +63,28 @@ def _check_ns_shape_dtype(
6263 msg = f"namespaces do not match: { actual_xp } != f{ desired_xp } "
6364 assert actual_xp == desired_xp , msg
6465
65- if check_shape :
66- actual_shape = actual .shape
67- desired_shape = desired .shape
68- if is_dask_namespace (desired_xp ):
69- # Dask uses nan instead of None for unknown shapes
70- if any (math .isnan (i ) for i in cast (tuple [float , ...], actual_shape )):
71- actual_shape = actual .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
72- if any (math .isnan (i ) for i in cast (tuple [float , ...], desired_shape )):
73- desired_shape = desired .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
66+ # Dask uses nan instead of None for unknown shapes
67+ actual_shape = cast (tuple [float , ...], actual .shape )
68+ desired_shape = cast (tuple [float , ...], desired .shape )
69+ assert None not in actual_shape # Requires explicit support
70+ assert None not in desired_shape
71+ if is_dask_namespace (desired_xp ):
72+ if any (math .isnan (i ) for i in actual_shape ):
73+ actual_shape = actual .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
74+ if any (math .isnan (i ) for i in desired_shape ):
75+ desired_shape = desired .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
7476
77+ if check_shape :
7578 msg = f"shapes do not match: { actual_shape } != f{ desired_shape } "
7679 assert actual_shape == desired_shape , msg
80+ else :
81+ # Ignore shape, but check flattened size. This is normally done by
82+ # np.testing.assert_array_equal etc even when strict=False, but not for
83+ # non-materializable arrays.
84+ actual_size = math .prod (actual_shape ) # pyright: ignore[reportUnknownArgumentType]
85+ desired_size = math .prod (desired_shape ) # pyright: ignore[reportUnknownArgumentType]
86+ msg = f"sizes do not match: { actual_size } != f{ desired_size } "
87+ assert actual_size == desired_size , msg
7788
7889 if check_dtype :
7990 msg = f"dtypes do not match: { actual .dtype } != { desired .dtype } "
@@ -90,6 +101,15 @@ def _check_ns_shape_dtype(
90101 return desired_xp
91102
92103
104+ def _is_materializable (x : Array ) -> bool :
105+ """
106+ Return True if you can call `as_numpy_array(x)`; False otherwise.
107+ """
108+ # Important: here we assume that we're not tracing -
109+ # e.g. we're not inside `jax.jit`` nor `cupy.cuda.Stream.begin_capture`.
110+ return not is_torch_array (x ) or x .device .type != "meta" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
111+
112+
93113def as_numpy_array (array : Array , * , xp : ModuleType ) -> np .typing .NDArray [Any ]: # type: ignore[explicit-any]
94114 """
95115 Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards.
@@ -146,6 +166,8 @@ def xp_assert_equal(
146166 numpy.testing.assert_array_equal : Similar function for NumPy arrays.
147167 """
148168 xp = _check_ns_shape_dtype (actual , desired , check_dtype , check_shape , check_scalar )
169+ if not _is_materializable (actual ):
170+ return
149171 actual_np = as_numpy_array (actual , xp = xp )
150172 desired_np = as_numpy_array (desired , xp = xp )
151173 np .testing .assert_array_equal (actual_np , desired_np , err_msg = err_msg )
@@ -181,6 +203,8 @@ def xp_assert_less(
181203 numpy.testing.assert_array_equal : Similar function for NumPy arrays.
182204 """
183205 xp = _check_ns_shape_dtype (x , y , check_dtype , check_shape , check_scalar )
206+ if not _is_materializable (x ):
207+ return
184208 x_np = as_numpy_array (x , xp = xp )
185209 y_np = as_numpy_array (y , xp = xp )
186210 np .testing .assert_array_less (x_np , y_np , err_msg = err_msg )
@@ -229,6 +253,8 @@ def xp_assert_close(
229253 The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
230254 """
231255 xp = _check_ns_shape_dtype (actual , desired , check_dtype , check_shape , check_scalar )
256+ if not _is_materializable (actual ):
257+ return
232258
233259 if rtol is None :
234260 if xp .isdtype (actual .dtype , ("real floating" , "complex floating" )):
0 commit comments