|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import sys |
3 | 4 | import itertools
|
4 | 5 | import pickle
|
5 | 6 |
|
|
16 | 17 | from matplotlib import _api, cbook
|
17 | 18 | import matplotlib.colors as mcolors
|
18 | 19 | from matplotlib.cbook import delete_masked_points
|
| 20 | +from types import ModuleType |
19 | 21 |
|
20 | 22 |
|
21 | 23 | class Test_delete_masked_points:
|
@@ -925,3 +927,45 @@ def test_auto_format_str(fmt, value, result):
|
925 | 927 | """Apply *value* to the format string *fmt*."""
|
926 | 928 | assert cbook._auto_format_str(fmt, value) == result
|
927 | 929 | assert cbook._auto_format_str(fmt, np.float64(value)) == result
|
| 930 | + |
| 931 | + |
| 932 | +def test_unpack_to_numpy_from_torch(): |
| 933 | + """Test that torch tensors are converted to numpy arrays. |
| 934 | + We don't want to create a dependency on torch in the test suite, so we mock it. |
| 935 | + """ |
| 936 | + class Tensor: |
| 937 | + def __init__(self, data): |
| 938 | + self.data = data |
| 939 | + def __array__(self): |
| 940 | + return self.data |
| 941 | + torch = ModuleType('torch') |
| 942 | + torch.Tensor = Tensor |
| 943 | + sys.modules['torch'] = torch |
| 944 | + |
| 945 | + data = np.arange(10) |
| 946 | + torch_tensor = torch.Tensor(data) |
| 947 | + |
| 948 | + result = cbook._unpack_to_numpy(torch_tensor) |
| 949 | + assert isinstance(result, np.ndarray) |
| 950 | + |
| 951 | + |
| 952 | +def test_unpack_to_numpy_from_jax(): |
| 953 | + """Test that jax arrays are converted to numpy arrays. |
| 954 | + We don't want to create a dependency on jax in the test suite, so we mock it. |
| 955 | + """ |
| 956 | + class Array: |
| 957 | + def __init__(self, data): |
| 958 | + self.data = data |
| 959 | + def __array__(self): |
| 960 | + return self.data |
| 961 | + |
| 962 | + jax = ModuleType('jax') |
| 963 | + jax.Array = Array |
| 964 | + |
| 965 | + sys.modules['jax'] = jax |
| 966 | + |
| 967 | + data = np.arange(10) |
| 968 | + jax_array = jax.Array(data) |
| 969 | + |
| 970 | + result = cbook._unpack_to_numpy(jax_array) |
| 971 | + assert isinstance(result, np.ndarray) |
0 commit comments