Skip to content

Commit 09463d0

Browse files
committed
revert NDArrayType narrowing
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent e3b8c17 commit 09463d0

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

equinox/_filters.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828

2929

3030
# Define array types
31-
# Note that we don't accept `np.generic` as a suitable duck-type for arrays,
32-
# because `generic` includes `np.object_` and `np.flexible` which are not
33-
# array-like.
34-
NDArrayType: TypeAlias = np.ndarray | np.number | np.bool_
31+
# Note that `np.generic` covers more than just the array-like dtypes,
32+
# e.g. `np.object_` and `np.flexible`. But `ml_dtypes` also defines dtypes
33+
# that inherit from `np.generic` and can't easily be listed here individually.
34+
NDArrayType: TypeAlias = np.ndarray | np.generic
3535
_NDARRAY_TYPES: Final = get_args(NDArrayType)
3636
_ARRAY_TYPES = _NDARRAY_TYPES + (jax.Array,)
3737

0 commit comments

Comments
 (0)