Skip to content

Commit c258f83

Browse files
authored
Merge pull request numpy#26081 from Jacob-Stevens-Haas/covariant-shapetype
TYP: Make array _ShapeType bound and covariant
2 parents 7578274 + 476bc6b commit c258f83

File tree

10 files changed

+97
-50
lines changed

10 files changed

+97
-50
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
``ndarray`` shape-type parameter is now covariant and bound to ``tuple[int, ...]``
2+
----------------------------------------------------------------------------------
3+
Static typing for ``ndarray`` is a long-term effort that continues
4+
with this change. It is a generic type with type parameters for
5+
the shape and the data type. Previously, the shape type parameter could be
6+
any value. This change restricts it to a tuple of ints, as one would expect
7+
from using ``ndarray.shape``. Further, the shape-type parameter has been
8+
changed from invariant to covariant. This change also applies to the subtypes
9+
of ``ndarray``, e.g. ``numpy.ma.MaskedArray``. See the
10+
`typing docs <https://typing.readthedocs.io/en/latest/reference/generics.html#variance-of-generic-types>`_
11+
for more information.

doc/release/upcoming_changes/README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ So for example: ``123.new_feature.rst`` would have the content::
4040
The ``my_new_feature`` option is now available for `my_favorite_function`.
4141
To use it, write ``np.my_favorite_function(..., my_new_feature=True)``.
4242

43-
``highlight`` is usually formatted as bulled points making the fragment
43+
``highlight`` is usually formatted as bullet points making the fragment
4444
``* This is a highlight``.
4545

4646
Note the use of single-backticks to get an internal link (assuming

numpy/__init__.pyi

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,10 +1500,9 @@ _DType = TypeVar("_DType", bound=dtype[Any])
15001500
_DType_co = TypeVar("_DType_co", covariant=True, bound=dtype[Any])
15011501
_FlexDType = TypeVar("_FlexDType", bound=dtype[flexible])
15021502

1503-
# TODO: Set the `bound` to something more suitable once we
1504-
# have proper shape support
1505-
_ShapeType = TypeVar("_ShapeType", bound=Any)
1506-
_ShapeType2 = TypeVar("_ShapeType2", bound=Any)
1503+
_ShapeType_co = TypeVar("_ShapeType_co", covariant=True, bound=tuple[int, ...])
1504+
_ShapeType2 = TypeVar("_ShapeType2", bound=tuple[int, ...])
1505+
_Shape2DType_co = TypeVar("_Shape2DType_co", covariant=True, bound=tuple[int, int])
15071506
_NumberType = TypeVar("_NumberType", bound=number[Any])
15081507

15091508
if sys.version_info >= (3, 12):
@@ -1553,7 +1552,7 @@ class _SupportsImag(Protocol[_T_co]):
15531552
@property
15541553
def imag(self) -> _T_co: ...
15551554

1556-
class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
1555+
class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType_co, _DType_co]):
15571556
__hash__: ClassVar[None]
15581557
@property
15591558
def base(self) -> None | NDArray[Any]: ...
@@ -1563,14 +1562,14 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
15631562
def size(self) -> int: ...
15641563
@property
15651564
def real(
1566-
self: ndarray[_ShapeType, dtype[_SupportsReal[_ScalarType]]], # type: ignore[type-var]
1567-
) -> ndarray[_ShapeType, _dtype[_ScalarType]]: ...
1565+
self: ndarray[_ShapeType_co, dtype[_SupportsReal[_ScalarType]]], # type: ignore[type-var]
1566+
) -> ndarray[_ShapeType_co, _dtype[_ScalarType]]: ...
15681567
@real.setter
15691568
def real(self, value: ArrayLike) -> None: ...
15701569
@property
15711570
def imag(
1572-
self: ndarray[_ShapeType, dtype[_SupportsImag[_ScalarType]]], # type: ignore[type-var]
1573-
) -> ndarray[_ShapeType, _dtype[_ScalarType]]: ...
1571+
self: ndarray[_ShapeType_co, dtype[_SupportsImag[_ScalarType]]], # type: ignore[type-var]
1572+
) -> ndarray[_ShapeType_co, _dtype[_ScalarType]]: ...
15741573
@imag.setter
15751574
def imag(self, value: ArrayLike) -> None: ...
15761575
def __new__(
@@ -1591,11 +1590,11 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
15911590
@overload
15921591
def __array__(
15931592
self, dtype: None = ..., /, *, copy: None | bool = ...
1594-
) -> ndarray[_ShapeType, _DType_co]: ...
1593+
) -> ndarray[_ShapeType_co, _DType_co]: ...
15951594
@overload
15961595
def __array__(
15971596
self, dtype: _DType, /, *, copy: None | bool = ...
1598-
) -> ndarray[_ShapeType, _DType]: ...
1597+
) -> ndarray[_ShapeType_co, _DType]: ...
15991598

16001599
def __array_ufunc__(
16011600
self,
@@ -1646,12 +1645,12 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
16461645
@overload
16471646
def __getitem__(self: NDArray[void], key: str) -> NDArray[Any]: ...
16481647
@overload
1649-
def __getitem__(self: NDArray[void], key: list[str]) -> ndarray[_ShapeType, _dtype[void]]: ...
1648+
def __getitem__(self: NDArray[void], key: list[str]) -> ndarray[_ShapeType_co, _dtype[void]]: ...
16501649

16511650
@property
16521651
def ctypes(self) -> _ctypes[int]: ...
16531652
@property
1654-
def shape(self) -> _Shape: ...
1653+
def shape(self) -> _ShapeType_co: ...
16551654
@shape.setter
16561655
def shape(self, value: _ShapeLike) -> None: ...
16571656
@property
@@ -3786,7 +3785,7 @@ _MemMapModeKind: TypeAlias = L[
37863785
"write", "w+",
37873786
]
37883787

3789-
class memmap(ndarray[_ShapeType, _DType_co]):
3788+
class memmap(ndarray[_ShapeType_co, _DType_co]):
37903789
__array_priority__: ClassVar[float]
37913790
filename: str | None
37923791
offset: int
@@ -3824,7 +3823,7 @@ class memmap(ndarray[_ShapeType, _DType_co]):
38243823
def __array_finalize__(self, obj: object) -> None: ...
38253824
def __array_wrap__(
38263825
self,
3827-
array: memmap[_ShapeType, _DType_co],
3826+
array: memmap[_ShapeType_co, _DType_co],
38283827
context: None | tuple[ufunc, tuple[Any, ...], int] = ...,
38293828
return_scalar: builtins.bool = ...,
38303829
) -> Any: ...
@@ -3927,7 +3926,9 @@ class poly1d:
39273926
k: None | _ArrayLikeComplex_co | _ArrayLikeObject_co = ...,
39283927
) -> poly1d: ...
39293928

3930-
class matrix(ndarray[_ShapeType, _DType_co]):
3929+
3930+
3931+
class matrix(ndarray[_Shape2DType_co, _DType_co]):
39313932
__array_priority__: ClassVar[float]
39323933
def __new__(
39333934
subtype,
@@ -3963,13 +3964,13 @@ class matrix(ndarray[_ShapeType, _DType_co]):
39633964
@overload
39643965
def __getitem__(self: NDArray[void], key: str, /) -> matrix[Any, dtype[Any]]: ...
39653966
@overload
3966-
def __getitem__(self: NDArray[void], key: list[str], /) -> matrix[_ShapeType, dtype[void]]: ...
3967+
def __getitem__(self: NDArray[void], key: list[str], /) -> matrix[_Shape2DType_co, dtype[void]]: ...
39673968

39683969
def __mul__(self, other: ArrayLike, /) -> matrix[Any, Any]: ...
39693970
def __rmul__(self, other: ArrayLike, /) -> matrix[Any, Any]: ...
3970-
def __imul__(self, other: ArrayLike, /) -> matrix[_ShapeType, _DType_co]: ...
3971+
def __imul__(self, other: ArrayLike, /) -> matrix[_Shape2DType_co, _DType_co]: ...
39713972
def __pow__(self, other: ArrayLike, /) -> matrix[Any, Any]: ...
3972-
def __ipow__(self, other: ArrayLike, /) -> matrix[_ShapeType, _DType_co]: ...
3973+
def __ipow__(self, other: ArrayLike, /) -> matrix[_Shape2DType_co, _DType_co]: ...
39733974

39743975
@overload
39753976
def sum(self, axis: None = ..., dtype: DTypeLike = ..., out: None = ...) -> Any: ...
@@ -4065,14 +4066,14 @@ class matrix(ndarray[_ShapeType, _DType_co]):
40654066
@property
40664067
def I(self) -> matrix[Any, Any]: ...
40674068
@property
4068-
def A(self) -> ndarray[_ShapeType, _DType_co]: ...
4069+
def A(self) -> ndarray[_Shape2DType_co, _DType_co]: ...
40694070
@property
40704071
def A1(self) -> ndarray[Any, _DType_co]: ...
40714072
@property
40724073
def H(self) -> matrix[Any, _DType_co]: ...
40734074
def getT(self) -> matrix[Any, _DType_co]: ...
40744075
def getI(self) -> matrix[Any, Any]: ...
4075-
def getA(self) -> ndarray[_ShapeType, _DType_co]: ...
4076+
def getA(self) -> ndarray[_Shape2DType_co, _DType_co]: ...
40764077
def getA1(self) -> ndarray[Any, _DType_co]: ...
40774078
def getH(self) -> matrix[Any, _DType_co]: ...
40784079

numpy/_core/defchararray.pyi

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ from numpy import (
1616
int_,
1717
object_,
1818
_OrderKACF,
19-
_ShapeType,
19+
_ShapeType_co,
2020
_CharDType,
2121
_SupportsBuffer,
2222
)
@@ -35,7 +35,7 @@ from numpy._core.multiarray import compare_chararrays as compare_chararrays
3535
_SCT = TypeVar("_SCT", str_, bytes_)
3636
_CharArray = chararray[Any, dtype[_SCT]]
3737

38-
class chararray(ndarray[_ShapeType, _CharDType]):
38+
class chararray(ndarray[_ShapeType_co, _CharDType]):
3939
@overload
4040
def __new__(
4141
subtype,
@@ -436,20 +436,20 @@ class chararray(ndarray[_ShapeType, _CharDType]):
436436
) -> _CharArray[bytes_]: ...
437437

438438
def zfill(self, width: _ArrayLikeInt_co) -> chararray[Any, _CharDType]: ...
439-
def capitalize(self) -> chararray[_ShapeType, _CharDType]: ...
440-
def title(self) -> chararray[_ShapeType, _CharDType]: ...
441-
def swapcase(self) -> chararray[_ShapeType, _CharDType]: ...
442-
def lower(self) -> chararray[_ShapeType, _CharDType]: ...
443-
def upper(self) -> chararray[_ShapeType, _CharDType]: ...
444-
def isalnum(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
445-
def isalpha(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
446-
def isdigit(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
447-
def islower(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
448-
def isspace(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
449-
def istitle(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
450-
def isupper(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
451-
def isnumeric(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
452-
def isdecimal(self) -> ndarray[_ShapeType, dtype[np.bool]]: ...
439+
def capitalize(self) -> chararray[_ShapeType_co, _CharDType]: ...
440+
def title(self) -> chararray[_ShapeType_co, _CharDType]: ...
441+
def swapcase(self) -> chararray[_ShapeType_co, _CharDType]: ...
442+
def lower(self) -> chararray[_ShapeType_co, _CharDType]: ...
443+
def upper(self) -> chararray[_ShapeType_co, _CharDType]: ...
444+
def isalnum(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
445+
def isalpha(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
446+
def isdigit(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
447+
def islower(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
448+
def isspace(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
449+
def istitle(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
450+
def isupper(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
451+
def isnumeric(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
452+
def isdecimal(self) -> ndarray[_ShapeType_co, dtype[np.bool]]: ...
453453

454454
__all__: list[str]
455455

numpy/_core/records.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ from numpy import (
1616
void,
1717
_ByteOrder,
1818
_SupportsBuffer,
19-
_ShapeType,
19+
_ShapeType_co,
2020
_DType_co,
2121
_OrderKACF,
2222
)
@@ -49,7 +49,7 @@ class record(void):
4949
@overload
5050
def __getitem__(self, key: list[str]) -> record: ...
5151

52-
class recarray(ndarray[_ShapeType, _DType_co]):
52+
class recarray(ndarray[_ShapeType_co, _DType_co]):
5353
# NOTE: While not strictly mandatory, we're demanding here that arguments
5454
# for the `format_parser`- and `dtype`-based dtype constructors are
5555
# mutually exclusive
@@ -114,7 +114,7 @@ class recarray(ndarray[_ShapeType, _DType_co]):
114114
@overload
115115
def __getitem__(self, indx: str) -> NDArray[Any]: ...
116116
@overload
117-
def __getitem__(self, indx: list[str]) -> recarray[_ShapeType, dtype[record]]: ...
117+
def __getitem__(self, indx: list[str]) -> recarray[_ShapeType_co, dtype[record]]: ...
118118
@overload
119119
def field(self, attr: int | str, val: None = ...) -> Any: ...
120120
@overload

numpy/ma/core.pyi

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ from numpy import (
1515
angle as angle
1616
)
1717

18-
# TODO: Set the `bound` to something more suitable once we
19-
# have proper shape support
20-
_ShapeType = TypeVar("_ShapeType", bound=Any)
18+
_ShapeType_co = TypeVar("_ShapeType_co", bound=tuple[int, ...], covariant=True)
2119
_DType_co = TypeVar("_DType_co", bound=dtype[Any], covariant=True)
2220

2321
__all__: list[str]
@@ -165,7 +163,7 @@ class MaskedIterator:
165163
def __setitem__(self, index, value): ...
166164
def __next__(self): ...
167165

168-
class MaskedArray(ndarray[_ShapeType, _DType_co]):
166+
class MaskedArray(ndarray[_ShapeType_co, _DType_co]):
169167
__array_priority__: Any
170168
def __new__(cls, data=..., mask=..., dtype=..., copy=..., subok=..., ndmin=..., fill_value=..., keep_mask=..., hard_mask=..., shrink=..., order=...): ...
171169
def __array_finalize__(self, obj): ...
@@ -300,7 +298,7 @@ class MaskedArray(ndarray[_ShapeType, _DType_co]):
300298
def __reduce__(self): ...
301299
def __deepcopy__(self, memo=...): ...
302300

303-
class mvoid(MaskedArray[_ShapeType, _DType_co]):
301+
class mvoid(MaskedArray[_ShapeType_co, _DType_co]):
304302
def __new__(
305303
self,
306304
data,

numpy/ma/mrecords.pyi

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@ from numpy.ma import MaskedArray
55

66
__all__: list[str]
77

8-
# TODO: Set the `bound` to something more suitable once we
9-
# have proper shape support
10-
_ShapeType = TypeVar("_ShapeType", bound=Any)
8+
_ShapeType_co = TypeVar("_ShapeType_co", covariant=True, bound=tuple[int, ...])
119
_DType_co = TypeVar("_DType_co", bound=dtype[Any], covariant=True)
1210

13-
class MaskedRecords(MaskedArray[_ShapeType, _DType_co]):
11+
class MaskedRecords(MaskedArray[_ShapeType_co, _DType_co]):
1412
def __new__(
1513
cls,
1614
shape,
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from typing import Any
2+
import numpy as np
3+
4+
# test bounds of _ShapeType_co
5+
6+
np.ndarray[tuple[str, str], Any] # E: Value of type variable

numpy/typing/tests/data/pass/shape.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import Any, NamedTuple
2+
3+
import numpy as np
4+
from typing_extensions import assert_type
5+
6+
7+
# Subtype of tuple[int, int]
8+
class XYGrid(NamedTuple):
9+
x_axis: int
10+
y_axis: int
11+
12+
arr: np.ndarray[XYGrid, Any] = np.empty(XYGrid(2, 2))
13+
14+
# Test variance of _ShapeType_co
15+
def accepts_2d(a: np.ndarray[tuple[int, int], Any]) -> None:
16+
return None
17+
18+
accepts_2d(arr)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from typing import Any, NamedTuple
2+
3+
import numpy as np
4+
from typing_extensions import assert_type
5+
6+
7+
# Subtype of tuple[int, int]
8+
class XYGrid(NamedTuple):
9+
x_axis: int
10+
y_axis: int
11+
12+
arr: np.ndarray[XYGrid, Any]
13+
14+
# Test shape property matches shape typevar
15+
assert_type(arr.shape, XYGrid)

0 commit comments

Comments
 (0)