22
22
is_jax_namespace ,
23
23
is_numpy_namespace ,
24
24
is_pydata_sparse_namespace ,
25
+ is_torch_array ,
25
26
is_torch_namespace ,
26
27
to_device ,
27
28
)
@@ -62,18 +63,28 @@ def _check_ns_shape_dtype(
62
63
msg = f"namespaces do not match: { actual_xp } != f{ desired_xp } "
63
64
assert actual_xp == desired_xp , msg
64
65
65
- if check_shape :
66
- actual_shape = actual .shape
67
- desired_shape = desired .shape
68
- if is_dask_namespace (desired_xp ):
69
- # Dask uses nan instead of None for unknown shapes
70
- if any (math .isnan (i ) for i in cast (tuple [float , ...], actual_shape )):
71
- actual_shape = actual .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
72
- if any (math .isnan (i ) for i in cast (tuple [float , ...], desired_shape )):
73
- desired_shape = desired .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
66
+ # Dask uses nan instead of None for unknown shapes
67
+ actual_shape = cast (tuple [float , ...], actual .shape )
68
+ desired_shape = cast (tuple [float , ...], desired .shape )
69
+ assert None not in actual_shape # Requires explicit support
70
+ assert None not in desired_shape
71
+ if is_dask_namespace (desired_xp ):
72
+ if any (math .isnan (i ) for i in actual_shape ):
73
+ actual_shape = actual .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
74
+ if any (math .isnan (i ) for i in desired_shape ):
75
+ desired_shape = desired .compute ().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
74
76
77
+ if check_shape :
75
78
msg = f"shapes do not match: { actual_shape } != f{ desired_shape } "
76
79
assert actual_shape == desired_shape , msg
80
+ else :
81
+ # Ignore shape, but check flattened size. This is normally done by
82
+ # np.testing.assert_array_equal etc even when strict=False, but not for
83
+ # non-materializable arrays.
84
+ actual_size = math .prod (actual_shape ) # pyright: ignore[reportUnknownArgumentType]
85
+ desired_size = math .prod (desired_shape ) # pyright: ignore[reportUnknownArgumentType]
86
+ msg = f"sizes do not match: { actual_size } != f{ desired_size } "
87
+ assert actual_size == desired_size , msg
77
88
78
89
if check_dtype :
79
90
msg = f"dtypes do not match: { actual .dtype } != { desired .dtype } "
@@ -90,6 +101,15 @@ def _check_ns_shape_dtype(
90
101
return desired_xp
91
102
92
103
104
+ def _is_materializable (x : Array ) -> bool :
105
+ """
106
+ Return True if you can call `as_numpy_array(x)`; False otherwise.
107
+ """
108
+ # Important: here we assume that we're not tracing -
109
+ # e.g. we're not inside `jax.jit`` nor `cupy.cuda.Stream.begin_capture`.
110
+ return not is_torch_array (x ) or x .device .type != "meta" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
111
+
112
+
93
113
def as_numpy_array (array : Array , * , xp : ModuleType ) -> np .typing .NDArray [Any ]: # type: ignore[explicit-any]
94
114
"""
95
115
Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards.
@@ -146,6 +166,8 @@ def xp_assert_equal(
146
166
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
147
167
"""
148
168
xp = _check_ns_shape_dtype (actual , desired , check_dtype , check_shape , check_scalar )
169
+ if not _is_materializable (actual ):
170
+ return
149
171
actual_np = as_numpy_array (actual , xp = xp )
150
172
desired_np = as_numpy_array (desired , xp = xp )
151
173
np .testing .assert_array_equal (actual_np , desired_np , err_msg = err_msg )
@@ -181,6 +203,8 @@ def xp_assert_less(
181
203
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
182
204
"""
183
205
xp = _check_ns_shape_dtype (x , y , check_dtype , check_shape , check_scalar )
206
+ if not _is_materializable (x ):
207
+ return
184
208
x_np = as_numpy_array (x , xp = xp )
185
209
y_np = as_numpy_array (y , xp = xp )
186
210
np .testing .assert_array_less (x_np , y_np , err_msg = err_msg )
@@ -229,6 +253,8 @@ def xp_assert_close(
229
253
The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
230
254
"""
231
255
xp = _check_ns_shape_dtype (actual , desired , check_dtype , check_shape , check_scalar )
256
+ if not _is_materializable (actual ):
257
+ return
232
258
233
259
if rtol is None :
234
260
if xp .isdtype (actual .dtype , ("real floating" , "complex floating" )):
0 commit comments