Skip to content

Commit ec68363

Browse files
committed
update _unpack_to_numpy to handle JAX and PyTorch arrays with sys.modules
1 parent 2cd9857 commit ec68363

File tree

1 file changed

+27
-9
lines changed

1 file changed

+27
-9
lines changed

lib/matplotlib/cbook.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import time
2020
import traceback
2121
import types
22+
import warnings
2223
import weakref
2324

2425
import numpy as np
@@ -2234,13 +2235,32 @@ def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
22342235
return cls.__new__(cls)
22352236

22362237

2238+
def _is_torch_array(x):
2239+
try:
2240+
return isinstance(x, sys.modules['torch'].Tensor)
2241+
except (KeyError, AttributeError, TypeError):
2242+
return False
2243+
except Exception as e:
2244+
warnings.warn(f"Error checking if {x} is a PyTorch Tensor: \n {e} \
2245+
\n Please report this issue to the developers.")
2246+
return False
2247+
2248+
2249+
def _is_jax_array(x):
2250+
try:
2251+
return isinstance(x, sys.modules['jax'].Array)
2252+
except (KeyError, AttributeError, TypeError):
2253+
return False
2254+
except Exception as e:
2255+
warnings.warn(f"Error checking if {x} is a JAX Array: \n {e} \
2256+
\n Please report this issue to the developers.")
2257+
return False
2258+
2259+
22372260
def _unpack_to_numpy(x):
22382261
"""Internal helper to extract data from e.g. pandas and xarray objects."""
22392262
if isinstance(x, np.ndarray):
2240-
# If numpy array, return directly
2241-
return x
2242-
if isinstance(x, np.generic):
2243-
# If numpy scalar, return directly
2263+
# If numpy, return directly
22442264
return x
22452265
if hasattr(x, 'to_numpy'):
22462266
# Assume that any to_numpy() method actually returns a numpy array
@@ -2251,12 +2271,10 @@ def _unpack_to_numpy(x):
22512271
# so in this case we do not want to return a function
22522272
if isinstance(xtmp, np.ndarray):
22532273
return xtmp
2254-
if hasattr(x, '__array__'):
2255-
# Assume that any to __array__() method returns a numpy array
2256-
# (e.g. TensorFlow, JAX or PyTorch arrays)
2274+
if _is_torch_array(x) or _is_jax_array(x):
22572275
xtmp = x.__array__()
2258-
# Anything that doesn't return ndarray via __array__() method
2259-
# will be filtered by the following check
2276+
2277+
# In case __array__() method does not return a numpy array in future
22602278
if isinstance(xtmp, np.ndarray):
22612279
return xtmp
22622280
return x

0 commit comments

Comments
 (0)