1919import time
2020import traceback
2121import types
22+ import warnings
2223import weakref
2324
2425import numpy as np
@@ -2234,13 +2235,32 @@ def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
22342235 return cls .__new__ (cls )
22352236
22362237
2238+ def _is_torch_array (x ):
2239+ try :
2240+ return isinstance (x , sys .modules ['torch' ].Tensor )
2241+ except (KeyError , AttributeError , TypeError ):
2242+ return False
2243+ except Exception as e :
2244+ warnings .warn (f"Error checking if { x } is a PyTorch Tensor: \n { e } \
2245+ \n Please report this issue to the developers." )
2246+ return False
2247+
2248+
2249+ def _is_jax_array (x ):
2250+ try :
2251+ return isinstance (x , sys .modules ['jax' ].Array )
2252+ except (KeyError , AttributeError , TypeError ):
2253+ return False
2254+ except Exception as e :
2255+ warnings .warn (f"Error checking if { x } is a JAX Array: \n { e } \
2256+ \n Please report this issue to the developers." )
2257+ return False
2258+
2259+
22372260def _unpack_to_numpy (x ):
22382261 """Internal helper to extract data from e.g. pandas and xarray objects."""
22392262 if isinstance (x , np .ndarray ):
2240- # If numpy array, return directly
2241- return x
2242- if isinstance (x , np .generic ):
2243- # If numpy scalar, return directly
2263+ # If numpy, return directly
22442264 return x
22452265 if hasattr (x , 'to_numpy' ):
22462266 # Assume that any to_numpy() method actually returns a numpy array
@@ -2251,12 +2271,10 @@ def _unpack_to_numpy(x):
22512271 # so in this case we do not want to return a function
22522272 if isinstance (xtmp , np .ndarray ):
22532273 return xtmp
2254- if hasattr (x , '__array__' ):
2255- # Assume that any to __array__() method returns a numpy array
2256- # (e.g. TensorFlow, JAX or PyTorch arrays)
2274+ if _is_torch_array (x ) or _is_jax_array (x ):
22572275 xtmp = x .__array__ ()
2258- # Anything that doesn't return ndarray via __array__() method
2259- # will be filtered by the following check
2276+
2277+ # In case __array__() method does not return a numpy array in future
22602278 if isinstance (xtmp , np .ndarray ):
22612279 return xtmp
22622280 return x
0 commit comments