Skip to content

Commit 8d6519c

Browse files
committed
update _unpack_to_numpy to handle JAX and PyTorch arrays
1 parent b61bb0b commit 8d6519c

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

lib/matplotlib/cbook.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2237,7 +2237,10 @@ def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
22372237
def _unpack_to_numpy(x):
22382238
"""Internal helper to extract data from e.g. pandas and xarray objects."""
22392239
if isinstance(x, np.ndarray):
2240-
# If numpy, return directly
2240+
# If numpy array, return directly
2241+
return x
2242+
if isinstance(x, np.generic):
2243+
# If numpy scalar, return directly
22412244
return x
22422245
if hasattr(x, 'to_numpy'):
22432246
# Assume that any to_numpy() method actually returns a numpy array
@@ -2248,6 +2251,12 @@ def _unpack_to_numpy(x):
22482251
# so in this case we do not want to return a function
22492252
if isinstance(xtmp, np.ndarray):
22502253
return xtmp
2254+
if hasattr(x, '__array__'):
2255+
# Assume that any to __array__() method returns a numpy array (e.g. TensorFlow, JAX or PyTorch arrays)
2256+
x = x.__array__()
2257+
# Anything that doesn't return ndarray via __array__() method will be filtered by the following check
2258+
if isinstance(x, np.ndarray):
2259+
return x
22512260
return x
22522261

22532262

0 commit comments

Comments
 (0)