Skip to content

Commit 930abe8

Browse files
committed
✨ HasDType
Signed-off-by: Nathaniel Starkman <[email protected]> Signed-off-by: nstarman <[email protected]>
1 parent e666327 commit 930abe8

File tree

4 files changed

+89
-31
lines changed

4 files changed

+89
-31
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ ignore = [
130130
"ISC001", # Conflicts with formatter
131131
"PLW1641", # Object does not implement `__hash__` method
132132
"PYI041", # Use `float` instead of `int | float`
133+
"TD002", # Missing author in TODO
134+
"TD003", # Missing issue link for this TODO
133135
]
134136

135137
[tool.ruff.lint.pylint]

src/array_api_typing/_array.py

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from pathlib import Path
99
from types import ModuleType
10-
from typing import Literal, Never, Protocol, TypeAlias
10+
from typing import Any, Literal, Never, Protocol, TypeAlias
1111
from typing_extensions import TypeVar
1212

1313
import optype as op
@@ -27,6 +27,7 @@
2727
NS_co = TypeVar("NS_co", covariant=True, default=ModuleType)
2828
Other_contra = TypeVar("Other_contra", contravariant=True, default=Never)
2929
R_co = TypeVar("R_co", covariant=True, default=Never)
30+
DType_co = TypeVar("DType_co", covariant=True)
3031

3132

3233
class HasArrayNamespace(Protocol[NS_co]):
@@ -52,8 +53,20 @@ def __array_namespace__(
5253
) -> NS_co: ...
5354

5455

56+
class HasDType(Protocol[DType_co]):
57+
"""Protocol for array classes that have a data type attribute."""
58+
59+
@property
60+
def dtype(self) -> DType_co:
61+
"""Data type of the array elements."""
62+
...
63+
64+
5565
@docstring_setter(**_array_docstrings)
5666
class Array(
67+
# ------ Attributes -------
68+
HasDType[DType_co],
69+
# ------ Methods -------
5770
HasArrayNamespace[NS_co],
5871
op.CanPosSelf,
5972
op.CanNegSelf,
@@ -64,22 +77,23 @@ class Array(
6477
op.CanFloordivSame[Other_contra, R_co],
6578
op.CanModSame[Other_contra, R_co],
6679
op.CanPowSame[Other_contra, R_co],
67-
Protocol[Other_contra, R_co, NS_co],
80+
Protocol[DType_co, Other_contra, R_co, NS_co],
6881
):
6982
"""Array API specification for array object attributes and methods.
7083
71-
The type is: ``Array[-Other = Never, +R = Never, +NS = ModuleType] =
72-
Array[Self | Other, Self | R, NS]`` where:
84+
The type is: ``Array[+DType, -Other = Never, +R = Never, +NS = ModuleType] =
85+
Array[+DType, Self | Other, Self | R, NS]`` where:
7386
87+
- `DType` is the data type of the array elements.
7488
- `Other` is the type of objects that can be used with the array (e.g., for
7589
binary operations). For example, with numeric arrays, it is common to be
7690
able to add `float` and `int` objects to the array, not just other arrays
7791
of the same dtype. This would be annotated as `Other = float`. When not
7892
specified, `Other` only allows `Self` objects.
7993
- `R` is the return type of the array operations. For example, the return
8094
type of the division between integer arrays can be a float array. This
81-
would be annotated as `R = float`. When not specified, `R` only allows
82-
`Self` objects.
95+
would be annotated as `R = Array[float]`. When not specified, `R` only
96+
allows `Self` objects.
8397
- `NS` is the type of the array namespace. It defaults to `ModuleType`,
8498
which is the most common form of array namespace (e.g., `numpy`, `cupy`,
8599
etc.). However, it can be any type, e.g. a `types.SimpleNamespace`, to
@@ -89,20 +103,47 @@ class Array(
89103
"""
90104

91105

92-
BoolArray: TypeAlias = Array[bool, Array[float, Never, NS_co], NS_co]
93-
"""Array API specification for boolean array object attributes and methods.
106+
# TODO: are there ways to tighten the dtype in both Arrays?
107+
BoolArray: TypeAlias = Array[DType_co, bool, Array[Any, float, Never, NS_co], NS_co]
108+
"""Array API specification for arrays that work with boolean values.
94109
95-
Specifically, this type alias fills the `Other_contra` type variable with
96-
`bool`, allowing for `bool` objects to be added, subtracted, multiplied, etc. to
97-
the array object.
110+
The type is: ``BoolArray[+DType, +NS = ModuleType] = Array[+DType, Self | bool,
111+
Self | Array[Any, float, Self, NS], Self, NS]`` where:
98112
99-
"""
113+
- `DType` is the data type of the array elements.
114+
- The second type variable -- `Other` -- is filled with `bool`, allowing for
115+
`bool` objects to be added, subtracted, multiplied, etc. to the array object.
116+
- The third type variable -- `R` -- is filled with `Array[Any, float, Self,
117+
NS]`, which is the return type of the array operations. For example, the
118+
return type of the division between boolean arrays can be a float array.
119+
- `NS` is the type of the array namespace. It defaults to `ModuleType`, which is
120+
the most common form of array namespace (e.g., `numpy`, `cupy`, etc.).
121+
However, it can be any type, e.g. a `types.SimpleNamespace`, to allow for
122+
wrapper libraries to semi-dynamically define their own array namespaces based
123+
on the wrapped array type.
100124
101-
NumericArray: TypeAlias = Array[float | int, NS_co]
102-
"""Array API specification for numeric array object attributes and methods.
125+
"""
103126

104-
Specifically, this type alias fills the `Other_contra` type variable with `float
105-
| int`, allowing for `float | int` objects to be added, subtracted, multiplied,
106-
etc. to the array object.
127+
# TODO: are there ways to tighten the dtype in both Arrays?
128+
NumericArray: TypeAlias = Array[
129+
DType_co, float | int, Array[Any, float | int, Never, NS_co], NS_co
130+
]
131+
"""Array API specification for arrays that work with numeric values.
132+
133+
the type is: ``NumericArray[+DType, +NS = ModuleType] = Array[+DType, Self |
134+
float | int, Self | Array[Any, float | int, Self, NS], NS]`` where:
135+
136+
- `DType` is the data type of the array elements.
137+
- The second type variable -- `Other` -- is filled with `float | int`, allowing
138+
for `float | int` objects to be added, subtracted, multiplied, etc. to the
139+
array object.
140+
- The third type variable -- `R` -- is filled with `Array[Any, float | int,
141+
Self, NS]`, which is the return type of the array operations. For example, the
142+
return type of the division between integer arrays can be a float array.
143+
- `NS` is the type of the array namespace. It defaults to `ModuleType`, which is
144+
the most common form of array namespace (e.g., `numpy`, `cupy`, etc.).
145+
However, it can be any type, e.g. a `types.SimpleNamespace`, to allow for
146+
wrapper libraries to semi-dynamically define their own array namespaces based
147+
on the wrapped array type.
107148
108149
"""

tests/integration/test_numpy1.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,4 @@ arr_i: xpt.Array[int | float, xpt.Array[float | int]] = nparr_i32
4242
# =========================================================
4343
# Check np.ndarray against BoolArray and NumericArray type aliases
4444

45-
numericarray: NumericArray = nparr_f32
45+
numericarray: NumericArray[Any] = nparr_f32

tests/integration/test_numpy2.pyi

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
1-
from typing import Any, Never, TypeAlias
1+
from typing import Any, TypeAlias
22

33
import numpy as np
44
import numpy.typing as npt
55

66
import array_api_typing as xpt
7-
from array_api_typing._array import BoolArray, NumericArray
7+
from array_api_typing._array import BoolArray, HasDType, NumericArray
88

99
F: TypeAlias = np.floating[Any]
10+
DtF: TypeAlias = np.dtype[F]
1011
F32: TypeAlias = np.float32
12+
DtF32: TypeAlias = np.dtype[F32]
13+
1114
I: TypeAlias = np.integer[Any]
15+
DtI: TypeAlias = np.dtype[I]
1216
I32: TypeAlias = np.int32
17+
DtI32: TypeAlias = np.dtype[I32]
18+
19+
DtB: TypeAlias = np.dtype[np.bool_]
20+
DtN: TypeAlias = np.dtype[np.number]
1321

1422
# Define NDArrays against which we can test the protocols
1523
nparr: npt.NDArray[Any]
@@ -24,28 +32,35 @@ arr_ns: xpt.HasArrayNamespace[Any] = nparr
2432
arr_ns_i32: xpt.HasArrayNamespace[Any] = nparr_i32
2533
arr_ns_f32: xpt.HasArrayNamespace[Any] = nparr_f32
2634

35+
# =========================================================
36+
# Ensure that `np.ndarray` instances are assignable to `xpt.HasDType`.
37+
38+
arr_dtype: HasDType[Any] = nparr
39+
arr_dtype_i32: HasDType[DtI32] = nparr_i32
40+
arr_dtype_f32: HasDType[DtF32] = nparr_f32
41+
2742
# =========================================================
2843
# Ensure that `np.ndarray` instances are assignable to `xpt.Array`.
2944

3045
# Generic Array type
31-
arr_array: xpt.Array[Never] = nparr
46+
arr_array: xpt.Array[Any] = nparr
3247

3348
# Float Array types
34-
arr_float: xpt.Array[float] = nparr_f32
35-
arr_f: xpt.Array[F] = nparr_f32
36-
arr_f32: xpt.Array[F32] = nparr_f32
49+
arr_float: xpt.Array[DtF, float] = nparr_f32
50+
arr_f: xpt.Array[DtF, F] = nparr_f32
51+
arr_f32: xpt.Array[DtF32, F32] = nparr_f32
3752

3853
# Integer Array types
39-
arr_int: xpt.Array[int, xpt.Array[float | int]] = nparr_i32
40-
arr_i: xpt.Array[I, xpt.Array[float | int]] = nparr_i32
41-
arr_i32: xpt.Array[I32, xpt.Array[F32 | I32]] = nparr_i32
54+
arr_int: xpt.Array[DtI32, int, xpt.Array[DtN, float | int]] = nparr_i32
55+
arr_i: xpt.Array[DtI32, I, xpt.Array[DtN, float | int]] = nparr_i32
56+
arr_i32: xpt.Array[DtI32, I32, xpt.Array[DtN, F32 | I32]] = nparr_i32
4257

4358
# Boolean Array types
44-
arr_bool: xpt.Array[bool, xpt.Array[float | int | bool]] = nparr_b
45-
arr_b: xpt.Array[np.bool_, xpt.Array[F | I | np.bool_]] = nparr_b
59+
arr_bool: xpt.Array[DtB, bool, xpt.Array[DtN | DtB, float | int | bool]] = nparr_b
60+
arr_b: xpt.Array[DtB, np.bool_, xpt.Array[DtN | DtB, F | I | np.bool_]] = nparr_b
4661

4762
# =========================================================
4863
# Check np.ndarray against BoolArray and NumericArray type aliases
4964

50-
boolarray: BoolArray = nparr_b
51-
numericarray: NumericArray = nparr_f32
65+
boolarray: BoolArray[DtB] = nparr_b # TODO: the Other type isn't correct
66+
numericarray: NumericArray[DtN] = nparr_f32

0 commit comments

Comments
 (0)