Skip to content

Commit 211a84c

Browse files
committed
Recursion guard
1 parent 1dbe2d0 commit 211a84c

File tree

2 files changed

+47
-17
lines changed

2 files changed

+47
-17
lines changed

src/array_api_extra/_lib/_lazy.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -324,28 +324,48 @@ def _contains_jax_arrays(x: object) -> bool: # numpydoc ignore=PR01,RT01
324324
"""
325325
Test if x is a JAX array or a nested collection with any JAX arrays in it.
326326
"""
327-
if is_jax_array(x):
328-
return True
329-
if isinstance(x, list | tuple):
330-
return any(_contains_jax_arrays(i) for i in x) # pyright: ignore[reportUnknownArgumentType]
331-
if isinstance(x, dict):
332-
return any(_contains_jax_arrays(i) for i in x.values()) # pyright: ignore[reportUnknownArgumentType]
333-
return False
327+
seen = set()
328+
329+
def recursion(x: object) -> bool: # numpydoc ignore=GL08
330+
if id(x) in seen:
331+
return False
332+
seen.add(id(x))
333+
334+
if is_jax_array(x):
335+
return True
336+
if isinstance(x, list | tuple):
337+
return any(recursion(i) for i in x) # pyright: ignore[reportUnknownArgumentType]
338+
if isinstance(x, dict):
339+
return any(recursion(i) for i in x.values()) # pyright: ignore[reportUnknownArgumentType]
340+
return False
341+
342+
return recursion(x)
334343

335344

336345
def _as_numpy(x: object) -> Any: # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
337346
"""Recursively convert Array API objects in x to NumPy."""
338347
import numpy as np # pylint: disable=import-outside-toplevel
339348

340-
if is_array_api_obj(x):
341-
return np.asarray(x)
342-
if isinstance(x, list) or type(x) is tuple: # pylint: disable=unidiomatic-typecheck
343-
return type(x)(_as_numpy(i) for i in x) # pyright: ignore[reportUnknownArgumentType]
344-
if isinstance(x, tuple): # namedtuple
345-
return type(x)(*(_as_numpy(i) for i in x)) # pyright: ignore[reportUnknownArgumentType]
346-
if isinstance(x, dict):
347-
return {k: _as_numpy(v) for k, v in x.items()} # pyright: ignore[reportUnknownArgumentType]
348-
return x
349+
seen = set()
350+
351+
def recursion(x: Any) -> Any: # type: ignore[no-any-explicit] # numpydoc ignore=GL08
352+
if is_array_api_obj(x):
353+
return np.asarray(x)
354+
if not isinstance(x, list | tuple | dict):
355+
return x
356+
357+
if id(x) in seen: # pyright: ignore[reportUnknownArgumentType]
358+
return x # Recursive collections can't contain arrays
359+
seen.add(id(x)) # pyright: ignore[reportUnknownArgumentType]
360+
361+
if isinstance(x, list) or type(x) is tuple: # pylint: disable=unidiomatic-typecheck # pyright: ignore[reportUnknownArgumentType]
362+
return type(x)(recursion(i) for i in x) # pyright: ignore[reportUnknownArgumentType]
363+
if isinstance(x, tuple): # namedtuple
364+
return type(x)(*(recursion(i) for i in x)) # pyright: ignore[reportUnknownArgumentType]
365+
# dict
366+
return {k: recursion(v) for k, v in x.items()}
367+
368+
return recursion(x)
349369

350370

351371
def _lazy_apply_wrapper( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01

tests/test_lazy.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from array_api_extra._lib import Backend
1010
from array_api_extra._lib._testing import xp_assert_equal
1111
from array_api_extra._lib._utils import _compat
12-
from array_api_extra._lib._utils._compat import array_namespace
12+
from array_api_extra._lib._utils._compat import array_namespace, is_dask_array
1313
from array_api_extra._lib._utils._typing import Array, Device
1414
from array_api_extra.testing import lazy_xp_function
1515

@@ -288,12 +288,18 @@ class NT(NamedTuple):
288288

289289

290290
def check_lazy_apply_kwargs(x: Array, expect_cls: type, as_numpy: bool) -> Array:
291+
is_dask = is_dask_array(x)
292+
recursive: list[object] = []
293+
if not is_dask: # dask.delayed crashes on recursion
294+
recursive.append(recursive)
295+
291296
def eager(
292297
x: Array,
293298
z: dict[str, list[Array] | tuple[Array, ...] | NT],
294299
msg: str,
295300
msgs: list[str],
296301
scalar: int,
302+
recursive: list[list[object]],
297303
) -> Array:
298304
assert isinstance(x, expect_cls)
299305
# JAX will crash if x isn't material
@@ -310,6 +316,9 @@ def eager(
310316
assert isinstance(msgs[0], str)
311317
assert scalar == 1 # must be hidden from JAX
312318
assert isinstance(scalar, int)
319+
assert isinstance(recursive, list)
320+
if not is_dask:
321+
assert recursive[0][0] is recursive[0]
313322
return x + 1 # type: ignore[operator]
314323

315324
# Use explicit namespace to bypass monkey-patching by lazy_xp_function
@@ -323,6 +332,7 @@ def eager(
323332
msgs=["Hello World"],
324333
# This will be automatically cast to jax.Array if we don't wrap it
325334
scalar=1,
335+
recursive=recursive,
326336
shape=x.shape,
327337
dtype=x.dtype,
328338
as_numpy=as_numpy,

0 commit comments

Comments
 (0)