types: is_array(_like) provides TypeGuard#1148
types: is_array(_like) provides TypeGuard#1148nstarman wants to merge 5 commits intopatrick-kidger:mainfrom
Conversation
tests/test_jit.py
Outdated
| array_tree = [{"a": a, "b": b}, (c,)] | ||
| mlp_add = jtu.tree_map(lambda u: u + 1 if eqx.is_array(u) else u, general_tree[-1]) | ||
| mlp_add = jtu.tree_map( | ||
| lambda u: jnp.asarray(u) + 1 if eqx.is_array(u) else u, general_tree[-1] |
There was a problem hiding this comment.
is_array checks for np.generic, but this doen't support u+1. Do you want to restrict from np.generic to np.number | np.bool | ... ?
There was a problem hiding this comment.
That sounds reasonable to me. I don't think jnp.array(<some np.generic that is not a number or bool>) is going to work in general anyway.
|
@patrick-kidger this improvement to the return type is uncovering a real issue: |
|
numpy/numpy#30335 will eventually allow for some simplifications! |
| overload, | ||
| runtime_checkable, | ||
| TypeAlias, | ||
| TypeGuard, |
There was a problem hiding this comment.
Perhaps we should check for 3.13+ and use TypeIs in that case?
There was a problem hiding this comment.
I don't think that's what we want in this case. TypeIs doesn't enable type narrowing within the union.
equinox/_filters.py
Outdated
| _ARRAY_TYPES += (_TypedNdArray,) | ||
|
|
||
| # Type alias for type checkers. Never disappears from unions: T | Never == T | ||
| TypedNdArray: TypeAlias = _TypedNdArray # pyright: ignore[reportInvalidTypeForm] |
There was a problem hiding this comment.
Do we need both TypedNdArray and _TypedNdArray? We could just put _TypedNdArray directly into the union.
tests/test_jit.py
Outdated
| array_tree = [{"a": a, "b": b}, (c,)] | ||
| mlp_add = jtu.tree_map(lambda u: u + 1 if eqx.is_array(u) else u, general_tree[-1]) | ||
| mlp_add = jtu.tree_map( | ||
| lambda u: jnp.asarray(u) + 1 if eqx.is_array(u) else u, general_tree[-1] |
There was a problem hiding this comment.
That sounds reasonable to me. I don't think jnp.array(<some np.generic that is not a number or bool>) is going to work in general anyway.
There was a problem hiding this comment.
Pull request overview
This PR enhances type narrowing capabilities by adding TypeGuard return type annotations to the is_array and is_array_like functions, enabling better static type checking when using these predicates.
Key changes:
- Added
TypeGuardreturn types tois_arrayandis_array_likefunctions for improved type narrowing - Refined array type definitions to accept
np.numberandnp.bool_instead of the broadernp.genericto exclude non-array-like types (e.g.,np.object_,np.flexible) - Updated tooling versions (pyright v1.1.406→v1.1.407, ruff v0.13.0→v0.14.7)
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
equinox/_filters.py |
Core changes: added TypeGuard annotations, refined NDArrayType definition, introduced ArrayTypes and ArrayLikeTypes type aliases with TYPE_CHECKING branches, and added HasJaxArray protocol |
equinox/nn/_mlp.py |
Added jnp.asarray() call to handle NumPy scalar types (np.number/np.bool_) that TypeGuard now includes in array types |
equinox/_doc_utils.py |
Simplified doc_repr function logic by flattening nested conditionals |
.pre-commit-config.yaml |
Updated pre-commit hook versions for ruff and pyright |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
| # array-like. | ||
| NDArrayType: TypeAlias = np.ndarray | np.number | np.bool_ | ||
| _NDARRAY_TYPES: Final = get_args(NDArrayType) | ||
| _ARRAY_TYPES = _NDARRAY_TYPES + (jax.Array,) |
There was a problem hiding this comment.
I think the change to type narrow NDArrayType from generic to number | bool means it excludes bfloat16!
🤔
We want to include that, while still excluding np.object_ and np.flexible.
There was a problem hiding this comment.
Ah, interesting! I think just adding jax.dtypes.bfloat16 should work?
It would also probably be fine to include np.generic as an overestimation, to be robust to some of the more esoteric 4-bit (etc) dtypes as well?
There was a problem hiding this comment.
That looks to be necessary since jax doesn't provide a nice union of the ml_dtypes that it imports.
I wish they would! Or just subclass np.number
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
|
Yeesh. Did not expect this small bit of type narrowing to be so difficult. The state of array scalars (instances of np.generic) is a big mess. |
|
|
||
| # Type alias for type checkers. Never disappears from unions: T | Never == T | ||
| if TYPE_CHECKING: | ||
| ArrayTypes: TypeAlias = NDArrayType | jax.Array | LiteralArray | TypedNdArray # type: ignore[valid-type] |
There was a problem hiding this comment.
The type: ignore here is the kind of thing that'll also introduce maintenance overhead, from the oddities between what type checkers accept / how JAX handles TypedNdArray / etc all intersect, and how none of these things are stable in how they behave going into the future.
Let's just drop TypedNdArray from this union, since it's so rare?
| # e.g. `np.object_` and `np.flexible`. But `ml_dtypes` also defines dtypes | ||
| # that inherit from `np.generic` and can't easily be listed here individually. | ||
| @runtime_checkable | ||
| class ArrayLikeGeneric(Protocol): |
There was a problem hiding this comment.
I'd rather keep this file simple than overly precise, I think. JAX will inevitably change something, and I don't want the maintenance burden of tracking things too precisely!
Let's just use np.generic and accept that this is an upper bound?
This PR:
is_arrayandis_arraylikeso that they do better type narrowing.uv run pre-commit autoupdate)