Skip to content

Commit 5eb04f9

Browse files
Guard against implicit dependency on numpy if it is not imported (#361)
* Guard against implicit dependency on numpy if it is not imported * Fix pre-commits --------- Co-authored-by: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com>
1 parent c8fde7d commit 5eb04f9

File tree

1 file changed

+23
-14
lines changed

1 file changed

+23
-14
lines changed

jaxtyping/_array_types.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@
3737
Union,
3838
)
3939

40+
from ._errors import AnnotationError
41+
from ._storage import (
42+
get_shape_memo,
43+
get_treeflatten_memo,
44+
get_treepath_memo,
45+
set_shape_memo,
46+
)
47+
4048

4149
# Bit of a hack, but jaxtyping provides nicer error messages than typeguard. This means
4250
# we sometimes want to use it as our runtime type checker everywhere, even in non-array
@@ -46,18 +54,11 @@
4654
# messages and (c) the import hook that places the checker on the bottom of the
4755
# decorator stack.) And resist the urge to write our own runtime type-checker, I really
4856
# don't want to have to keep that up-to-date with changes in the Python typing spec...
49-
if importlib.util.find_spec("numpy") is not None:
57+
IS_NUMPY_INSTALLED = importlib.util.find_spec("numpy") is not None
58+
if IS_NUMPY_INSTALLED:
5059
import numpy as np
5160
import numpy.typing as npt
5261

53-
from ._errors import AnnotationError
54-
from ._storage import (
55-
get_shape_memo,
56-
get_treeflatten_memo,
57-
get_treepath_memo,
58-
set_shape_memo,
59-
)
60-
6162

6263
_array_name_format = "dtype_and_shape"
6364

@@ -170,7 +171,11 @@ def _check_dims(
170171

171172

172173
def _dtype_is_numpy_struct_array(dtype):
173-
return dtype.type.__name__ == "void" and dtype is not np.dtype(np.void)
174+
return (
175+
IS_NUMPY_INSTALLED
176+
and (dtype.type.__name__ == "void")
177+
and (dtype is not np.dtype(np.void))
178+
)
174179

175180

176181
class _MetaAbstractArray(type):
@@ -548,12 +553,12 @@ def _make_array_cached(array_type, dim_str, dtypes, name):
548553
return array_type
549554
else:
550555
return _not_made
551-
elif array_type is np.bool_:
556+
elif IS_NUMPY_INSTALLED and array_type is np.bool_:
552557
if _check_scalar("bool", dtypes, dims):
553558
return array_type
554559
else:
555560
return _not_made
556-
elif array_type is np.generic or array_type is np.number:
561+
elif IS_NUMPY_INSTALLED and (array_type is np.generic or array_type is np.number):
557562
if _check_scalar("", dtypes, dims):
558563
return array_type
559564
else:
@@ -647,7 +652,7 @@ def __getitem__(cls, item: tuple[Any, str]):
647652
array_type = Union[constraints]
648653
else:
649654
array_type = bound
650-
if "npt" in globals().keys() and array_type is npt.ArrayLike:
655+
if IS_NUMPY_INSTALLED and array_type is npt.ArrayLike:
651656
# Work around https://github.com/numpy/numpy/commit/1041f940f91660c91770679c60f6e63539581c72
652657
# which removes `bool`/`int`/`float` from the union.
653658
array_type = Union[(*get_args(array_type), bool, int, float, complex)]
@@ -840,6 +845,10 @@ def make_numpy_struct_dtype(dtype: "np.dtype", name: str):
840845
A type annotation with classname `name` that matches exactly `dtype` when used like
841846
any other [`jaxtyping.AbstractDtype`][].
842847
"""
843-
if not (isinstance(dtype, np.dtype) and _dtype_is_numpy_struct_array(dtype)):
848+
if not (
849+
IS_NUMPY_INSTALLED
850+
and isinstance(dtype, np.dtype)
851+
and _dtype_is_numpy_struct_array(dtype)
852+
):
844853
raise ValueError(f"Expecting a numpy structured array dtype, not {dtype}")
845854
return _make_dtype(str(dtype), name)

0 commit comments

Comments
 (0)