Skip to content

Commit c109fd6

Browse files
committed
TYP: Sane defaults for the platform-specific NBitBase types.
This will help for those that don't use the mypy plugin.
1 parent b664fe0 commit c109fd6

File tree

4 files changed

+127
-100
lines changed

4 files changed

+127
-100
lines changed

numpy/_typing/__init__.py

Lines changed: 11 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -2,95 +2,20 @@
22

33
from __future__ import annotations
44

5-
from .._utils import set_module
6-
from typing import final
7-
8-
9-
@final # Disallow the creation of arbitrary `NBitBase` subclasses
10-
@set_module("numpy.typing")
11-
class NBitBase:
12-
"""
13-
A type representing `numpy.number` precision during static type checking.
14-
15-
Used exclusively for the purpose static type checking, `NBitBase`
16-
represents the base of a hierarchical set of subclasses.
17-
Each subsequent subclass is herein used for representing a lower level
18-
of precision, *e.g.* ``64Bit > 32Bit > 16Bit``.
19-
20-
.. versionadded:: 1.20
21-
22-
Examples
23-
--------
24-
Below is a typical usage example: `NBitBase` is herein used for annotating
25-
a function that takes a float and integer of arbitrary precision
26-
as arguments and returns a new float of whichever precision is largest
27-
(*e.g.* ``np.float16 + np.int64 -> np.float64``).
28-
29-
.. code-block:: python
30-
31-
>>> from __future__ import annotations
32-
>>> from typing import TypeVar, TYPE_CHECKING
33-
>>> import numpy as np
34-
>>> import numpy.typing as npt
35-
36-
>>> T1 = TypeVar("T1", bound=npt.NBitBase)
37-
>>> T2 = TypeVar("T2", bound=npt.NBitBase)
38-
39-
>>> def add(a: np.floating[T1], b: np.integer[T2]) -> np.floating[T1 | T2]:
40-
... return a + b
41-
42-
>>> a = np.float16()
43-
>>> b = np.int64()
44-
>>> out = add(a, b)
45-
46-
>>> if TYPE_CHECKING:
47-
... reveal_locals()
48-
... # note: Revealed local types are:
49-
... # note: a: numpy.floating[numpy.typing._16Bit*]
50-
... # note: b: numpy.signedinteger[numpy.typing._64Bit*]
51-
... # note: out: numpy.floating[numpy.typing._64Bit*]
52-
53-
"""
54-
55-
def __init_subclass__(cls) -> None:
56-
allowed_names = {
57-
"NBitBase", "_256Bit", "_128Bit", "_96Bit", "_80Bit",
58-
"_64Bit", "_32Bit", "_16Bit", "_8Bit",
59-
}
60-
if cls.__name__ not in allowed_names:
61-
raise TypeError('cannot inherit from final class "NBitBase"')
62-
super().__init_subclass__()
63-
64-
65-
# Silence errors about subclassing a `@final`-decorated class
66-
class _256Bit(NBitBase): # type: ignore[misc]
67-
pass
68-
69-
class _128Bit(_256Bit): # type: ignore[misc]
70-
pass
71-
72-
class _96Bit(_128Bit): # type: ignore[misc]
73-
pass
74-
75-
class _80Bit(_96Bit): # type: ignore[misc]
76-
pass
77-
78-
class _64Bit(_80Bit): # type: ignore[misc]
79-
pass
80-
81-
class _32Bit(_64Bit): # type: ignore[misc]
82-
pass
83-
84-
class _16Bit(_32Bit): # type: ignore[misc]
85-
pass
86-
87-
class _8Bit(_16Bit): # type: ignore[misc]
88-
pass
89-
90-
915
from ._nested_sequence import (
926
_NestedSequence as _NestedSequence,
937
)
8+
from ._nbit_base import (
9+
NBitBase as NBitBase,
10+
_8Bit as _8Bit,
11+
_16Bit as _16Bit,
12+
_32Bit as _32Bit,
13+
_64Bit as _64Bit,
14+
_80Bit as _80Bit,
15+
_96Bit as _96Bit,
16+
_128Bit as _128Bit,
17+
_256Bit as _256Bit,
18+
)
9419
from ._nbit import (
9520
_NBitByte as _NBitByte,
9621
_NBitShort as _NBitShort,

numpy/_typing/_nbit.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
"""A module with the precisions of platform-specific `~numpy.number`s."""
22

3-
from typing import Any
3+
from typing import TypeAlias
4+
from ._nbit_base import _8Bit, _16Bit, _32Bit, _64Bit, _96Bit, _128Bit
5+
46

57
# To-be replaced with a `npt.NBitBase` subclass by numpy's mypy plugin
6-
_NBitByte = Any
7-
_NBitShort = Any
8-
_NBitIntC = Any
9-
_NBitIntP = Any
10-
_NBitInt = Any
11-
_NBitLong = Any
12-
_NBitLongLong = Any
8+
_NBitByte: TypeAlias = _8Bit
9+
_NBitShort: TypeAlias = _16Bit
10+
_NBitIntC: TypeAlias = _32Bit
11+
_NBitIntP: TypeAlias = _32Bit | _64Bit
12+
_NBitInt: TypeAlias = _NBitIntP
13+
_NBitLong: TypeAlias = _32Bit | _64Bit
14+
_NBitLongLong: TypeAlias = _64Bit
1315

14-
_NBitHalf = Any
15-
_NBitSingle = Any
16-
_NBitDouble = Any
17-
_NBitLongDouble = Any
16+
_NBitHalf: TypeAlias = _16Bit
17+
_NBitSingle: TypeAlias = _32Bit
18+
_NBitDouble: TypeAlias = _64Bit
19+
_NBitLongDouble: TypeAlias = _64Bit | _96Bit | _128Bit

numpy/_typing/_nbit_base.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""A module with the precisions of generic `~numpy.number` types."""
2+
from .._utils import set_module
3+
from typing import final
4+
5+
6+
@final # Disallow the creation of arbitrary `NBitBase` subclasses
7+
@set_module("numpy.typing")
8+
class NBitBase:
9+
"""
10+
A type representing `numpy.number` precision during static type checking.
11+
12+
Used exclusively for the purpose static type checking, `NBitBase`
13+
represents the base of a hierarchical set of subclasses.
14+
Each subsequent subclass is herein used for representing a lower level
15+
of precision, *e.g.* ``64Bit > 32Bit > 16Bit``.
16+
17+
.. versionadded:: 1.20
18+
19+
Examples
20+
--------
21+
Below is a typical usage example: `NBitBase` is herein used for annotating
22+
a function that takes a float and integer of arbitrary precision
23+
as arguments and returns a new float of whichever precision is largest
24+
(*e.g.* ``np.float16 + np.int64 -> np.float64``).
25+
26+
.. code-block:: python
27+
28+
>>> from __future__ import annotations
29+
>>> from typing import TypeVar, TYPE_CHECKING
30+
>>> import numpy as np
31+
>>> import numpy.typing as npt
32+
33+
>>> S = TypeVar("S", bound=npt.NBitBase)
34+
>>> T = TypeVar("T", bound=npt.NBitBase)
35+
36+
>>> def add(a: np.floating[S], b: np.integer[T]) -> np.floating[S | T]:
37+
... return a + b
38+
39+
>>> a = np.float16()
40+
>>> b = np.int64()
41+
>>> out = add(a, b)
42+
43+
>>> if TYPE_CHECKING:
44+
... reveal_locals()
45+
... # note: Revealed local types are:
46+
... # note: a: numpy.floating[numpy.typing._16Bit*]
47+
... # note: b: numpy.signedinteger[numpy.typing._64Bit*]
48+
... # note: out: numpy.floating[numpy.typing._64Bit*]
49+
50+
"""
51+
52+
def __init_subclass__(cls) -> None:
53+
allowed_names = {
54+
"NBitBase", "_256Bit", "_128Bit", "_96Bit", "_80Bit",
55+
"_64Bit", "_32Bit", "_16Bit", "_8Bit",
56+
}
57+
if cls.__name__ not in allowed_names:
58+
raise TypeError('cannot inherit from final class "NBitBase"')
59+
super().__init_subclass__()
60+
61+
@final
62+
@set_module("numpy._typing")
63+
# Silence errors about subclassing a `@final`-decorated class
64+
class _256Bit(NBitBase): # type: ignore[misc]
65+
pass
66+
67+
@final
68+
@set_module("numpy._typing")
69+
class _128Bit(_256Bit): # type: ignore[misc]
70+
pass
71+
72+
@final
73+
@set_module("numpy._typing")
74+
class _96Bit(_128Bit): # type: ignore[misc]
75+
pass
76+
77+
@final
78+
@set_module("numpy._typing")
79+
class _80Bit(_96Bit): # type: ignore[misc]
80+
pass
81+
82+
@final
83+
@set_module("numpy._typing")
84+
class _64Bit(_80Bit): # type: ignore[misc]
85+
pass
86+
87+
@final
88+
@set_module("numpy._typing")
89+
class _32Bit(_64Bit): # type: ignore[misc]
90+
pass
91+
92+
@final
93+
@set_module("numpy._typing")
94+
class _16Bit(_32Bit): # type: ignore[misc]
95+
pass
96+
97+
@final
98+
@set_module("numpy._typing")
99+
class _8Bit(_16Bit): # type: ignore[misc]
100+
pass

numpy/typing/mypy_plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,10 @@ def _get_precision_dict() -> dict[str, str]:
6969
("_NBitLongDouble", np.longdouble),
7070
]
7171
ret = {}
72+
module = "numpy._typing"
7273
for name, typ in names:
7374
n: int = 8 * typ().dtype.itemsize
74-
ret[f'numpy._typing._nbit.{name}'] = f"numpy._{n}Bit"
75+
ret[f'{module}._nbit.{name}'] = f"{module}._nbit_base._{n}Bit"
7576
return ret
7677

7778

@@ -92,7 +93,6 @@ def _get_extended_precision_list() -> list[str]:
9293
]
9394
return [i for i in extended_names if hasattr(np, i)]
9495

95-
9696
def _get_c_intp_name() -> str:
9797
# Adapted from `np.core._internal._getintp_ctype`
9898
char = np.dtype('n').char

0 commit comments

Comments
 (0)