19
19
import time
20
20
import traceback
21
21
import types
22
+ import warnings
22
23
import weakref
23
24
24
25
import numpy as np
@@ -2234,13 +2235,32 @@ def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
2234
2235
return cls.__new__(cls)
2235
2236
2236
2237
2238
+ def _is_torch_array(x):
2239
+ try:
2240
+ return isinstance(x, sys.modules['torch'].Tensor)
2241
+ except (KeyError, AttributeError, TypeError):
2242
+ return False
2243
+ except Exception as e:
2244
+ warnings.warn(f"Error checking if {x} is a PyTorch Tensor: \n {e} \
2245
+ \n Please report this issue to the developers.")
2246
+ return False
2247
+
2248
+
2249
+ def _is_jax_array(x):
2250
+ try:
2251
+ return isinstance(x, sys.modules['jax'].Array)
2252
+ except (KeyError, AttributeError, TypeError):
2253
+ return False
2254
+ except Exception as e:
2255
+ warnings.warn(f"Error checking if {x} is a JAX Array: \n {e} \
2256
+ \n Please report this issue to the developers.")
2257
+ return False
2258
+
2259
+
2237
2260
def _unpack_to_numpy(x):
2238
2261
"""Internal helper to extract data from e.g. pandas and xarray objects."""
2239
2262
if isinstance(x, np.ndarray):
2240
- # If numpy array, return directly
2241
- return x
2242
- if isinstance(x, np.generic):
2243
- # If numpy scalar, return directly
2263
+ # If numpy, return directly
2244
2264
return x
2245
2265
if hasattr(x, 'to_numpy'):
2246
2266
# Assume that any to_numpy() method actually returns a numpy array
@@ -2251,12 +2271,10 @@ def _unpack_to_numpy(x):
2251
2271
# so in this case we do not want to return a function
2252
2272
if isinstance(xtmp, np.ndarray):
2253
2273
return xtmp
2254
- if hasattr(x, '__array__'):
2255
- # Assume that any to __array__() method returns a numpy array
2256
- # (e.g. TensorFlow, JAX or PyTorch arrays)
2274
+ if _is_torch_array(x) or _is_jax_array(x):
2257
2275
xtmp = x.__array__()
2258
- # Anything that doesn't return ndarray via __array__() method
2259
- # will be filtered by the following check
2276
+
2277
+ # In case __array__() method does not return a numpy array in future
2260
2278
if isinstance(xtmp, np.ndarray):
2261
2279
return xtmp
2262
2280
return x
0 commit comments