Skip to content

Commit 624c3ba

Browse files
committed
feat: HasDType
Signed-off-by: Nathaniel Starkman <[email protected]>
1 parent ebd973c commit 624c3ba

File tree

1 file changed

+46
-12
lines changed

1 file changed

+46
-12
lines changed

src/array_api_typing/_array.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525
_array_docstrings = tomllib.load(f)["docstrings"]
2626

2727
NS_co = TypeVar("NS_co", covariant=True, default=ModuleType)
28-
T_contra = TypeVar("T_contra", contravariant=True)
28+
Other_contra = TypeVar("Other_contra", contravariant=True)
2929
R_co = TypeVar("R_co", covariant=True, default=Never)
30+
DType_co = TypeVar("DType_co", covariant=True)
3031

3132

3233
class HasArrayNamespace(Protocol[NS_co]):
@@ -52,27 +53,60 @@ def __array_namespace__(
5253
) -> NS_co: ...
5354

5455

56+
class HasDType(Protocol[DType_co]):
57+
"""Protocol for array classes that have a data type attribute."""
58+
59+
@property
60+
def dtype(self) -> DType_co:
61+
"""Data type of the array elements."""
62+
...
63+
64+
5565
@docstring_setter(**_array_docstrings)
5666
class Array(
67+
# ------ Attributes -------
68+
HasDType[DType_co],
69+
# ------ Methods -------
5770
HasArrayNamespace[NS_co],
5871
op.CanPosSelf,
5972
op.CanNegSelf,
60-
op.CanAddSame[T_contra, R_co],
61-
op.CanSubSame[T_contra, R_co],
62-
op.CanMulSame[T_contra, R_co],
63-
op.CanTruedivSame[T_contra, R_co],
64-
op.CanFloordivSame[T_contra, R_co],
65-
op.CanModSame[T_contra, R_co],
66-
op.CanPowSame[T_contra, R_co],
67-
Protocol[T_contra, R_co, NS_co],
73+
op.CanAddSame[Other_contra, R_co],
74+
op.CanSubSame[Other_contra, R_co],
75+
op.CanMulSame[Other_contra, R_co],
76+
op.CanTruedivSame[Other_contra, R_co],
77+
op.CanFloordivSame[Other_contra, R_co],
78+
op.CanModSame[Other_contra, R_co],
79+
op.CanPowSame[Other_contra, R_co],
80+
Protocol[DType_co, Other_contra, R_co, NS_co],
6881
):
69-
"""Array API specification for array object attributes and methods."""
82+
"""Array API specification for array object attributes and methods.
83+
84+
The type is: ``Array[+DType, -Other = Never, +R = Never, +NS = ModuleType] =
85+
Array[+DType, Self | Other, Self | R, NS]`` where:
86+
87+
- `DType` is the data type of the array elements.
88+
- `Other` is the type of objects that can be used with the array (e.g., for
89+
binary operations). For example, with numeric arrays, it is common to be
90+
able to add `float` and `int` objects to the array, not just other arrays
91+
of the same dtype. This would be annotated as `Other = float`. When not
92+
specified, `Other` only allows `Self` objects.
93+
- `R` is the return type of the array operations. For example, the return
94+
type of the division between integer arrays can be a float array. This
95+
would be annotated as `R = float`. When not specified, `R` only allows
96+
`Self` objects.
97+
- `NS` is the type of the array namespace. It defaults to `ModuleType`,
98+
which is the most common form of array namespace (e.g., `numpy`, `cupy`,
99+
etc.). However, it can be any type, e.g. a `types.SimpleNamespace`, to
100+
allow for wrapper libraries to semi-dynamically define their own array
101+
namespaces based on the wrapped array type.
102+
103+
"""
70104

71105

72106
BoolArray: TypeAlias = Array[bool, Array[float, Never, NS_co], NS_co]
73107
"""Array API specification for boolean array object attributes and methods.
74108
75-
Specifically, this type alias fills the `T_contra` type variable with
109+
Specifically, this type alias fills the `Other_contra` type variable with
76110
`bool`, allowing for `bool` objects to be added, subtracted, multiplied, etc. to
77111
the array object.
78112
@@ -81,7 +115,7 @@ class Array(
81115
NumericArray: TypeAlias = Array[float | int, NS_co]
82116
"""Array API specification for numeric array object attributes and methods.
83117
84-
Specifically, this type alias fills the `T_contra` type variable with `float
118+
Specifically, this type alias fills the `Other_contra` type variable with `float
85119
| int`, allowing for `float | int` objects to be added, subtracted, multiplied,
86120
etc. to the array object.
87121

0 commit comments

Comments
 (0)