Skip to content

Commit 07124b5

Browse files
author
Bas van Beek
committed
ENH: Add ndarray.__class_getitem__
1 parent dc7dafe commit 07124b5

File tree

4 files changed

+67
-3
lines changed

4 files changed

+67
-3
lines changed

numpy/__init__.pyi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ from abc import abstractmethod
99
from types import TracebackType, MappingProxyType
1010
from contextlib import ContextDecorator
1111

12+
if sys.version_info >= (3, 9):
13+
from types import GenericAlias
14+
1215
from numpy._pytesttester import PytestTester
1316
from numpy.core.multiarray import flagsobj
1417
from numpy.core._internal import _ctypes
@@ -1697,6 +1700,10 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
16971700
strides: _ShapeLike = ...,
16981701
order: _OrderKACF = ...,
16991702
) -> _ArraySelf: ...
1703+
1704+
if sys.version_info >= (3, 9):
1705+
def __class_getitem__(self, item: Any) -> GenericAlias: ...
1706+
17001707
@overload
17011708
def __array__(self, dtype: None = ..., /) -> ndarray[Any, _DType_co]: ...
17021709
@overload

numpy/core/_add_newdocs.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
1010
"""
1111

12+
import sys
1213
from numpy.core.function_base import add_newdoc
1314
from numpy.core.overrides import array_function_like_doc
1415

@@ -796,7 +797,7 @@
796797
object : array_like
797798
An array, any object exposing the array interface, an object whose
798799
__array__ method returns an array, or any (nested) sequence.
799-
If object is a scalar, a 0-dimensional array containing object is
800+
If object is a scalar, a 0-dimensional array containing object is
800801
returned.
801802
dtype : data-type, optional
802803
The desired data-type for the array. If not given, then the type will
@@ -2201,8 +2202,8 @@
22012202
empty : Create an array, but leave its allocated memory unchanged (i.e.,
22022203
it contains "garbage").
22032204
dtype : Create a data-type.
2204-
numpy.typing.NDArray : A :term:`generic <generic type>` version
2205-
of ndarray.
2205+
numpy.typing.NDArray : An ndarray alias :term:`generic <generic type>`
2206+
w.r.t. its `dtype.type <numpy.dtype.type>`.
22062207
22072208
Notes
22082209
-----
@@ -2798,6 +2799,40 @@
27982799
"""))
27992800

28002801

2802+
if sys.version_info > (3, 9):
2803+
add_newdoc('numpy.core.multiarray', 'ndarray', ('__class_getitem__',
2804+
"""a.__class_getitem__(item, /)
2805+
2806+
Return a parametrized wrapper around the `~numpy.ndarray` type.
2807+
2808+
.. versionadded:: 1.22
2809+
2810+
Returns
2811+
-------
2812+
alias : types.GenericAlias
2813+
A parametrized `~numpy.ndarray` type.
2814+
2815+
Examples
2816+
--------
2817+
>>> from typing import Any
2818+
>>> import numpy as np
2819+
2820+
>>> np.ndarray[Any, np.dtype]
2821+
numpy.ndarray[typing.Any, numpy.dtype]
2822+
2823+
Note
2824+
----
2825+
This method is only available for python 3.9 and later.
2826+
2827+
See Also
2828+
--------
2829+
:pep:`585` : Type hinting generics in standard collections.
2830+
numpy.typing.NDArray : An ndarray alias :term:`generic <generic type>`
2831+
w.r.t. its `dtype.type <numpy.dtype.type>`.
2832+
2833+
"""))
2834+
2835+
28012836
add_newdoc('numpy.core.multiarray', 'ndarray', ('__deepcopy__',
28022837
"""a.__deepcopy__(memo, /) -> Deep copy of array.
28032838

numpy/core/src/multiarray/methods.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2756,6 +2756,13 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
27562756
(PyCFunction) array_format,
27572757
METH_VARARGS, NULL},
27582758

2759+
/* for typing; requires python >= 3.9 */
2760+
#ifdef Py_GENERICALIASOBJECT_H
2761+
{"__class_getitem__",
2762+
(PyCFunction)Py_GenericAlias,
2763+
METH_CLASS | METH_O, NULL},
2764+
#endif
2765+
27592766
/* Original and Extended methods added 2005 */
27602767
{"all",
27612768
(PyCFunction)array_all,

numpy/core/tests/test_arraymethod.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
this is private API, but when added, public API may be added here.
44
"""
55

6+
import sys
7+
import types
8+
from typing import Any, Type
9+
610
import pytest
711

812
import numpy as np
@@ -56,3 +60,14 @@ def test_invalid_arguments(self, args, error):
5660
# This is private API, which may be modified freely
5761
with pytest.raises(error):
5862
self.method._simple_strided_call(*args)
63+
64+
65+
@pytest.mark.parametrize(
66+
"cls", [np.ndarray, np.recarray, np.chararray, np.matrix, np.memmap]
67+
)
68+
@pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9")
69+
def test_class_getitem(cls: Type[np.ndarray]) -> None:
70+
"""Test `ndarray.__class_getitem__`."""
71+
alias = cls[Any, Any]
72+
assert isinstance(alias, types.GenericAlias)
73+
assert alias.__origin__ is cls

0 commit comments

Comments
 (0)