Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ ignore = [
"FIX", # flake8-fixme
"ISC001", # Conflicts with formatter
"PYI041", # Use `float` instead of `int | float`
"TD002", # Missing author in TODO
"TD003", # Missing issue link for this TODO
]

[tool.ruff.lint.pydocstyle]
Expand Down
18 changes: 17 additions & 1 deletion src/array_api_typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,25 @@
"Array",
"HasArrayNamespace",
"HasDType",
"HasDevice",
"HasMatrixTranspose",
"HasNDim",
"HasShape",
"HasSize",
"HasTranspose",
"__version__",
"__version_tuple__",
)

from ._array import Array, HasArrayNamespace, HasDType
from ._array import (
Array,
HasArrayNamespace,
HasDevice,
HasDType,
HasMatrixTranspose,
HasNDim,
HasShape,
HasSize,
HasTranspose,
)
from ._version import version as __version__, version_tuple as __version_tuple__
127 changes: 125 additions & 2 deletions src/array_api_typing/_array.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
__all__ = (
"Array",
"HasArrayNamespace",
"HasDType",
"HasDevice",
"HasMatrixTranspose",
"HasNDim",
"HasShape",
"HasSize",
"HasTranspose",
)

from types import ModuleType
from typing import Literal, Protocol
from typing import Literal, Protocol, Self
from typing_extensions import TypeVar

NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType)
DTypeT_co = TypeVar("DTypeT_co", covariant=True)
DeviceT_co = TypeVar("DeviceT_co", covariant=True, default=object)


class HasArrayNamespace(Protocol[NamespaceT_co]):
Expand Down Expand Up @@ -67,10 +75,125 @@ def dtype(self, /) -> DTypeT_co:
...


class HasDevice(Protocol[DeviceT_co]):
"""Protocol for array classes that have a device attribute."""

@property
def device(self) -> DeviceT_co:
"""Hardware device the array data resides on."""
...


class HasMatrixTranspose(Protocol):
"""Protocol for array classes that have a matrix transpose attribute."""

@property
def mT(self) -> Self: # noqa: N802
"""Transpose of a matrix (or a stack of matrices).

If an array instance has fewer than two dimensions, an error should be
raised.

Returns:
Self: array whose last two dimensions (axes) are permuted in reverse
order relative to original array (i.e., for an array instance
having shape `(..., M, N)`, the returned array must have shape
`(..., N, M))`. The returned array must have the same data type
as the original array.

"""
...


class HasNDim(Protocol):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a situation where you'd wanna use HasNDim over e.g. HasShape? Otherwise we probably should keep this private, given that

There should be one-- and preferably only one --obvious way to do it.
Although that way may not be obvious at first unless you're Dutch.

Copy link
Collaborator Author

@nstarman nstarman Aug 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're checking for the attribute ndim.
Not sure I'm understanding this comment. Ndim is in the Array API spec...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, my point is that I think we should only provide users with protocols that help them annotate their array-api code. Otherwise, it'll just be confusing for the users, and a waste of time for us.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh. As the Array API is the intersection of the array libraries, IMO pretty much everything is useful.

Copy link
Member

@jorenham jorenham Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea of course, everything in the array-api is designed for a reason. But all of those reasons apply in the runtime world. And our goal is to provide an API for the static-typing world, which is but a shadow of the runtime one, where only a subset of the API has practical use.

Copy link
Collaborator Author

@nstarman nstarman Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python comes with batteries.
The Array API is the intersection of the array libraries and each method is useful / used by many people / has long historical precedent, which is why it's in all the libraries.
IMO this library isn't to ideologically improve on the Array API, it's to provide the tools to describe it statically.

My suggestion is to have all these low-level protocols. Where we go beyond the baseline Array API is in how these protocols are parametrized and how we build intersection types. This is where necessity drives us to make decisions.

Copy link
Collaborator Author

@nstarman nstarman Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the main difference is that almost all of the optype protocols were written with a specific use-case in mind

I don't disagree with this idea, but let's consider the situation. There are numerous array libraries which people have spent a long time working on, thinking about their APIs for many years. Numpy in particular sets much of the standard and has many people thinking about the things in it. They got together to rethink numpy and array libraries in general, to mass adoption. We've been to some of those meetings where they're agonizing over a small piece of API. IMO the presumption of usefulness is very clear.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jorenham can we get in protocols for the methods?

#22 (comment)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jorenham can we get in protocols for the methods?

#22 (comment)

What methods, exactly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Methods and attributes of Array.

"""Protocol for array classes that have a number of dimensions attribute."""

@property
def ndim(self) -> int:
"""Number of array dimensions (axes).

Returns:
int: number of array dimensions (axes).

"""
...


class HasShape(Protocol):
"""Protocol for array classes that have a shape attribute."""

@property
def shape(self) -> tuple[int | None, ...]:
"""Shape of the array.

Returns:
tuple[int | None, ...]: array dimensions. An array dimension must be None
if and only if a dimension is unknown.

Notes:
For array libraries having graph-based computational models, array
dimensions may be unknown due to data-dependent operations (e.g.,
boolean indexing; `A[:, B > 0]`) and thus cannot be statically
resolved without knowing array contents.

"""
...


class HasSize(Protocol):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the use-case of this one? Is there anything this can help with, that HasShape can't?
Put differently; should we make this public API or not?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we have Protocols for all attributes in the Array API?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why we should write a protocol if there's no use for it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def is_sized(obj: Any, /) -> TypeGuard[HasSize]: ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, and what would you use that for in a real-world scenario

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I'm trying to prevent is situations where there are multiple non-obvious ways of achieving the same result. And HasSize seems to me like it could have a lot of overlap with HasShape.

This library isn't the venue for changing the Array API...

I know, and I wasn't suggesting anything like that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that not the practical effect of omission?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the Array API spec has methods size and shape, and the corresponding typing library has only HasShape then by omission it re-interprets the Array API. Yes people can write their own HasSize, but the target audience of the Array API isn't just lovers of typing.

Copy link
Member

@jorenham jorenham Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that not the practical effect of omission?

No? AFAIK, this is library depends on the array api, not the other way around. But maybe I'm being naive here 🤷🏻.

Plus, there are already stub-like types in the array-api itself (https://github.com/data-apis/array-api/tree/main/src/array_api_stubs) that are used for documentation purposes (and don't have any use-case besides that).

Copy link
Collaborator Author

@nstarman nstarman Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They aren't planning on publishing those. As I understand it, we depend on the Array API spec and our remit is to describe it statically. As the official typing library for that spec, omission is change.

We have HasNamespace. We should have similar for the other methods and attributes of Array. We should have useful intersection types related to Array. We should have DoesFunction protocols for the functions and a Namespace protocol for their collection.

"""Protocol for array classes that have a size attribute."""

@property
def size(self) -> int | None:
"""Number of elements in an array.

Returns:
int | None: number of elements in an array. The returned value must
be `None` if and only if one or more array dimensions are
unknown.

Notes:
This must equal the product of the array's dimensions.

"""
...


class HasTranspose(Protocol):
"""Protocol for array classes that support the transpose operation."""

@property
def T(self) -> Self: # noqa: N802
"""Transpose of the array.

The array instance must be two-dimensional. If the array instance is not
two-dimensional, an error should be raised.

Returns:
Self: two-dimensional array whose first and last dimensions (axes)
are permuted in reverse order relative to original array. The
returned array must have the same data type as the original
array.

Notes:
Limiting the transpose to two-dimensional arrays (matrices) deviates
from the NumPy et al practice of reversing all axes for arrays
having more than two-dimensions. This is intentional, as reversing
all axes was found to be problematic (e.g., conflicting with the
mathematical definition of a transpose which is limited to matrices;
not operating on batches of matrices; et cetera). In order to
reverse all axes, one is recommended to use the functional
`PermuteDims` interface found in this specification.

"""
...


class Array(
HasArrayNamespace[NamespaceT_co],
# ------ Attributes -------
HasDType[DTypeT_co],
# ------- Methods ---------
HasArrayNamespace[NamespaceT_co],
# -------------------------
Protocol[DTypeT_co, NamespaceT_co],
):
Expand Down
52 changes: 49 additions & 3 deletions tests/integration/test_numpy1p0.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# mypy: disable-error-code="no-redef"

from types import ModuleType
from typing import Any
from typing import Any, assert_type

import numpy.array_api as np # type: ignore[import-not-found, unused-ignore]
from numpy import dtype
Expand Down Expand Up @@ -39,6 +39,48 @@ _: xpt.HasDType[dtype[Any]] = nparr
_: xpt.HasDType[dtype[Any]] = nparr_i32
_: xpt.HasDType[dtype[Any]] = nparr_f32

# =========================================================
# `xpt.HasDevice`

_: xpt.HasDevice = nparr
_: xpt.HasDevice = nparr_i32
_: xpt.HasDevice = nparr_f32

# =========================================================
# `xpt.HasMatrixTranspose`

_: xpt.HasMatrixTranspose = nparr
_: xpt.HasMatrixTranspose = nparr_i32
_: xpt.HasMatrixTranspose = nparr_f32

# =========================================================
# `xpt.HasNDim`

_: xpt.HasNDim = nparr
_: xpt.HasNDim = nparr_i32
_: xpt.HasNDim = nparr_f32

# =========================================================
# `xpt.HasShape`

_: xpt.HasShape = nparr
_: xpt.HasShape = nparr_i32
_: xpt.HasShape = nparr_f32

# =========================================================
# `xpt.HasShape`

_: xpt.HasShape = nparr
_: xpt.HasShape = nparr_i32
_: xpt.HasShape = nparr_f32

# =========================================================
# `xpt.HasTranspose`

_: xpt.HasTranspose = nparr
_: xpt.HasTranspose = nparr_i32
_: xpt.HasTranspose = nparr_f32

# =========================================================
# `xpt.Array`

Expand All @@ -49,5 +91,9 @@ a_ns: xpt.Array[Any, ModuleType] = nparr
# Note that `np.array_api` uses dtype objects, not dtype classes, so we can't
# type annotate specific dtypes like `np.float32` or `np.int32`.
_: xpt.Array[dtype[Any]] = nparr
_: xpt.Array[dtype[Any]] = nparr_i32
_: xpt.Array[dtype[Any]] = nparr_f32
x_f32: xpt.Array[dtype[Any]] = nparr_f32
x_i32: xpt.Array[dtype[Any]] = nparr_i32

# Check Attribute `.dtype`
assert_type(x_f32.dtype, dtype[Any])
assert_type(x_i32.dtype, dtype[Any])
66 changes: 60 additions & 6 deletions tests/integration/test_numpy2p0.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# mypy: disable-error-code="no-redef"

from types import ModuleType
from typing import Any, TypeAlias
from typing import Any, TypeAlias, assert_type

import numpy as np
import numpy.typing as npt
Expand All @@ -11,12 +11,13 @@ import array_api_typing as xpt
# DType aliases
F32: TypeAlias = np.float32
I32: TypeAlias = np.int32
B: TypeAlias = np.bool_

# Define NDArrays against which we can test the protocols
nparr: npt.NDArray[Any]
nparr_i32: npt.NDArray[I32]
nparr_f32: npt.NDArray[F32]
nparr_b: npt.NDArray[np.bool_]
nparr_b: npt.NDArray[B]

# =========================================================
# `xpt.HasArrayNamespace`
Expand All @@ -42,7 +43,55 @@ _: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
_: xpt.HasDType[Any] = nparr
_: xpt.HasDType[np.dtype[I32]] = nparr_i32
_: xpt.HasDType[np.dtype[F32]] = nparr_f32
_: xpt.HasDType[np.dtype[np.bool_]] = nparr_b
_: xpt.HasDType[np.dtype[B]] = nparr_b

# =========================================================
# `xpt.HasDevice`

_: xpt.HasDevice = nparr
_: xpt.HasDevice = nparr_i32
_: xpt.HasDevice = nparr_f32
_: xpt.HasDevice = nparr_b

# =========================================================
# `xpt.HasMatrixTranspose`

_: xpt.HasMatrixTranspose = nparr
_: xpt.HasMatrixTranspose = nparr_i32
_: xpt.HasMatrixTranspose = nparr_f32
_: xpt.HasMatrixTranspose = nparr_b

# =========================================================
# `xpt.HasNDim`

_: xpt.HasNDim = nparr
_: xpt.HasNDim = nparr_i32
_: xpt.HasNDim = nparr_f32
_: xpt.HasNDim = nparr_b

# =========================================================
# `xpt.HasShape`

_: xpt.HasShape = nparr
_: xpt.HasShape = nparr_i32
_: xpt.HasShape = nparr_f32
_: xpt.HasShape = nparr_b

# =========================================================
# `xpt.HasSize`

_: xpt.HasSize = nparr
_: xpt.HasSize = nparr_i32
_: xpt.HasSize = nparr_f32
_: xpt.HasSize = nparr_b

# =========================================================
# `xpt.HasTranspose`

_: xpt.HasTranspose = nparr
_: xpt.HasTranspose = nparr_i32
_: xpt.HasTranspose = nparr_f32
_: xpt.HasTranspose = nparr_b

# =========================================================
# `xpt.Array`
Expand All @@ -52,6 +101,11 @@ a_ns: xpt.Array[Any, ModuleType] = nparr

# Check DTypeT_co assignment
_: xpt.Array[Any] = nparr
_: xpt.Array[np.dtype[I32]] = nparr_i32
_: xpt.Array[np.dtype[F32]] = nparr_f32
_: xpt.Array[np.dtype[np.bool_]] = nparr_b
x_f32: xpt.Array[np.dtype[F32]] = nparr_f32
x_i32: xpt.Array[np.dtype[I32]] = nparr_i32
x_b: xpt.Array[np.dtype[B]] = nparr_b

# Check Attribute `.dtype`
assert_type(x_f32.dtype, np.dtype[F32])
assert_type(x_i32.dtype, np.dtype[I32])
assert_type(x_b.dtype, np.dtype[B])
Loading
Loading