Skip to content

Commit 9fe2834

Browse files
authored
🏷️ fix stubtest errors in numpy.random (#251)
1 parent 6c7b475 commit 9fe2834

File tree

4 files changed

+85
-82
lines changed

4 files changed

+85
-82
lines changed

.mypyignore-todo

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,6 @@ numpy.polynomial.polynomial.polyvander
231231
numpy.polynomial.polynomial.polyvander2d
232232
numpy.polynomial.polynomial.polyvander3d
233233

234-
numpy\.random(\.bit_generator)?\.BitGenerator\.capsule
235-
numpy\.random\.((_mt19937\.)?MT19937|(_pcg64\.)?PCG64(DXSM)?|(_philox\.)?Philox|(_sfc64\.)?SFC64|(bit_generator\.)?(Seedless)?SeedSequence)\.__(reduce|setstate)_cython__
236-
237234
numpy\.testing(\._private.utils)?\.assert_array_equal
238235
numpy\.testing(\._private.utils)?\.assert_array_almost_equal
239236
numpy\.testing(\._private.utils)?\.assert_array_compare
Lines changed: 73 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
import abc
2+
from binascii import Incomplete
23
from collections.abc import Callable, Mapping, Sequence
34
from threading import Lock
4-
from typing import Any, Generic, Literal, NamedTuple, TypeAlias, TypedDict, overload, type_check_only
5-
from typing_extensions import Self, TypeVar
5+
from typing import Any, ClassVar, Generic, Literal as L, NamedTuple, TypeAlias, TypedDict, overload, type_check_only
6+
from typing_extensions import CapsuleType, Self, TypeVar
67

78
import numpy as np
89
from numpy._typing import NDArray, _ArrayLikeInt_co, _DTypeLike, _ShapeLike, _UInt32Codes, _UInt64Codes
910

1011
__all__ = ["BitGenerator", "SeedSequence"]
1112

13+
###
14+
1215
_StateT = TypeVar("_StateT", bound=Mapping[str, object], default=Mapping[str, Any])
1316

14-
_DTypeLikeUint32: TypeAlias = _DTypeLike[np.uint32] | _UInt32Codes
15-
_DTypeLikeUint64: TypeAlias = _DTypeLike[np.uint64] | _UInt64Codes
17+
_ToDTypeUInt32: TypeAlias = _DTypeLike[np.uint32] | _UInt32Codes
18+
_ToDTypeUInt64: TypeAlias = _DTypeLike[np.uint64] | _UInt64Codes
1619

1720
###
1821

@@ -25,93 +28,96 @@ class _SeedSeqState(TypedDict):
2528

2629
@type_check_only
2730
class _Interface(NamedTuple):
28-
state_address: Any
29-
state: Any
30-
next_uint64: Any
31-
next_uint32: Any
32-
next_double: Any
33-
bit_generator: Any
34-
35-
class ISeedSequence(abc.ABC):
36-
@abc.abstractmethod
37-
def generate_state(
38-
self,
39-
n_words: int,
40-
dtype: _DTypeLikeUint32 | _DTypeLikeUint64 = ...,
41-
) -> NDArray[np.uint32 | np.uint64]: ...
42-
43-
class ISpawnableSeedSequence(ISeedSequence, abc.ABC):
44-
@abc.abstractmethod
45-
def spawn(self, n_children: int) -> list[Self]: ...
46-
47-
class SeedlessSeedSequence(ISpawnableSeedSequence):
48-
def generate_state(
49-
self,
50-
n_words: int,
51-
dtype: _DTypeLikeUint32 | _DTypeLikeUint64 = ...,
52-
) -> NDArray[np.uint32 | np.uint64]: ...
53-
def spawn(self, n_children: int) -> list[Self]: ...
31+
state_address: Incomplete
32+
state: Incomplete
33+
next_uint64: Incomplete
34+
next_uint32: Incomplete
35+
next_double: Incomplete
36+
bit_generator: Incomplete
5437

55-
class SeedSequence(ISpawnableSeedSequence):
56-
entropy: int | Sequence[int] | None
57-
spawn_key: tuple[int, ...]
58-
pool_size: int
59-
n_children_spawned: int
60-
pool: NDArray[np.uint32]
38+
@type_check_only
39+
class _CythonMixin:
40+
def __setstate_cython__(self, pyx_state: object, /) -> None: ...
41+
def __reduce_cython__(self) -> Any: ...
6142

62-
@property
63-
def state(self) -> _SeedSeqState: ...
43+
@type_check_only
44+
class _GenerateStateMixin(_CythonMixin):
45+
@overload
46+
def generate_state(self, /, n_words: int, dtype: _ToDTypeUInt32 = ...) -> NDArray[np.uint32]: ...
47+
@overload
48+
def generate_state(self, /, n_words: int, dtype: _ToDTypeUInt64) -> NDArray[np.uint64]: ...
49+
@overload
50+
def generate_state(self, /, n_words: int, dtype: _ToDTypeUInt32 | _ToDTypeUInt64 = ...) -> NDArray[np.uint32 | np.uint64]: ...
6451

65-
#
66-
def __init__(
67-
self,
68-
entropy: int | Sequence[int] | _ArrayLikeInt_co | None = None,
69-
*,
70-
spawn_key: Sequence[int] = ...,
71-
pool_size: int = ...,
72-
n_children_spawned: int = ...,
73-
) -> None: ...
74-
def generate_state(
75-
self,
76-
n_words: int,
77-
dtype: _DTypeLikeUint32 | _DTypeLikeUint64 = ...,
78-
) -> NDArray[np.uint32 | np.uint64]: ...
79-
def spawn(self, n_children: int) -> list[SeedSequence]: ...
52+
###
8053

81-
class BitGenerator(abc.ABC, Generic[_StateT]):
54+
class BitGenerator(_CythonMixin, abc.ABC, Generic[_StateT]):
8255
lock: Lock
8356

84-
def __init__(self, /, seed: _ArrayLikeInt_co | SeedSequence | None = None) -> None: ...
85-
def __getstate__(self) -> tuple[_StateT, ISeedSequence]: ...
86-
def __setstate__(self, state_seed_seq: _StateT | tuple[Mapping[str, Any], ISeedSequence]) -> None: ...
87-
def __reduce__(self) -> tuple[Callable[[str], Self], tuple[str], tuple[Mapping[str, Any], ISeedSequence]]: ...
88-
8957
#
9058
@property
9159
def state(self, /) -> _StateT: ...
9260
@state.setter
9361
def state(self, state: _StateT, /) -> None: ...
94-
95-
#
9662
@property
9763
def seed_seq(self) -> ISeedSequence: ...
9864
@property
9965
def ctypes(self) -> _Interface: ...
10066
@property
10167
def cffi(self) -> _Interface: ...
68+
@property
69+
def capsule(self) -> CapsuleType: ...
10270

10371
#
104-
def spawn(self, n_children: int) -> list[Self]: ...
72+
def __init__(self, /, seed: _ArrayLikeInt_co | SeedSequence | None = None) -> None: ...
73+
def __reduce__(self) -> tuple[Callable[[str], Self], tuple[str], tuple[Mapping[str, Any], ISeedSequence]]: ...
74+
def spawn(self, /, n_children: int) -> list[Self]: ...
75+
def _benchmark(self, /, cnt: int, method: str = "uint64") -> None: ...
10576

10677
#
10778
@overload
108-
def random_raw(self, /, size: None = None, output: Literal[True] = True) -> int: ...
79+
def random_raw(self, /, size: None = None, output: L[True] = True) -> int: ...
10980
@overload
110-
def random_raw(self, /, size: _ShapeLike, output: Literal[True] = True) -> NDArray[np.uint64]: ...
81+
def random_raw(self, /, size: _ShapeLike, output: L[True] = True) -> NDArray[np.uint64]: ...
11182
@overload
112-
def random_raw(self, /, size: _ShapeLike | None, output: Literal[False]) -> None: ...
83+
def random_raw(self, /, size: _ShapeLike | None, output: L[False]) -> None: ...
11384
@overload
114-
def random_raw(self, /, size: _ShapeLike | None = None, *, output: Literal[False]) -> None: ...
85+
def random_raw(self, /, size: _ShapeLike | None = None, *, output: L[False]) -> None: ...
86+
87+
###
88+
89+
class ISeedSequence(abc.ABC):
90+
@abc.abstractmethod
91+
def generate_state(self, /, n_words: int, dtype: _ToDTypeUInt32 | _ToDTypeUInt64 = ...) -> NDArray[np.uint32 | np.uint64]: ...
92+
93+
class ISpawnableSeedSequence(ISeedSequence, abc.ABC):
94+
@abc.abstractmethod
95+
def spawn(self, /, n_children: int) -> list[Self]: ...
96+
97+
class SeedlessSeedSequence(_GenerateStateMixin, ISpawnableSeedSequence):
98+
def spawn(self, /, n_children: int) -> list[Self]: ...
11599

100+
class SeedSequence(_GenerateStateMixin, ISpawnableSeedSequence):
101+
__pyx_vtable__: ClassVar[CapsuleType] = ...
102+
103+
entropy: int | Sequence[int] | None
104+
spawn_key: tuple[int, ...]
105+
pool_size: int
106+
n_children_spawned: int
107+
pool: NDArray[np.uint32]
108+
109+
def __init__(
110+
self,
111+
/,
112+
entropy: _ArrayLikeInt_co | None = None,
113+
*,
114+
spawn_key: Sequence[int] = (),
115+
pool_size: int = 4,
116+
n_children_spawned: int = ...,
117+
) -> None: ...
118+
119+
#
120+
def spawn(self, /, n_children: int) -> list[Self]: ...
116121
#
117-
def _benchmark(self, cnt: int, method: str = ...) -> None: ...
122+
@property
123+
def state(self) -> _SeedSeqState: ...

test/static/accept/random.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ assert_type(sfc64.lock, threading.Lock)
6666
assert_type(seed_seq.pool, npt.NDArray[np.uint32])
6767
assert_type(seed_seq.entropy, int | Sequence[int] | None)
6868
assert_type(seed_seq.spawn(1), list[np.random.SeedSequence])
69-
assert_type(seed_seq.generate_state(8, "uint32"), npt.NDArray[np.uint32 | np.uint64])
70-
assert_type(seed_seq.generate_state(8, "uint64"), npt.NDArray[np.uint32 | np.uint64])
69+
assert_type(seed_seq.generate_state(8, "uint32"), npt.NDArray[np.uint32])
70+
assert_type(seed_seq.generate_state(8, "uint64"), npt.NDArray[np.uint64])
7171

7272
assert_type(def_gen.standard_normal(), float)
7373
assert_type(def_gen.standard_normal(dtype=np.float32), float)

test/static/reject/random.pyi

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,16 @@ np.random.SeedSequence(SEED_STR) # type: ignore[arg-type] # pyright: ignore[re
2424

2525
seed_seq: np.random.bit_generator.SeedSequence = ...
2626
seed_seq.spawn(11.5) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
27-
seed_seq.generate_state(3.2) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
28-
seed_seq.generate_state(3, np.uint8) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
29-
seed_seq.generate_state(3, "uint8") # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
30-
seed_seq.generate_state(3, "u1") # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
31-
seed_seq.generate_state(3, np.uint16) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
32-
seed_seq.generate_state(3, "uint16") # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
33-
seed_seq.generate_state(3, "u2") # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
34-
seed_seq.generate_state(3, np.int32) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
35-
seed_seq.generate_state(3, "int32") # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
36-
seed_seq.generate_state(3, "i4") # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
27+
seed_seq.generate_state(3.2) # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType]
28+
seed_seq.generate_state(3, np.uint8) # type: ignore[arg-type] # pyright: ignore[reportCallIssue, reportArgumentType]
29+
seed_seq.generate_state(3, "uint8") # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType]
30+
seed_seq.generate_state(3, "u1") # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType]
31+
seed_seq.generate_state(3, np.uint16) # type: ignore[arg-type] # pyright: ignore[reportCallIssue, reportArgumentType]
32+
seed_seq.generate_state(3, "uint16") # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType]
33+
seed_seq.generate_state(3, "u2") # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType]
34+
seed_seq.generate_state(3, np.int32) # type: ignore[arg-type] # pyright: ignore[reportCallIssue, reportArgumentType]
35+
seed_seq.generate_state(3, "int32") # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType]
36+
seed_seq.generate_state(3, "i4") # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType]
3737

3838
# Bit Generators
3939
np.random.MT19937(SEED_FLOAT) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]

0 commit comments

Comments
 (0)