Skip to content

Commit e093c7e

Browse files
TYP: Make array shape type variable covariant and bound
Fixes numpy#25729 This change allows future changes to the static typing of numpy that modify or only work with certain numbers of dimensions. It also applies the change to subclasses of ndarray and adds tests. It allows users to statically type their array shapes with subtypes of tuple (e.g. NamedTuple) and tuples of int subtypes (e.g. Literal or NewType). For a discussion of the merits of TypeVarTuple vs a tuple-bound TypeVar, see the linked PR
1 parent 0469e1d commit e093c7e

File tree

11 files changed

+94
-50
lines changed

11 files changed

+94
-50
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Editor temporary/working/backup files #
22
#########################################
3+
env/
34
.#*
45
[#]*#
56
*~
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
``ndarray`` shape-type parameter is now covariant and bound to ``tuple[int, ...]``
2+
----------------------------------------------------------------------------------
3+
Static typing for ``ndarray`` is a long-term effort that continues
4+
with this change. It is a generic type with type parameters for
5+
the shape and the data type. Previously, the shape type parameter could be
6+
any value. This change restricts it to a tuple of ints, as one would expect
7+
from using ``ndarray.shape``. Further, the shape-type parameter has been
8+
changed from invariant to covariant. This change also applies to subpytes of
9+
``ndarray``, e.g. ``np.ma.MaskedArray``. See the `typing docs <https://typing.readthedocs.io/en/latest/reference/generics.html#variance-of-generic-types>`_
10+
for more information.

doc/release/upcoming_changes/README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ So for example: ``123.new_feature.rst`` would have the content::
4040
The ``my_new_feature`` option is now available for `my_favorite_function`.
4141
To use it, write ``np.my_favorite_function(..., my_new_feature=True)``.
4242

43-
``highlight`` is usually formatted as bulled points making the fragment
43+
``highlight`` is usually formatted as bullet points making the fragment
4444
``* This is a highlight``.
4545

4646
Note the use of single-backticks to get an internal link (assuming

numpy/__init__.pyi

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,10 +1500,8 @@ _DType = TypeVar("_DType", bound=dtype[Any])
15001500
_DType_co = TypeVar("_DType_co", covariant=True, bound=dtype[Any])
15011501
_FlexDType = TypeVar("_FlexDType", bound=dtype[flexible])
15021502

1503-
# TODO: Set the `bound` to something more suitable once we
1504-
# have proper shape support
1505-
_ShapeType = TypeVar("_ShapeType", bound=Any)
1506-
_ShapeType2 = TypeVar("_ShapeType2", bound=Any)
1503+
_ShapeType_co = TypeVar("_ShapeType_co", covariant=True, bound=tuple[int, ...])
1504+
_ShapeType2 = TypeVar("_ShapeType2", bound=tuple[int, ...])
15071505
_NumberType = TypeVar("_NumberType", bound=number[Any])
15081506

15091507
if sys.version_info >= (3, 12):
@@ -1553,7 +1551,7 @@ class _SupportsImag(Protocol[_T_co]):
15531551
@property
15541552
def imag(self) -> _T_co: ...
15551553

1556-
class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
1554+
class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
15571555
__hash__: ClassVar[None]
15581556
@property
15591557
def base(self) -> None | NDArray[Any]: ...
@@ -1563,14 +1561,14 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
15631561
def size(self) -> int: ...
15641562
@property
15651563
def real(
1566-
self: ndarray[_ShapeType, dtype[_SupportsReal[_ScalarType]]], # type: ignore[type-var]
1567-
) -> ndarray[_ShapeType, _dtype[_ScalarType]]: ...
1564+
self: ndarray[_ShapeType_co, dtype[_SupportsReal[_ScalarType]]], # type: ignore[type-var]
1565+
) -> ndarray[_ShapeType_co, _dtype[_ScalarType]]: ...
15681566
@real.setter
15691567
def real(self, value: ArrayLike) -> None: ...
15701568
@property
15711569
def imag(
1572-
self: ndarray[_ShapeType, dtype[_SupportsImag[_ScalarType]]], # type: ignore[type-var]
1573-
) -> ndarray[_ShapeType, _dtype[_ScalarType]]: ...
1570+
self: ndarray[_ShapeType_co, dtype[_SupportsImag[_ScalarType]]], # type: ignore[type-var]
1571+
) -> ndarray[_ShapeType_co, _dtype[_ScalarType]]: ...
15741572
@imag.setter
15751573
def imag(self, value: ArrayLike) -> None: ...
15761574
def __new__(
@@ -1591,11 +1589,11 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
15911589
@overload
15921590
def __array__(
15931591
self, dtype: None = ..., /, *, copy: None | bool = ...
1594-
) -> ndarray[_ShapeType, _DType_co]: ...
1592+
) -> ndarray[_ShapeType_co, _DType_co]: ...
15951593
@overload
15961594
def __array__(
15971595
self, dtype: _DType, /, *, copy: None | bool = ...
1598-
) -> ndarray[_ShapeType, _DType]: ...
1596+
) -> ndarray[_ShapeType_co, _DType]: ...
15991597

16001598
def __array_ufunc__(
16011599
self,
@@ -1646,12 +1644,12 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
16461644
@overload
16471645
def __getitem__(self: NDArray[void], key: str) -> NDArray[Any]: ...
16481646
@overload
1649-
def __getitem__(self: NDArray[void], key: list[str]) -> ndarray[_ShapeType, _dtype[void]]: ...
1647+
def __getitem__(self: NDArray[void], key: list[str]) -> ndarray[_ShapeType_co, _dtype[void]]: ...
16501648

16511649
@property
16521650
def ctypes(self) -> _ctypes[int]: ...
16531651
@property
1654-
def shape(self) -> _Shape: ...
1652+
def shape(self) -> _ShapeType_co: ...
16551653
@shape.setter
16561654
def shape(self, value: _ShapeLike) -> None: ...
16571655
@property
@@ -3786,7 +3784,7 @@ _MemMapModeKind: TypeAlias = L[
37863784
"write", "w+",
37873785
]
37883786

3789-
class memmap(ndarray[_ShapeType, _DType_co]):
3787+
class memmap(ndarray[_ShapeType_co, _DType_co]):
37903788
__array_priority__: ClassVar[float]
37913789
filename: str | None
37923790
offset: int
@@ -3824,7 +3822,7 @@ class memmap(ndarray[_ShapeType, _DType_co]):
38243822
def __array_finalize__(self, obj: object) -> None: ...
38253823
def __array_wrap__(
38263824
self,
3827-
array: memmap[_ShapeType, _DType_co],
3825+
array: memmap[_ShapeType_co, _DType_co],
38283826
context: None | tuple[ufunc, tuple[Any, ...], int] = ...,
38293827
return_scalar: builtins.bool = ...,
38303828
) -> Any: ...
@@ -3927,7 +3925,7 @@ class poly1d:
39273925
k: None | _ArrayLikeComplex_co | _ArrayLikeObject_co = ...,
39283926
) -> poly1d: ...
39293927

3930-
class matrix(ndarray[_ShapeType, _DType_co]):
3928+
class matrix(ndarray[_ShapeType_co, _DType_co]):
39313929
__array_priority__: ClassVar[float]
39323930
def __new__(
39333931
subtype,
@@ -3963,13 +3961,13 @@ class matrix(ndarray[_ShapeType, _DType_co]):
39633961
@overload
39643962
def __getitem__(self: NDArray[void], key: str, /) -> matrix[Any, dtype[Any]]: ...
39653963
@overload
3966-
def __getitem__(self: NDArray[void], key: list[str], /) -> matrix[_ShapeType, dtype[void]]: ...
3964+
def __getitem__(self: NDArray[void], key: list[str], /) -> matrix[_ShapeType_co, dtype[void]]: ...
39673965

39683966
def __mul__(self, other: ArrayLike, /) -> matrix[Any, Any]: ...
39693967
def __rmul__(self, other: ArrayLike, /) -> matrix[Any, Any]: ...
3970-
def __imul__(self, other: ArrayLike, /) -> matrix[_ShapeType, _DType_co]: ...
3968+
def __imul__(self, other: ArrayLike, /) -> matrix[_ShapeType_co, _DType_co]: ...
39713969
def __pow__(self, other: ArrayLike, /) -> matrix[Any, Any]: ...
3972-
def __ipow__(self, other: ArrayLike, /) -> matrix[_ShapeType, _DType_co]: ...
3970+
def __ipow__(self, other: ArrayLike, /) -> matrix[_ShapeType_co, _DType_co]: ...
39733971

39743972
@overload
39753973
def sum(self, axis: None = ..., dtype: DTypeLike = ..., out: None = ...) -> Any: ...
@@ -4065,14 +4063,14 @@ class matrix(ndarray[_ShapeType, _DType_co]):
40654063
@property
40664064
def I(self) -> matrix[Any, Any]: ...
40674065
@property
4068-
def A(self) -> ndarray[_ShapeType, _DType_co]: ...
4066+
def A(self) -> ndarray[_ShapeType_co, _DType_co]: ...
40694067
@property
40704068
def A1(self) -> ndarray[Any, _DType_co]: ...
40714069
@property
40724070
def H(self) -> matrix[Any, _DType_co]: ...
40734071
def getT(self) -> matrix[Any, _DType_co]: ...
40744072
def getI(self) -> matrix[Any, Any]: ...
4075-
def getA(self) -> ndarray[_ShapeType, _DType_co]: ...
4073+
def getA(self) -> ndarray[_ShapeType_co, _DType_co]: ...
40764074
def getA1(self) -> ndarray[Any, _DType_co]: ...
40774075
def getH(self) -> matrix[Any, _DType_co]: ...
40784076

numpy/_core/defchararray.pyi

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ from numpy import (
1616
int_,
1717
object_,
1818
_OrderKACF,
19-
_ShapeType,
19+
_ShapeType_co,
2020
_CharDType,
2121
_SupportsBuffer,
2222
)
@@ -35,7 +35,7 @@ from numpy._core.multiarray import compare_chararrays as compare_chararrays
3535
_SCT = TypeVar("_SCT", str_, bytes_)
3636
_CharArray = chararray[Any, dtype[_SCT]]
3737

38-
class chararray(ndarray[_ShapeType, _CharDType]):
38+
class chararray(ndarray[_ShapeType_co, _CharDType]):
3939
@overload
4040
def __new__(
4141
subtype,
@@ -436,20 +436,20 @@ class chararray(ndarray[_ShapeType, _CharDType]):
436436
) -> _CharArray[bytes_]: ...
437437

438438
def zfill(self, width: _ArrayLikeInt_co) -> chararray[Any, _CharDType]: ...
439-
def capitalize(self) -> chararray[_ShapeType, _CharDType]: ...
440-
def title(self) -> chararray[_ShapeType, _CharDType]: ...
441-
def swapcase(self) -> chararray[_ShapeType, _CharDType]: ...
442-
def lower(self) -> chararray[_ShapeType, _CharDType]: ...
443-
def upper(self) -> chararray[_ShapeType, _CharDType]: ...
444-
def isalnum(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
445-
def isalpha(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
446-
def isdigit(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
447-
def islower(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
448-
def isspace(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
449-
def istitle(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
450-
def isupper(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
451-
def isnumeric(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
452-
def isdecimal(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
439+
def capitalize(self) -> chararray[_ShapeType_co, _CharDType]: ...
440+
def title(self) -> chararray[_ShapeType_co, _CharDType]: ...
441+
def swapcase(self) -> chararray[_ShapeType_co, _CharDType]: ...
442+
def lower(self) -> chararray[_ShapeType_co, _CharDType]: ...
443+
def upper(self) -> chararray[_ShapeType_co, _CharDType]: ...
444+
def isalnum(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
445+
def isalpha(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
446+
def isdigit(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
447+
def islower(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
448+
def isspace(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
449+
def istitle(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
450+
def isupper(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
451+
def isnumeric(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
452+
def isdecimal(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
453453

454454
__all__: list[str]
455455

numpy/_core/records.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ from numpy import (
1616
void,
1717
_ByteOrder,
1818
_SupportsBuffer,
19-
_ShapeType,
19+
_ShapeType_co,
2020
_DType_co,
2121
_OrderKACF,
2222
)
@@ -49,7 +49,7 @@ class record(void):
4949
@overload
5050
def __getitem__(self, key: list[str]) -> record: ...
5151

52-
class recarray(ndarray[_ShapeType, _DType_co]):
52+
class recarray(ndarray[_ShapeType_co, _DType_co]):
5353
# NOTE: While not strictly mandatory, we're demanding here that arguments
5454
# for the `format_parser`- and `dtype`-based dtype constructors are
5555
# mutually exclusive
@@ -114,7 +114,7 @@ class recarray(ndarray[_ShapeType, _DType_co]):
114114
@overload
115115
def __getitem__(self, indx: str) -> NDArray[Any]: ...
116116
@overload
117-
def __getitem__(self, indx: list[str]) -> recarray[_ShapeType, dtype[record]]: ...
117+
def __getitem__(self, indx: list[str]) -> recarray[_ShapeType_co, dtype[record]]: ...
118118
@overload
119119
def field(self, attr: int | str, val: None = ...) -> Any: ...
120120
@overload

numpy/ma/core.pyi

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ from numpy import (
1515
angle as angle
1616
)
1717

18-
# TODO: Set the `bound` to something more suitable once we
19-
# have proper shape support
20-
_ShapeType = TypeVar("_ShapeType", bound=Any)
18+
_ShapeType_co = TypeVar("_ShapeType_co", bound=tuple[int, ...], covariant=True)
2119
_DType_co = TypeVar("_DType_co", bound=dtype[Any], covariant=True)
2220

2321
__all__: list[str]
@@ -165,7 +163,7 @@ class MaskedIterator:
165163
def __setitem__(self, index, value): ...
166164
def __next__(self): ...
167165

168-
class MaskedArray(ndarray[_ShapeType, _DType_co]):
166+
class MaskedArray(ndarray[_ShapeType_co, _DType_co]):
169167
__array_priority__: Any
170168
def __new__(cls, data=..., mask=..., dtype=..., copy=..., subok=..., ndmin=..., fill_value=..., keep_mask=..., hard_mask=..., shrink=..., order=...): ...
171169
def __array_finalize__(self, obj): ...
@@ -300,7 +298,7 @@ class MaskedArray(ndarray[_ShapeType, _DType_co]):
300298
def __reduce__(self): ...
301299
def __deepcopy__(self, memo=...): ...
302300

303-
class mvoid(MaskedArray[_ShapeType, _DType_co]):
301+
class mvoid(MaskedArray[_ShapeType_co, _DType_co]):
304302
def __new__(
305303
self,
306304
data,

numpy/ma/mrecords.pyi

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@ from numpy.ma import MaskedArray
55

66
__all__: list[str]
77

8-
# TODO: Set the `bound` to something more suitable once we
9-
# have proper shape support
10-
_ShapeType = TypeVar("_ShapeType", bound=Any)
8+
_ShapeType_co = TypeVar("_ShapeType_co", covariant=True, bound=tuple[int, ...])
119
_DType_co = TypeVar("_DType_co", bound=dtype[Any], covariant=True)
1210

13-
class MaskedRecords(MaskedArray[_ShapeType, _DType_co]):
11+
class MaskedRecords(MaskedArray[_ShapeType_co, _DType_co]):
1412
def __new__(
1513
cls,
1614
shape,
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from typing import Any
2+
import numpy as np
3+
4+
# test bounds of _ShapeType_co
5+
6+
np.ndarray[tuple[str, str], Any] # E: Value of type variable

numpy/typing/tests/data/pass/shape.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import Any, NamedTuple
2+
3+
import numpy as np
4+
from typing_extensions import assert_type
5+
6+
7+
# Subtype of tuple[int, int]
8+
class XYGrid(NamedTuple):
9+
x_axis: int
10+
y_axis: int
11+
12+
arr: np.ndarray[XYGrid, Any] = np.empty(XYGrid(2, 2))
13+
14+
# Test variance of _ShapeType_co
15+
def accepts_2d(a: np.ndarray[tuple[int, int], Any]) -> None:
16+
return None
17+
18+
accepts_2d(arr)

0 commit comments

Comments
 (0)