Skip to content

Commit 0ac8691

Browse files
authored
Merge pull request numpy#27771 from andyfaff/random
ENH: ``default_rng`` coerces ``RandomState`` to ``Generator``
2 parents f833f33 + e7446b9 commit 0ac8691

File tree

3 files changed

+28
-3
lines changed

3 files changed

+28
-3
lines changed

numpy/random/_generator.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ from numpy import (
1717
uint32,
1818
uint64,
1919
)
20-
from numpy.random import BitGenerator, SeedSequence
20+
from numpy.random import BitGenerator, SeedSequence, RandomState
2121
from numpy._typing import (
2222
ArrayLike,
2323
NDArray,
@@ -782,5 +782,5 @@ class Generator:
782782
def shuffle(self, x: ArrayLike, axis: int = ...) -> None: ...
783783

784784
def default_rng(
785-
seed: None | _ArrayLikeInt_co | SeedSequence | BitGenerator | Generator = ...
785+
seed: None | _ArrayLikeInt_co | SeedSequence | BitGenerator | Generator | RandomState = ...
786786
) -> Generator: ...

numpy/random/_generator.pyx

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ from ._bounded_integers cimport (_rand_bool, _rand_int32, _rand_int64,
2222
_rand_int16, _rand_int8, _rand_uint64, _rand_uint32, _rand_uint16,
2323
_rand_uint8, _gen_mask)
2424
from ._pcg64 import PCG64
25+
from ._mt19937 import MT19937
2526
from numpy.random cimport bitgen_t
2627
from ._common cimport (POISSON_LAM_MAX, CONS_POSITIVE, CONS_NONE,
2728
CONS_NON_NEGATIVE, CONS_BOUNDED_0_1, CONS_BOUNDED_GT_0_1,
@@ -4990,14 +4991,15 @@ def default_rng(seed=None):
49904991
49914992
Parameters
49924993
----------
4993-
seed : {None, int, array_like[ints], SeedSequence, BitGenerator, Generator}, optional
4994+
seed : {None, int, array_like[ints], SeedSequence, BitGenerator, Generator, RandomState}, optional
49944995
A seed to initialize the `BitGenerator`. If None, then fresh,
49954996
unpredictable entropy will be pulled from the OS. If an ``int`` or
49964997
``array_like[ints]`` is passed, then all values must be non-negative and will be
49974998
passed to `SeedSequence` to derive the initial `BitGenerator` state. One may also
49984999
pass in a `SeedSequence` instance.
49995000
Additionally, when passed a `BitGenerator`, it will be wrapped by
50005001
`Generator`. If passed a `Generator`, it will be returned unaltered.
5002+
When passed a legacy `RandomState` instance it will be coerced to a `Generator`.
50015003
50025004
Returns
50035005
-------
@@ -5070,6 +5072,10 @@ def default_rng(seed=None):
50705072
elif isinstance(seed, Generator):
50715073
# Pass through a Generator.
50725074
return seed
5075+
elif isinstance(seed, np.random.RandomState):
5076+
gen = np.random.Generator(seed._bit_generator)
5077+
return gen
5078+
50735079
# Otherwise we need to instantiate a new BitGenerator and Generator as
50745080
# normal.
50755081
return Generator(PCG64(seed))

numpy/random/tests/test_direct.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,3 +559,22 @@ def test_passthrough(self):
559559
rg2 = default_rng(rg)
560560
assert rg2 is rg
561561
assert rg2.bit_generator is bg
562+
563+
def test_coercion_RandomState_Generator(self):
564+
# use default_rng to coerce RandomState to Generator
565+
rs = RandomState(1234)
566+
rg = default_rng(rs)
567+
assert isinstance(rg.bit_generator, MT19937)
568+
assert rg.bit_generator is rs._bit_generator
569+
570+
# RandomState with a non MT19937 bit generator
571+
_original = np.random.get_bit_generator()
572+
bg = PCG64(12342298)
573+
np.random.set_bit_generator(bg)
574+
rs = np.random.mtrand._rand
575+
rg = default_rng(rs)
576+
assert rg.bit_generator is bg
577+
578+
# vital to get global state back to original, otherwise
579+
# other tests start to fail.
580+
np.random.set_bit_generator(_original)

0 commit comments

Comments
 (0)