Skip to content

Commit 5580a4c

Browse files
author
Bas van Beek
committed
ENH: Add a protocol for representing nested sequences
1 parent df0b1bd commit 5580a4c

File tree

2 files changed

+94
-0
lines changed

2 files changed

+94
-0
lines changed

numpy/typing/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ class _32Bit(_64Bit): ... # type: ignore[misc]
219219
class _16Bit(_32Bit): ... # type: ignore[misc]
220220
class _8Bit(_16Bit): ... # type: ignore[misc]
221221

222+
from ._nested_sequence import _NestedSequence
222223
from ._nbit import (
223224
_NBitByte,
224225
_NBitShort,

numpy/typing/_nested_sequence.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""A module containing the `_NestedSequence` protocol."""
2+
3+
from __future__ import annotations
4+
5+
import sys
6+
from typing import (
7+
Any,
8+
Iterator,
9+
overload,
10+
TypeVar,
11+
Protocol,
12+
)
13+
14+
__all__ = ["_NestedSequence"]
15+
16+
_T_co = TypeVar("_T_co", covariant=True)
17+
18+
19+
class _NestedSequence(Protocol[_T_co]):
20+
"""A protocol for representing nested sequences.
21+
22+
Warning
23+
-------
24+
`_NestedSequence` currently does not work in combination with typevars,
25+
*e.g.* ``def func(a: _NestedSequnce[T]) -> T: ...``.
26+
27+
See Also
28+
--------
29+
`collections.abc.Sequence`
30+
ABCs for read-only and mutable :term:`sequences`.
31+
32+
Examples
33+
--------
34+
.. code-block:: python
35+
36+
>>> from __future__ import annotations
37+
38+
>>> from typing import TYPE_CHECKING
39+
>>> import numpy as np
40+
>>> from numpy.typing import _NestedSequnce
41+
42+
>>> def get_dtype(seq: _NestedSequnce[float]) -> np.dtype[np.float64]:
43+
... return np.asarray(seq).dtype
44+
45+
>>> a = get_dtype([1.0])
46+
>>> b = get_dtype([[1.0]])
47+
>>> c = get_dtype([[[1.0]]])
48+
>>> d = get_dtype([[[[1.0]]]])
49+
50+
>>> if TYPE_CHECKING:
51+
... reveal_locals()
52+
... # note: Revealed local types are:
53+
... # note: a: numpy.dtype[numpy.floating[numpy.typing._64Bit]]
54+
... # note: b: numpy.dtype[numpy.floating[numpy.typing._64Bit]]
55+
... # note: c: numpy.dtype[numpy.floating[numpy.typing._64Bit]]
56+
... # note: d: numpy.dtype[numpy.floating[numpy.typing._64Bit]]
57+
58+
"""
59+
60+
def __len__(self, /) -> int:
61+
"""Implement ``len(self)``."""
62+
raise NotImplementedError
63+
64+
@overload
65+
def __getitem__(self, index: int, /) -> _T_co | _NestedSequence[_T_co]: ...
66+
@overload
67+
def __getitem__(self, index: slice, /) -> _NestedSequence[_T_co]: ...
68+
69+
def __getitem__(self, index, /):
70+
"""Implement ``self[x]``."""
71+
raise NotImplementedError
72+
73+
def __contains__(self, x: object, /) -> bool:
74+
"""Implement ``x in self``."""
75+
raise NotImplementedError
76+
77+
def __iter__(self, /) -> Iterator[_T_co | _NestedSequence[_T_co]]:
78+
"""Implement ``iter(self)``."""
79+
raise NotImplementedError
80+
81+
def __reversed__(self, /) -> Iterator[_T_co | _NestedSequence[_T_co]]:
82+
"""Implement ``reversed(self)``."""
83+
raise NotImplementedError
84+
85+
def count(self, value: Any, /) -> int:
86+
"""Return the number of occurrences of `value`."""
87+
raise NotImplementedError
88+
89+
def index(
90+
self, value: Any, start: int = 0, stop: int = sys.maxsize, /
91+
) -> int:
92+
"""Return the first index of `value`."""
93+
raise NotImplementedError

0 commit comments

Comments
 (0)