Skip to content

Commit 87ee115

Browse files
committed
document _is_torch_array and _is_jax_array. Add tests for them.
1 parent a72a1ef commit 87ee115

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

lib/matplotlib/cbook.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2237,16 +2237,20 @@ def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
22372237
def _is_torch_array(x):
22382238
"""Check if 'x' is a PyTorch Tensor."""
22392239
try:
2240+
# we're intentionally not attempting to import torch. If somebody
2241+
# has created a torch array, torch should already be in sys.modules
22402242
return isinstance(x, sys.modules['torch'].Tensor)
2241-
except Exception as e: # not using bare `except` to bypass flake8
2243+
except Exception:
22422244
return False
22432245

22442246

22452247
def _is_jax_array(x):
22462248
"""Check if 'x' is a JAX Array."""
22472249
try:
2250+
# we're intentionally not attempting to import jax. If somebody
2251+
# has created a jax array, jax should already be in sys.modules
22482252
return isinstance(x, sys.modules['jax'].Array)
2249-
except Exception as e: # not using bare `except` to bypass flake8
2253+
except Exception:
22502254
return False
22512255

22522256

lib/matplotlib/tests/test_cbook.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import sys
34
import itertools
45
import pickle
56

@@ -16,6 +17,7 @@
1617
from matplotlib import _api, cbook
1718
import matplotlib.colors as mcolors
1819
from matplotlib.cbook import delete_masked_points
20+
from types import ModuleType
1921

2022

2123
class Test_delete_masked_points:
@@ -925,3 +927,45 @@ def test_auto_format_str(fmt, value, result):
925927
"""Apply *value* to the format string *fmt*."""
926928
assert cbook._auto_format_str(fmt, value) == result
927929
assert cbook._auto_format_str(fmt, np.float64(value)) == result
930+
931+
932+
def test_unpack_to_numpy_from_torch():
933+
"""Test that torch tensors are converted to numpy arrays.
934+
We don't want to create a dependency on torch in the test suite, so we mock it.
935+
"""
936+
class Tensor:
937+
def __init__(self, data):
938+
self.data = data
939+
def __array__(self):
940+
return self.data
941+
torch = ModuleType('torch')
942+
torch.Tensor = Tensor
943+
sys.modules['torch'] = torch
944+
945+
data = np.arange(10)
946+
torch_tensor = torch.Tensor(data)
947+
948+
result = cbook._unpack_to_numpy(torch_tensor)
949+
assert isinstance(result, np.ndarray)
950+
951+
952+
def test_unpack_to_numpy_from_jax():
953+
"""Test that jax arrays are converted to numpy arrays.
954+
We don't want to create a dependency on jax in the test suite, so we mock it.
955+
"""
956+
class Array:
957+
def __init__(self, data):
958+
self.data = data
959+
def __array__(self):
960+
return self.data
961+
962+
jax = ModuleType('jax')
963+
jax.Array = Array
964+
965+
sys.modules['jax'] = jax
966+
967+
data = np.arange(10)
968+
jax_array = jax.Array(data)
969+
970+
result = cbook._unpack_to_numpy(jax_array)
971+
assert isinstance(result, np.ndarray)

0 commit comments

Comments
 (0)