Skip to content

Commit a179060

Browse files
authored
Merge pull request numpy#27792 from jorenham/typing/generic-bool
TYP: Generic ``numpy.bool`` and statically typed boolean logic
2 parents 484f9bf + c5cf9e3 commit a179060

File tree

5 files changed

+151
-35
lines changed

5 files changed

+151
-35
lines changed

numpy/__init__.pyi

Lines changed: 90 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -801,13 +801,16 @@ _SCT = TypeVar("_SCT", bound=generic)
801801
_SCT_co = TypeVar("_SCT_co", bound=generic, covariant=True)
802802
_NumberT = TypeVar("_NumberT", bound=number[Any])
803803
_FloatingT_co = TypeVar("_FloatingT_co", bound=floating[Any], default=floating[Any], covariant=True)
804+
_IntegerT = TypeVar("_IntegerT", bound=integer)
804805
_IntegerT_co = TypeVar("_IntegerT_co", bound=integer[Any], default=integer[Any], covariant=True)
805806

806807
_NBit = TypeVar("_NBit", bound=NBitBase, default=Any)
807808
_NBit1 = TypeVar("_NBit1", bound=NBitBase, default=Any)
808809
_NBit2 = TypeVar("_NBit2", bound=NBitBase, default=_NBit1)
809810

810811
_ItemT_co = TypeVar("_ItemT_co", default=Any, covariant=True)
812+
_BoolItemT = TypeVar("_BoolItemT", bound=builtins.bool)
813+
_BoolItemT_co = TypeVar("_BoolItemT_co", bound=builtins.bool, default=builtins.bool, covariant=True)
811814
_NumberItemT_co = TypeVar("_NumberItemT_co", bound=int | float | complex, default=int | float | complex, covariant=True)
812815
_InexactItemT_co = TypeVar("_InexactItemT_co", bound=float | complex, default=float | complex, covariant=True)
813816
_FlexibleItemT_co = TypeVar(
@@ -823,6 +826,9 @@ _TD64UnitT = TypeVar("_TD64UnitT", bound=_TD64Unit, default=_TD64Unit)
823826

824827
### Type Aliases (for internal use only)
825828

829+
_Falsy: TypeAlias = L[False, 0] | np.bool[L[False]]
830+
_Truthy: TypeAlias = L[True, 1] | np.bool[L[True]]
831+
826832
_1D: TypeAlias = tuple[int]
827833
_2D: TypeAlias = tuple[int, int]
828834
_2Tuple: TypeAlias = tuple[_T, _T]
@@ -1144,8 +1150,9 @@ nan: Final[float] = ...
11441150
pi: Final[float] = ...
11451151

11461152
little_endian: Final[builtins.bool] = ...
1147-
True_: Final[np.bool] = ...
1148-
False_: Final[np.bool] = ...
1153+
1154+
False_: Final[np.bool[L[False]]] = ...
1155+
True_: Final[np.bool[L[True]]] = ...
11491156

11501157
newaxis: Final[None] = None
11511158

@@ -3559,36 +3566,103 @@ class number(generic[_NumberItemT_co], Generic[_NBit, _NumberItemT_co]):
35593566
__gt__: _ComparisonOpGT[_NumberLike_co, _ArrayLikeNumber_co]
35603567
__ge__: _ComparisonOpGE[_NumberLike_co, _ArrayLikeNumber_co]
35613568

3562-
class bool(_RealMixin, generic[builtins.bool]):
3563-
def __init__(self, value: object = ..., /) -> None: ...
3569+
class bool(generic[_BoolItemT_co], Generic[_BoolItemT_co]):
3570+
@property
3571+
def itemsize(self) -> L[1]: ...
3572+
@property
3573+
def nbytes(self) -> L[1]: ...
3574+
@property
3575+
def real(self) -> Self: ...
3576+
@property
3577+
def imag(self) -> np.bool[L[False]]: ...
3578+
3579+
@overload
3580+
def __init__(self: np.bool[L[False]], /) -> None: ...
3581+
@overload
3582+
def __init__(self: np.bool[L[False]], value: _Falsy = ..., /) -> None: ...
3583+
@overload
3584+
def __init__(self: np.bool[L[True]], value: _Truthy, /) -> None: ...
3585+
@overload
3586+
def __init__(self, value: object, /) -> None: ...
35643587

3588+
def __bool__(self, /) -> _BoolItemT_co: ...
3589+
@overload
3590+
def __int__(self: np.bool[L[False]], /) -> L[0]: ...
3591+
@overload
3592+
def __int__(self: np.bool[L[True]], /) -> L[1]: ...
3593+
@overload
3594+
def __int__(self, /) -> L[0, 1]: ...
35653595
@deprecated("In future, it will be an error for 'np.bool' scalars to be interpreted as an index")
3566-
def __index__(self, /) -> int: ...
3596+
def __index__(self, /) -> L[0, 1]: ...
35673597
def __abs__(self) -> Self: ...
3568-
def __invert__(self) -> np.bool: ...
3598+
3599+
@overload
3600+
def __invert__(self: np.bool[L[False]], /) -> np.bool[L[True]]: ...
3601+
@overload
3602+
def __invert__(self: np.bool[L[True]], /) -> np.bool[L[False]]: ...
3603+
@overload
3604+
def __invert__(self, /) -> np.bool: ...
35693605

35703606
__add__: _BoolOp[np.bool]
35713607
__radd__: _BoolOp[np.bool]
35723608
__sub__: _BoolSub
35733609
__rsub__: _BoolSub
35743610
__mul__: _BoolOp[np.bool]
35753611
__rmul__: _BoolOp[np.bool]
3612+
__truediv__: _BoolTrueDiv
3613+
__rtruediv__: _BoolTrueDiv
35763614
__floordiv__: _BoolOp[int8]
35773615
__rfloordiv__: _BoolOp[int8]
35783616
__pow__: _BoolOp[int8]
35793617
__rpow__: _BoolOp[int8]
3580-
__truediv__: _BoolTrueDiv
3581-
__rtruediv__: _BoolTrueDiv
3618+
35823619
__lshift__: _BoolBitOp[int8]
35833620
__rlshift__: _BoolBitOp[int8]
35843621
__rshift__: _BoolBitOp[int8]
35853622
__rrshift__: _BoolBitOp[int8]
3586-
__and__: _BoolBitOp[np.bool]
3587-
__rand__: _BoolBitOp[np.bool]
3588-
__xor__: _BoolBitOp[np.bool]
3589-
__rxor__: _BoolBitOp[np.bool]
3590-
__or__: _BoolBitOp[np.bool]
3591-
__ror__: _BoolBitOp[np.bool]
3623+
3624+
@overload
3625+
def __and__(self: np.bool[L[False]], other: builtins.bool | np.bool, /) -> np.bool[L[False]]: ...
3626+
@overload
3627+
def __and__(self, other: L[False] | np.bool[L[False]], /) -> np.bool[L[False]]: ...
3628+
@overload
3629+
def __and__(self, other: L[True] | np.bool[L[True]], /) -> Self: ...
3630+
@overload
3631+
def __and__(self, other: builtins.bool | np.bool, /) -> np.bool: ...
3632+
@overload
3633+
def __and__(self, other: _IntegerT, /) -> _IntegerT: ...
3634+
@overload
3635+
def __and__(self, other: int, /) -> np.bool | intp: ...
3636+
__rand__ = __and__
3637+
3638+
@overload
3639+
def __xor__(self: np.bool[L[False]], other: _BoolItemT | np.bool[_BoolItemT], /) -> np.bool[_BoolItemT]: ...
3640+
@overload
3641+
def __xor__(self: np.bool[L[True]], other: L[True] | np.bool[L[True]], /) -> np.bool[L[False]]: ...
3642+
@overload
3643+
def __xor__(self, other: L[False] | np.bool[L[False]], /) -> Self: ...
3644+
@overload
3645+
def __xor__(self, other: builtins.bool | np.bool, /) -> np.bool: ...
3646+
@overload
3647+
def __xor__(self, other: _IntegerT, /) -> _IntegerT: ...
3648+
@overload
3649+
def __xor__(self, other: int, /) -> np.bool | intp: ...
3650+
__rxor__ = __xor__
3651+
3652+
@overload
3653+
def __or__(self: np.bool[L[True]], other: builtins.bool | np.bool, /) -> np.bool[L[True]]: ...
3654+
@overload
3655+
def __or__(self, other: L[False] | np.bool[L[False]], /) -> Self: ...
3656+
@overload
3657+
def __or__(self, other: L[True] | np.bool[L[True]], /) -> np.bool[L[True]]: ...
3658+
@overload
3659+
def __or__(self, other: builtins.bool | np.bool, /) -> np.bool: ...
3660+
@overload
3661+
def __or__(self, other: _IntegerT, /) -> _IntegerT: ...
3662+
@overload
3663+
def __or__(self, other: int, /) -> np.bool | intp: ...
3664+
__ror__ = __or__
3665+
35923666
__mod__: _BoolMod
35933667
__rmod__: _BoolMod
35943668
__divmod__: _BoolDivMod
@@ -3599,7 +3673,8 @@ class bool(_RealMixin, generic[builtins.bool]):
35993673
__gt__: _ComparisonOpGT[_NumberLike_co, _ArrayLikeNumber_co]
36003674
__ge__: _ComparisonOpGE[_NumberLike_co, _ArrayLikeNumber_co]
36013675

3602-
bool_: TypeAlias = bool
3676+
# NOTE: This should _not_ be `Final` or a `TypeAlias`
3677+
bool_ = bool
36033678

36043679
# NOTE: The `object_` constructor returns the passed object, so instances with type
36053680
# `object_` cannot exists (at runtime).

numpy/typing/tests/data/reveal/bitwise_ops.pyi

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,31 @@
1-
from typing import Any
1+
from typing import Any, Literal as L, TypeAlias
22

33
import numpy as np
44
import numpy.typing as npt
55
from numpy._typing import _64Bit, _32Bit
66

77
from typing_extensions import assert_type
88

9-
i8 = np.int64(1)
10-
u8 = np.uint64(1)
9+
FalseType: TypeAlias = L[False]
10+
TrueType: TypeAlias = L[True]
1111

12-
i4 = np.int32(1)
13-
u4 = np.uint32(1)
12+
i4: np.int32
13+
i8: np.int64
1414

15-
b_ = np.bool(1)
15+
u4: np.uint32
16+
u8: np.uint64
1617

17-
b = bool(1)
18-
i = int(1)
18+
b_: np.bool[bool]
19+
b0_: np.bool[FalseType]
20+
b1_: np.bool[TrueType]
1921

20-
AR = np.array([0, 1, 2], dtype=np.int32)
21-
AR.setflags(write=False)
22+
b: bool
23+
b0: FalseType
24+
b1: TrueType
25+
26+
i: int
27+
28+
AR: npt.NDArray[np.int32]
2229

2330

2431
assert_type(i8 << i8, np.int64)
@@ -119,13 +126,45 @@ assert_type(b_ & b, np.bool)
119126

120127
assert_type(b_ << i, np.int_)
121128
assert_type(b_ >> i, np.int_)
122-
assert_type(b_ | i, np.int_)
123-
assert_type(b_ ^ i, np.int_)
124-
assert_type(b_ & i, np.int_)
129+
assert_type(b_ | i, np.bool | np.int_)
130+
assert_type(b_ ^ i, np.bool | np.int_)
131+
assert_type(b_ & i, np.bool | np.int_)
125132

126133
assert_type(~i8, np.int64)
127134
assert_type(~i4, np.int32)
128135
assert_type(~u8, np.uint64)
129136
assert_type(~u4, np.uint32)
130137
assert_type(~b_, np.bool)
138+
assert_type(~b0_, np.bool[TrueType])
139+
assert_type(~b1_, np.bool[FalseType])
131140
assert_type(~AR, npt.NDArray[np.int32])
141+
142+
assert_type(b_ | b0_, np.bool)
143+
assert_type(b0_ | b_, np.bool)
144+
assert_type(b_ | b1_, np.bool[TrueType])
145+
assert_type(b1_ | b_, np.bool[TrueType])
146+
147+
assert_type(b_ ^ b0_, np.bool)
148+
assert_type(b0_ ^ b_, np.bool)
149+
assert_type(b_ ^ b1_, np.bool)
150+
assert_type(b1_ ^ b_, np.bool)
151+
152+
assert_type(b_ & b0_, np.bool[FalseType])
153+
assert_type(b0_ & b_, np.bool[FalseType])
154+
assert_type(b_ & b1_, np.bool)
155+
assert_type(b1_ & b_, np.bool)
156+
157+
assert_type(b0_ | b0_, np.bool[FalseType])
158+
assert_type(b0_ | b1_, np.bool[TrueType])
159+
assert_type(b1_ | b0_, np.bool[TrueType])
160+
assert_type(b1_ | b1_, np.bool[TrueType])
161+
162+
assert_type(b0_ ^ b0_, np.bool[FalseType])
163+
assert_type(b0_ ^ b1_, np.bool[TrueType])
164+
assert_type(b1_ ^ b0_, np.bool[TrueType])
165+
assert_type(b1_ ^ b1_, np.bool[FalseType])
166+
167+
assert_type(b0_ & b0_, np.bool[FalseType])
168+
assert_type(b0_ & b1_, np.bool[FalseType])
169+
assert_type(b1_ & b0_, np.bool[FalseType])
170+
assert_type(b1_ & b1_, np.bool[TrueType])
Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
import numpy as np
2-
1+
from typing import Literal
32
from typing_extensions import assert_type
43

4+
import numpy as np
5+
56
assert_type(np.e, float)
67
assert_type(np.euler_gamma, float)
78
assert_type(np.inf, float)
89
assert_type(np.nan, float)
910
assert_type(np.pi, float)
1011

1112
assert_type(np.little_endian, bool)
12-
assert_type(np.True_, np.bool)
13-
assert_type(np.False_, np.bool)
13+
14+
assert_type(np.True_, np.bool[Literal[True]])
15+
assert_type(np.False_, np.bool[Literal[False]])

numpy/typing/tests/data/reveal/numerictypes.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from typing import Literal
2+
from typing_extensions import assert_type
23

34
import numpy as np
45

5-
from typing_extensions import assert_type
66

77
assert_type(
88
np.ScalarType,
@@ -44,7 +44,7 @@ assert_type(np.ScalarType[0], type[int])
4444
assert_type(np.ScalarType[3], type[bool])
4545
assert_type(np.ScalarType[8], type[np.csingle])
4646
assert_type(np.ScalarType[10], type[np.clongdouble])
47-
assert_type(np.bool_, type[np.bool])
47+
assert_type(np.bool_(object()), np.bool)
4848

4949
assert_type(np.typecodes["Character"], Literal["c"])
5050
assert_type(np.typecodes["Complex"], Literal["FDG"])

numpy/typing/tests/data/reveal/scalars.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ assert_type(V[["field1", "field2"]], np.void)
5252
V[0] = 5
5353

5454
# Aliases
55-
assert_type(np.bool_(), np.bool)
55+
assert_type(np.bool_(), np.bool[Literal[False]])
5656
assert_type(np.byte(), np.byte)
5757
assert_type(np.short(), np.short)
5858
assert_type(np.intc(), np.intc)

0 commit comments

Comments
 (0)