Skip to content

types: is_array(_like) provides TypeGuard#1148

Open
nstarman wants to merge 5 commits intopatrick-kidger:mainfrom
nstarman:typeguards
Open

types: is_array(_like) provides TypeGuard#1148
nstarman wants to merge 5 commits intopatrick-kidger:mainfrom
nstarman:typeguards

Conversation

@nstarman
Copy link
Copy Markdown
Contributor

@nstarman nstarman commented Nov 19, 2025

This PR:

  1. enhances the return type annotations for is_array and is_arraylike so that they do better type narrowing.
  2. updates pyright to get rid of a CI warning (ruff also got updated as part of uv run pre-commit autoupdate)
  3. simplifies a few functions I saw while addressing pyright errors

@nstarman nstarman changed the title types: is_array(_like) provides TypeIs types: is_array(_like) provides TypeGuard Nov 30, 2025
array_tree = [{"a": a, "b": b}, (c,)]
mlp_add = jtu.tree_map(lambda u: u + 1 if eqx.is_array(u) else u, general_tree[-1])
mlp_add = jtu.tree_map(
lambda u: jnp.asarray(u) + 1 if eqx.is_array(u) else u, general_tree[-1]
Copy link
Copy Markdown
Contributor Author

@nstarman nstarman Nov 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_array checks for np.generic, but this doen't support u+1. Do you want to restrict from np.generic to np.number | np.bool | ... ?

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds reasonable to me. I don't think jnp.array(<some np.generic that is not a number or bool>) is going to work in general anyway.

@nstarman nstarman marked this pull request as ready for review November 30, 2025 21:28
@nstarman
Copy link
Copy Markdown
Contributor Author

@patrick-kidger this improvement to the return type is uncovering a real issue: np.generic isn't always a full stand-in for np.ndarray. Also, and this isn't an issue we can solve, jax has limited static support for __jax_array__.

@nstarman
Copy link
Copy Markdown
Contributor Author

numpy/numpy#30335 will eventually allow for some simplifications!

overload,
runtime_checkable,
TypeAlias,
TypeGuard,
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we should check for 3.13+ and use TypeIs in that case?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's what we want in this case. TypeIs doesn't enable type narrowing within the union.

_ARRAY_TYPES += (_TypedNdArray,)

# Type alias for type checkers. Never disappears from unions: T | Never == T
TypedNdArray: TypeAlias = _TypedNdArray # pyright: ignore[reportInvalidTypeForm]
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need both TypedNdArray and _TypedNdArray? We could just put _TypedNdArray directly into the union.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

array_tree = [{"a": a, "b": b}, (c,)]
mlp_add = jtu.tree_map(lambda u: u + 1 if eqx.is_array(u) else u, general_tree[-1])
mlp_add = jtu.tree_map(
lambda u: jnp.asarray(u) + 1 if eqx.is_array(u) else u, general_tree[-1]
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds reasonable to me. I don't think jnp.array(<some np.generic that is not a number or bool>) is going to work in general anyway.

Copilot AI review requested due to automatic review settings December 3, 2025 22:49
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enhances type narrowing capabilities by adding TypeGuard return type annotations to the is_array and is_array_like functions, enabling better static type checking when using these predicates.

Key changes:

  • Added TypeGuard return types to is_array and is_array_like functions for improved type narrowing
  • Refined array type definitions to accept np.number and np.bool_ instead of the broader np.generic to exclude non-array-like types (e.g., np.object_, np.flexible)
  • Updated tooling versions (pyright v1.1.406→v1.1.407, ruff v0.13.0→v0.14.7)

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
equinox/_filters.py Core changes: added TypeGuard annotations, refined NDArrayType definition, introduced ArrayTypes and ArrayLikeTypes type aliases with TYPE_CHECKING branches, and added HasJaxArray protocol
equinox/nn/_mlp.py Added jnp.asarray() call to handle NumPy scalar types (np.number/np.bool_) that TypeGuard now includes in array types
equinox/_doc_utils.py Simplified doc_repr function logic by flattening nested conditionals
.pre-commit-config.yaml Updated pre-commit hook versions for ruff and pyright

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
# array-like.
NDArrayType: TypeAlias = np.ndarray | np.number | np.bool_
_NDARRAY_TYPES: Final = get_args(NDArrayType)
_ARRAY_TYPES = _NDARRAY_TYPES + (jax.Array,)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the change to type narrow NDArrayType from generic to number | bool means it excludes bfloat16!
🤔
We want to include that, while still excluding np.object_ and np.flexible.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, interesting! I think just adding jax.dtypes.bfloat16 should work?

It would also probably be fine to include np.generic as an overestimation, to be robust to some of the more esoteric 4-bit (etc) dtypes as well?

Copy link
Copy Markdown
Contributor Author

@nstarman nstarman Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks to be necessary since jax doesn't provide a nice union of the ml_dtypes that it imports.
I wish they would! Or just subclass np.number

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
@nstarman
Copy link
Copy Markdown
Contributor Author

nstarman commented Dec 9, 2025

Yeesh. Did not expect this small bit of type narrowing to be so difficult. The state of array scalars (instances of np.generic) is a big mess.


# Type alias for type checkers. Never disappears from unions: T | Never == T
if TYPE_CHECKING:
ArrayTypes: TypeAlias = NDArrayType | jax.Array | LiteralArray | TypedNdArray # type: ignore[valid-type]
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type: ignore here is the kind of thing that'll also introduce maintenance overhead, from the oddities between what type checkers accept / how JAX handles TypedNdArray / etc all intersect, and how none of these things are stable in how they behave going into the future.

Let's just drop TypedNdArray from this union, since it's so rare?

# e.g. `np.object_` and `np.flexible`. But `ml_dtypes` also defines dtypes
# that inherit from `np.generic` and can't easily be listed here individually.
@runtime_checkable
class ArrayLikeGeneric(Protocol):
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather keep this file simple than overly precise, I think. JAX will inevitably change something, and I don't want the maintenance burden of tracking things too precisely!

Let's just use np.generic and accept that this is an upper bound?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants