@@ -324,28 +324,48 @@ def _contains_jax_arrays(x: object) -> bool: # numpydoc ignore=PR01,RT01
324324 """
325325 Test if x is a JAX array or a nested collection with any JAX arrays in it.
326326 """
327- if is_jax_array (x ):
328- return True
329- if isinstance (x , list | tuple ):
330- return any (_contains_jax_arrays (i ) for i in x ) # pyright: ignore[reportUnknownArgumentType]
331- if isinstance (x , dict ):
332- return any (_contains_jax_arrays (i ) for i in x .values ()) # pyright: ignore[reportUnknownArgumentType]
333- return False
327+ seen = set ()
328+
329+ def recursion (x : object ) -> bool : # numpydoc ignore=GL08
330+ if id (x ) in seen :
331+ return False
332+ seen .add (id (x ))
333+
334+ if is_jax_array (x ):
335+ return True
336+ if isinstance (x , list | tuple ):
337+ return any (recursion (i ) for i in x ) # pyright: ignore[reportUnknownArgumentType]
338+ if isinstance (x , dict ):
339+ return any (recursion (i ) for i in x .values ()) # pyright: ignore[reportUnknownArgumentType]
340+ return False
341+
342+ return recursion (x )
334343
335344
336345def _as_numpy (x : object ) -> Any : # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
337346 """Recursively convert Array API objects in x to NumPy."""
338347 import numpy as np # pylint: disable=import-outside-toplevel
339348
340- if is_array_api_obj (x ):
341- return np .asarray (x )
342- if isinstance (x , list ) or type (x ) is tuple : # pylint: disable=unidiomatic-typecheck
343- return type (x )(_as_numpy (i ) for i in x ) # pyright: ignore[reportUnknownArgumentType]
344- if isinstance (x , tuple ): # namedtuple
345- return type (x )(* (_as_numpy (i ) for i in x )) # pyright: ignore[reportUnknownArgumentType]
346- if isinstance (x , dict ):
347- return {k : _as_numpy (v ) for k , v in x .items ()} # pyright: ignore[reportUnknownArgumentType]
348- return x
349+ seen = set ()
350+
351+ def recursion (x : Any ) -> Any : # type: ignore[no-any-explicit] # numpydoc ignore=GL08
352+ if is_array_api_obj (x ):
353+ return np .asarray (x )
354+ if not isinstance (x , list | tuple | dict ):
355+ return x
356+
357+ if id (x ) in seen : # pyright: ignore[reportUnknownArgumentType]
358+ return x # Recursive collections can't contain arrays
359+ seen .add (id (x )) # pyright: ignore[reportUnknownArgumentType]
360+
361+ if isinstance (x , list ) or type (x ) is tuple : # pylint: disable=unidiomatic-typecheck # pyright: ignore[reportUnknownArgumentType]
362+ return type (x )(recursion (i ) for i in x ) # pyright: ignore[reportUnknownArgumentType]
363+ if isinstance (x , tuple ): # namedtuple
364+ return type (x )(* (recursion (i ) for i in x )) # pyright: ignore[reportUnknownArgumentType]
365+ # dict
366+ return {k : recursion (v ) for k , v in x .items ()}
367+
368+ return recursion (x )
349369
350370
351371def _lazy_apply_wrapper ( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
0 commit comments