1313import sys
1414import warnings
1515from collections .abc import Collection
16+ from functools import lru_cache
1617from typing import (
1718 TYPE_CHECKING ,
1819 Any ,
6162_API_VERSIONS : Final = _API_VERSIONS_OLD | frozenset ({"2024.12" })
6263
6364
64- def _is_jax_zero_gradient_array (x : object ) -> TypeGuard [_ZeroGradientArray ]:
65- @cache
65+ @lru_cache (100 )
6666def _issubclass_fast (cls : type , modname : str , clsname : str ) -> bool :
6767 try :
6868 mod = sys .modules [modname ]
@@ -72,6 +72,7 @@ def _issubclass_fast(cls: type, modname: str, clsname: str) -> bool:
7272 return issubclass (cls , parent_cls )
7373
7474
75+ def _is_jax_zero_gradient_array (x : object ) -> TypeGuard [_ZeroGradientArray ]:
7576 """Return True if `x` is a zero-gradient array.
7677
7778 These arrays are a design quirk of Jax that may one day be removed.
@@ -276,7 +277,7 @@ def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[repo
276277 return hasattr (x , '__array_namespace__' ) or _is_array_api_cls (type (x ))
277278
278279
279- @cache
280+ @lru_cache ( 100 )
280281def _is_array_api_cls (cls : type ) -> bool :
281282 return (
282283 # TODO: drop support for numpy<2 which didn't have __array_namespace__
@@ -296,7 +297,7 @@ def _compat_module_name() -> str:
296297 return __name__ .removesuffix (".common._helpers" )
297298
298299
299- @cache
300+ @lru_cache ( 100 )
300301def is_numpy_namespace (xp : Namespace ) -> bool :
301302 """
302303 Returns True if `xp` is a NumPy namespace.
@@ -318,7 +319,7 @@ def is_numpy_namespace(xp: Namespace) -> bool:
318319 return xp .__name__ in {"numpy" , _compat_module_name () + ".numpy" }
319320
320321
321- @cache
322+ @lru_cache ( 100 )
322323def is_cupy_namespace (xp : Namespace ) -> bool :
323324 """
324325 Returns True if `xp` is a CuPy namespace.
@@ -340,7 +341,7 @@ def is_cupy_namespace(xp: Namespace) -> bool:
340341 return xp .__name__ in {"cupy" , _compat_module_name () + ".cupy" }
341342
342343
343- @cache
344+ @lru_cache ( 100 )
344345def is_torch_namespace (xp : Namespace ) -> bool :
345346 """
346347 Returns True if `xp` is a PyTorch namespace.
@@ -381,7 +382,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool:
381382 return xp .__name__ == "ndonnx"
382383
383384
384- @cache
385+ @lru_cache ( 100 )
385386def is_dask_namespace (xp : Namespace ) -> bool :
386387 """
387388 Returns True if `xp` is a Dask namespace.
@@ -922,7 +923,7 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
922923 return None if math .isnan (out ) else out
923924
924925
925- @cache
926+ @lru_cache ( 100 )
926927def _is_writeable_cls (cls : type ) -> bool | None :
927928 if (
928929 _issubclass_fast (cls , "numpy" , "generic" )
@@ -954,7 +955,7 @@ def is_writeable_array(x: object) -> bool:
954955 return hasattr (x , '__array_namespace__' )
955956
956957
957- @cache
958+ @lru_cache ( 100 )
958959def _is_lazy_cls (cls : type ) -> bool | None :
959960 if (
960961 _issubclass_fast (cls , "numpy" , "ndarray" )
@@ -1054,7 +1055,7 @@ def is_lazy_array(x: object) -> bool:
10541055 "to_device" ,
10551056]
10561057
1057- _all_ignore = ['cache ' , 'sys' , 'math' , 'inspect' , 'warnings' ]
1058+ _all_ignore = ['lru_cache ' , 'sys' , 'math' , 'inspect' , 'warnings' ]
10581059
10591060def __dir__ () -> list [str ]:
10601061 return __all__
0 commit comments