Skip to content

Commit d9f11f2

Browse files
author
Bas van Beek
committed
ENH: Add dtype.__class_getitem__
1 parent 9ca8076 commit d9f11f2

File tree

4 files changed

+64
-7
lines changed

4 files changed

+64
-7
lines changed

numpy/__init__.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,6 +1091,9 @@ class dtype(Generic[_DTypeScalar_co]):
10911091
copy: bool = ...,
10921092
) -> dtype[object_]: ...
10931093

1094+
if sys.version_info >= (3, 9):
1095+
def __class_getitem__(self, item: Any) -> GenericAlias: ...
1096+
10941097
@overload
10951098
def __getitem__(self: dtype[void], key: List[str]) -> dtype[void]: ...
10961099
@overload

numpy/core/_add_newdocs.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6079,6 +6079,36 @@
60796079
60806080
"""))
60816081

6082+
if sys.version_info >= (3, 9):
6083+
add_newdoc('numpy.core.multiarray', 'dtype', ('__class_getitem__',
6084+
"""
6085+
__class_getitem__(item, /)
6086+
6087+
Return a parametrized wrapper around the `~numpy.dtype` type.
6088+
6089+
.. versionadded:: 1.22
6090+
6091+
Returns
6092+
-------
6093+
alias : types.GenericAlias
6094+
A parametrized `~numpy.dtype` type.
6095+
6096+
Examples
6097+
--------
6098+
>>> import numpy as np
6099+
6100+
>>> np.dtype[np.int64]
6101+
numpy.dtype[numpy.int64]
6102+
6103+
Note
6104+
----
6105+
This method is only available for python 3.9 and later.
6106+
6107+
See Also
6108+
--------
6109+
:pep:`585` : Type hinting generics in standard collections.
6110+
6111+
"""))
60826112

60836113
##############################################################################
60846114
#

numpy/core/src/multiarray/descriptor.c

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ static PyArray_Descr *
257257
_convert_from_tuple(PyObject *obj, int align)
258258
{
259259
if (PyTuple_GET_SIZE(obj) != 2) {
260-
PyErr_Format(PyExc_TypeError,
260+
PyErr_Format(PyExc_TypeError,
261261
"Tuple must have size 2, but has size %zd",
262262
PyTuple_GET_SIZE(obj));
263263
return NULL;
@@ -449,8 +449,8 @@ _convert_from_array_descr(PyObject *obj, int align)
449449
for (int i = 0; i < n; i++) {
450450
PyObject *item = PyList_GET_ITEM(obj, i);
451451
if (!PyTuple_Check(item) || (PyTuple_GET_SIZE(item) < 2)) {
452-
PyErr_Format(PyExc_TypeError,
453-
"Field elements must be 2- or 3-tuples, got '%R'",
452+
PyErr_Format(PyExc_TypeError,
453+
"Field elements must be 2- or 3-tuples, got '%R'",
454454
item);
455455
goto fail;
456456
}
@@ -461,7 +461,7 @@ _convert_from_array_descr(PyObject *obj, int align)
461461
}
462462
else if (PyTuple_Check(name)) {
463463
if (PyTuple_GET_SIZE(name) != 2) {
464-
PyErr_Format(PyExc_TypeError,
464+
PyErr_Format(PyExc_TypeError,
465465
"If a tuple, the first element of a field tuple must have "
466466
"two elements, not %zd",
467467
PyTuple_GET_SIZE(name));
@@ -475,7 +475,7 @@ _convert_from_array_descr(PyObject *obj, int align)
475475
}
476476
}
477477
else {
478-
PyErr_SetString(PyExc_TypeError,
478+
PyErr_SetString(PyExc_TypeError,
479479
"First element of field tuple is "
480480
"neither a tuple nor str");
481481
goto fail;
@@ -3112,6 +3112,13 @@ static PyMethodDef arraydescr_methods[] = {
31123112
{"newbyteorder",
31133113
(PyCFunction)arraydescr_newbyteorder,
31143114
METH_VARARGS, NULL},
3115+
3116+
/* for typing; requires python >= 3.9 */
3117+
#ifdef Py_GENERICALIASOBJECT_H
3118+
{"__class_getitem__",
3119+
(PyCFunction)Py_GenericAlias,
3120+
METH_CLASS | METH_O, NULL},
3121+
#endif
31153122
{NULL, NULL, 0, NULL} /* sentinel */
31163123
};
31173124

numpy/core/tests/test_dtype.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import ctypes
55
import gc
66
import warnings
7+
import types
8+
from typing import Any
79

810
import numpy as np
911
from numpy.core._rational_tests import rational
@@ -111,9 +113,9 @@ def test_richcompare_invalid_dtype_comparison(self, operation):
111113
@pytest.mark.parametrize("dtype",
112114
['Bool', 'Bytes0', 'Complex32', 'Complex64',
113115
'Datetime64', 'Float16', 'Float32', 'Float64',
114-
'Int8', 'Int16', 'Int32', 'Int64',
116+
'Int8', 'Int16', 'Int32', 'Int64',
115117
'Object0', 'Str0', 'Timedelta64',
116-
'UInt8', 'UInt16', 'Uint32', 'UInt32',
118+
'UInt8', 'UInt16', 'Uint32', 'UInt32',
117119
'Uint64', 'UInt64', 'Void0',
118120
"Float128", "Complex128"])
119121
def test_numeric_style_types_are_invalid(self, dtype):
@@ -1549,3 +1551,18 @@ class mytype:
15491551
# Tests that a dtype must have its type field set up to np.dtype
15501552
# or in this case a builtin instance.
15511553
create_custom_field_dtype(blueprint, mytype, 2)
1554+
1555+
1556+
@pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9")
1557+
class TestClassGetItem:
1558+
def test_dtype(self) -> None:
1559+
alias = np.dtype[Any]
1560+
assert isinstance(alias, types.GenericAlias)
1561+
assert alias.__origin__ is np.dtype
1562+
1563+
@pytest.mark.parametrize("code", np.typecodes["All"])
1564+
def test_dtype_subclass(self, code: str) -> None:
1565+
cls = type(np.dtype(code))
1566+
alias = cls[Any]
1567+
assert isinstance(alias, types.GenericAlias)
1568+
assert alias.__origin__ is cls

0 commit comments

Comments
 (0)