|
37 | 37 | Union, |
38 | 38 | ) |
39 | 39 |
|
| 40 | +from ._errors import AnnotationError |
| 41 | +from ._storage import ( |
| 42 | + get_shape_memo, |
| 43 | + get_treeflatten_memo, |
| 44 | + get_treepath_memo, |
| 45 | + set_shape_memo, |
| 46 | +) |
| 47 | + |
40 | 48 |
|
41 | 49 | # Bit of a hack, but jaxtyping provides nicer error messages than typeguard. This means |
42 | 50 | # we sometimes want to use it as our runtime type checker everywhere, even in non-array |
|
46 | 54 | # messages and (c) the import hook that places the checker on the bottom of the |
47 | 55 | # decorator stack.) And resist the urge to write our own runtime type-checker, I really |
48 | 56 | # don't want to have to keep that up-to-date with changes in the Python typing spec... |
49 | | -if importlib.util.find_spec("numpy") is not None: |
| 57 | +IS_NUMPY_INSTALLED = importlib.util.find_spec("numpy") is not None |
| 58 | +if IS_NUMPY_INSTALLED: |
50 | 59 | import numpy as np |
51 | 60 | import numpy.typing as npt |
52 | 61 |
|
53 | | -from ._errors import AnnotationError |
54 | | -from ._storage import ( |
55 | | - get_shape_memo, |
56 | | - get_treeflatten_memo, |
57 | | - get_treepath_memo, |
58 | | - set_shape_memo, |
59 | | -) |
60 | | - |
61 | 62 |
|
62 | 63 | _array_name_format = "dtype_and_shape" |
63 | 64 |
|
@@ -170,7 +171,11 @@ def _check_dims( |
170 | 171 |
|
171 | 172 |
|
172 | 173 | def _dtype_is_numpy_struct_array(dtype): |
173 | | - return dtype.type.__name__ == "void" and dtype is not np.dtype(np.void) |
| 174 | + return ( |
| 175 | + IS_NUMPY_INSTALLED |
| 176 | + and (dtype.type.__name__ == "void") |
| 177 | + and (dtype is not np.dtype(np.void)) |
| 178 | + ) |
174 | 179 |
|
175 | 180 |
|
176 | 181 | class _MetaAbstractArray(type): |
@@ -548,12 +553,12 @@ def _make_array_cached(array_type, dim_str, dtypes, name): |
548 | 553 | return array_type |
549 | 554 | else: |
550 | 555 | return _not_made |
551 | | - elif array_type is np.bool_: |
| 556 | + elif IS_NUMPY_INSTALLED and array_type is np.bool_: |
552 | 557 | if _check_scalar("bool", dtypes, dims): |
553 | 558 | return array_type |
554 | 559 | else: |
555 | 560 | return _not_made |
556 | | - elif array_type is np.generic or array_type is np.number: |
| 561 | + elif IS_NUMPY_INSTALLED and (array_type is np.generic or array_type is np.number): |
557 | 562 | if _check_scalar("", dtypes, dims): |
558 | 563 | return array_type |
559 | 564 | else: |
@@ -647,7 +652,7 @@ def __getitem__(cls, item: tuple[Any, str]): |
647 | 652 | array_type = Union[constraints] |
648 | 653 | else: |
649 | 654 | array_type = bound |
650 | | - if "npt" in globals().keys() and array_type is npt.ArrayLike: |
| 655 | + if IS_NUMPY_INSTALLED and array_type is npt.ArrayLike: |
651 | 656 | # Work around https://github.com/numpy/numpy/commit/1041f940f91660c91770679c60f6e63539581c72 |
652 | 657 | # which removes `bool`/`int`/`float` from the union. |
653 | 658 | array_type = Union[(*get_args(array_type), bool, int, float, complex)] |
@@ -840,6 +845,10 @@ def make_numpy_struct_dtype(dtype: "np.dtype", name: str): |
840 | 845 | A type annotation with classname `name` that matches exactly `dtype` when used like |
841 | 846 | any other [`jaxtyping.AbstractDtype`][]. |
842 | 847 | """ |
843 | | - if not (isinstance(dtype, np.dtype) and _dtype_is_numpy_struct_array(dtype)): |
| 848 | + if not ( |
| 849 | + IS_NUMPY_INSTALLED |
| 850 | + and isinstance(dtype, np.dtype) |
| 851 | + and _dtype_is_numpy_struct_array(dtype) |
| 852 | + ): |
844 | 853 | raise ValueError(f"Expecting a numpy structured array dtype, not {dtype}") |
845 | 854 | return _make_dtype(str(dtype), name) |
0 commit comments