Skip to content

Commit 71ab906

Browse files
authored
Merge pull request numpy#26089 from mtsokol/random-typing-update
TYP: Adjust typing for `np.random.integers` and `np.random.randint`
2 parents 56dab50 + f84ea13 commit 71ab906

File tree

3 files changed

+374
-166
lines changed

3 files changed

+374
-166
lines changed

numpy/random/_generator.pyi

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ class Generator:
210210
self,
211211
low: int,
212212
high: None | int = ...,
213+
size: None = ...,
213214
) -> int: ...
214215
@overload
215216
def integers( # type: ignore[misc]
@@ -221,6 +222,15 @@ class Generator:
221222
endpoint: bool = ...,
222223
) -> bool: ...
223224
@overload
225+
def integers( # type: ignore[misc]
226+
self,
227+
low: int,
228+
high: None | int = ...,
229+
size: None = ...,
230+
dtype: type[np.bool] = ...,
231+
endpoint: bool = ...,
232+
) -> np.bool: ...
233+
@overload
224234
def integers( # type: ignore[misc]
225235
self,
226236
low: int,
@@ -230,6 +240,96 @@ class Generator:
230240
endpoint: bool = ...,
231241
) -> int: ...
232242
@overload
243+
def integers( # type: ignore[misc]
244+
self,
245+
low: int,
246+
high: None | int = ...,
247+
size: None = ...,
248+
dtype: dtype[uint8] | type[uint8] | _UInt8Codes | _SupportsDType[dtype[uint8]] = ...,
249+
endpoint: bool = ...,
250+
) -> uint8: ...
251+
@overload
252+
def integers( # type: ignore[misc]
253+
self,
254+
low: int,
255+
high: None | int = ...,
256+
size: None = ...,
257+
dtype: dtype[uint16] | type[uint16] | _UInt16Codes | _SupportsDType[dtype[uint16]] = ...,
258+
endpoint: bool = ...,
259+
) -> uint16: ...
260+
@overload
261+
def integers( # type: ignore[misc]
262+
self,
263+
low: int,
264+
high: None | int = ...,
265+
size: None = ...,
266+
dtype: dtype[uint32] | type[uint32] | _UInt32Codes | _SupportsDType[dtype[uint32]] = ...,
267+
endpoint: bool = ...,
268+
) -> uint32: ...
269+
@overload
270+
def integers( # type: ignore[misc]
271+
self,
272+
low: int,
273+
high: None | int = ...,
274+
size: None = ...,
275+
dtype: dtype[uint] | type[uint] | _UIntCodes | _SupportsDType[dtype[uint]] = ...,
276+
endpoint: bool = ...,
277+
) -> uint: ...
278+
@overload
279+
def integers( # type: ignore[misc]
280+
self,
281+
low: int,
282+
high: None | int = ...,
283+
size: None = ...,
284+
dtype: dtype[uint64] | type[uint64] | _UInt64Codes | _SupportsDType[dtype[uint64]] = ...,
285+
endpoint: bool = ...,
286+
) -> uint64: ...
287+
@overload
288+
def integers( # type: ignore[misc]
289+
self,
290+
low: int,
291+
high: None | int = ...,
292+
size: None = ...,
293+
dtype: dtype[int8] | type[int8] | _Int8Codes | _SupportsDType[dtype[int8]] = ...,
294+
endpoint: bool = ...,
295+
) -> int8: ...
296+
@overload
297+
def integers( # type: ignore[misc]
298+
self,
299+
low: int,
300+
high: None | int = ...,
301+
size: None = ...,
302+
dtype: dtype[int16] | type[int16] | _Int16Codes | _SupportsDType[dtype[int16]] = ...,
303+
endpoint: bool = ...,
304+
) -> int16: ...
305+
@overload
306+
def integers( # type: ignore[misc]
307+
self,
308+
low: int,
309+
high: None | int = ...,
310+
size: None = ...,
311+
dtype: dtype[int32] | type[int32] | _Int32Codes | _SupportsDType[dtype[int32]] = ...,
312+
endpoint: bool = ...,
313+
) -> int32: ...
314+
@overload
315+
def integers( # type: ignore[misc]
316+
self,
317+
low: int,
318+
high: None | int = ...,
319+
size: None = ...,
320+
dtype: dtype[int_] | type[int] | type[int_] | _IntCodes | _SupportsDType[dtype[int_]] = ...,
321+
endpoint: bool = ...,
322+
) -> int_: ...
323+
@overload
324+
def integers( # type: ignore[misc]
325+
self,
326+
low: int,
327+
high: None | int = ...,
328+
size: None = ...,
329+
dtype: dtype[int64] | type[int64] | _Int64Codes | _SupportsDType[dtype[int64]] = ...,
330+
endpoint: bool = ...,
331+
) -> int64: ...
332+
@overload
233333
def integers( # type: ignore[misc]
234334
self,
235335
low: _ArrayLikeInt_co,

numpy/random/mtrand.pyi

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@ from numpy import (
1111
int16,
1212
int32,
1313
int64,
14+
int_,
1415
long,
15-
ulong,
1616
uint8,
1717
uint16,
1818
uint32,
1919
uint64,
20+
uint,
21+
ulong,
2022
)
2123
from numpy.random.bit_generator import BitGenerator
2224
from numpy._typing import (
@@ -34,6 +36,7 @@ from numpy._typing import (
3436
_Int16Codes,
3537
_Int32Codes,
3638
_Int64Codes,
39+
_IntCodes,
3740
_LongCodes,
3841
_ShapeLike,
3942
_SingleCodes,
@@ -42,6 +45,7 @@ from numpy._typing import (
4245
_UInt16Codes,
4346
_UInt32Codes,
4447
_UInt64Codes,
48+
_UIntCodes,
4549
_ULongCodes,
4650
)
4751

@@ -114,6 +118,7 @@ class RandomState:
114118
self,
115119
low: int,
116120
high: None | int = ...,
121+
size: None = ...,
117122
) -> int: ...
118123
@overload
119124
def randint( # type: ignore[misc]
@@ -124,6 +129,14 @@ class RandomState:
124129
dtype: type[bool] = ...,
125130
) -> bool: ...
126131
@overload
132+
def randint( # type: ignore[misc]
133+
self,
134+
low: int,
135+
high: None | int = ...,
136+
size: None = ...,
137+
dtype: type[np.bool] = ...,
138+
) -> np.bool: ...
139+
@overload
127140
def randint( # type: ignore[misc]
128141
self,
129142
low: int,
@@ -132,6 +145,102 @@ class RandomState:
132145
dtype: type[int] = ...,
133146
) -> int: ...
134147
@overload
148+
def randint( # type: ignore[misc]
149+
self,
150+
low: int,
151+
high: None | int = ...,
152+
size: None = ...,
153+
dtype: dtype[uint8] | type[uint8] | _UInt8Codes | _SupportsDType[dtype[uint8]] = ...,
154+
) -> uint8: ...
155+
@overload
156+
def randint( # type: ignore[misc]
157+
self,
158+
low: int,
159+
high: None | int = ...,
160+
size: None = ...,
161+
dtype: dtype[uint16] | type[uint16] | _UInt16Codes | _SupportsDType[dtype[uint16]] = ...,
162+
) -> uint16: ...
163+
@overload
164+
def randint( # type: ignore[misc]
165+
self,
166+
low: int,
167+
high: None | int = ...,
168+
size: None = ...,
169+
dtype: dtype[uint32] | type[uint32] | _UInt32Codes | _SupportsDType[dtype[uint32]] = ...,
170+
) -> uint32: ...
171+
@overload
172+
def randint( # type: ignore[misc]
173+
self,
174+
low: int,
175+
high: None | int = ...,
176+
size: None = ...,
177+
dtype: dtype[uint] | type[uint] | _UIntCodes | _SupportsDType[dtype[uint]] = ...,
178+
) -> uint: ...
179+
@overload
180+
def randint( # type: ignore[misc]
181+
self,
182+
low: int,
183+
high: None | int = ...,
184+
size: None = ...,
185+
dtype: dtype[ulong] | type[ulong] | _ULongCodes | _SupportsDType[dtype[ulong]] = ...,
186+
) -> ulong: ...
187+
@overload
188+
def randint( # type: ignore[misc]
189+
self,
190+
low: int,
191+
high: None | int = ...,
192+
size: None = ...,
193+
dtype: dtype[uint64] | type[uint64] | _UInt64Codes | _SupportsDType[dtype[uint64]] = ...,
194+
) -> uint64: ...
195+
@overload
196+
def randint( # type: ignore[misc]
197+
self,
198+
low: int,
199+
high: None | int = ...,
200+
size: None = ...,
201+
dtype: dtype[int8] | type[int8] | _Int8Codes | _SupportsDType[dtype[int8]] = ...,
202+
) -> int8: ...
203+
@overload
204+
def randint( # type: ignore[misc]
205+
self,
206+
low: int,
207+
high: None | int = ...,
208+
size: None = ...,
209+
dtype: dtype[int16] | type[int16] | _Int16Codes | _SupportsDType[dtype[int16]] = ...,
210+
) -> int16: ...
211+
@overload
212+
def randint( # type: ignore[misc]
213+
self,
214+
low: int,
215+
high: None | int = ...,
216+
size: None = ...,
217+
dtype: dtype[int32] | type[int32] | _Int32Codes | _SupportsDType[dtype[int32]] = ...,
218+
) -> int32: ...
219+
@overload
220+
def randint( # type: ignore[misc]
221+
self,
222+
low: int,
223+
high: None | int = ...,
224+
size: None = ...,
225+
dtype: dtype[int_] | type[int_] | _IntCodes | _SupportsDType[dtype[int_]] = ...,
226+
) -> int_: ...
227+
@overload
228+
def randint( # type: ignore[misc]
229+
self,
230+
low: int,
231+
high: None | int = ...,
232+
size: None = ...,
233+
dtype: dtype[long] | type[long] | _LongCodes | _SupportsDType[dtype[long]] = ...,
234+
) -> long: ...
235+
@overload
236+
def randint( # type: ignore[misc]
237+
self,
238+
low: int,
239+
high: None | int = ...,
240+
size: None = ...,
241+
dtype: dtype[int64] | type[int64] | _Int64Codes | _SupportsDType[dtype[int64]] = ...,
242+
) -> int64: ...
243+
@overload
135244
def randint( # type: ignore[misc]
136245
self,
137246
low: _ArrayLikeInt_co,

0 commit comments

Comments
 (0)