Skip to content

Commit f9e02f7

Browse files
committed
ENH: default_rng coerces RandomState to Generator
1 parent b79f912 commit f9e02f7

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

numpy/random/_generator.pyx

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ from ._bounded_integers cimport (_rand_bool, _rand_int32, _rand_int64,
2222
_rand_int16, _rand_int8, _rand_uint64, _rand_uint32, _rand_uint16,
2323
_rand_uint8, _gen_mask)
2424
from ._pcg64 import PCG64
25+
from ._mt19937 import MT19937
2526
from numpy.random cimport bitgen_t
2627
from ._common cimport (POISSON_LAM_MAX, CONS_POSITIVE, CONS_NONE,
2728
CONS_NON_NEGATIVE, CONS_BOUNDED_0_1, CONS_BOUNDED_GT_0_1,
@@ -4990,14 +4991,15 @@ def default_rng(seed=None):
49904991
49914992
Parameters
49924993
----------
4993-
seed : {None, int, array_like[ints], SeedSequence, BitGenerator, Generator}, optional
4994+
seed : {None, int, array_like[ints], SeedSequence, BitGenerator, Generator, RandomState}, optional
49944995
A seed to initialize the `BitGenerator`. If None, then fresh,
49954996
unpredictable entropy will be pulled from the OS. If an ``int`` or
49964997
``array_like[ints]`` is passed, then all values must be non-negative and will be
49974998
passed to `SeedSequence` to derive the initial `BitGenerator` state. One may also
49984999
pass in a `SeedSequence` instance.
49995000
Additionally, when passed a `BitGenerator`, it will be wrapped by
50005001
`Generator`. If passed a `Generator`, it will be returned unaltered.
5002+
When passed a legacy `RandomState` instance it will be coerced to a `Generator`.
50015003
50025004
Returns
50035005
-------
@@ -5070,6 +5072,14 @@ def default_rng(seed=None):
50705072
elif isinstance(seed, Generator):
50715073
# Pass through a Generator.
50725074
return seed
5075+
elif isinstance(seed, np.random.RandomState):
5076+
rs_state = seed.get_state(legacy=False)
5077+
klass = getattr(np.random, rs_state['bit_generator'])
5078+
bg = klass()
5079+
bg.state = rs_state
5080+
gen = np.random.Generator(bg)
5081+
return gen
5082+
50735083
# Otherwise we need to instantiate a new BitGenerator and Generator as
50745084
# normal.
50755085
return Generator(PCG64(seed))

numpy/random/tests/test_direct.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,3 +559,23 @@ def test_passthrough(self):
559559
rg2 = default_rng(rg)
560560
assert rg2 is rg
561561
assert rg2.bit_generator is bg
562+
563+
def test_coercion_RandomState_Generator(self):
564+
# use default_rng to coerce RandomState to Generator
565+
rs = RandomState(1234)
566+
rg = default_rng(rs)
567+
assert isinstance(rg.bit_generator, MT19937)
568+
569+
assert_allclose(rg.random(), rs.rand())
570+
571+
# RandomState with a non MT19937 bit generator
572+
_original = np.random.get_bit_generator()
573+
bg = PCG64(12342298)
574+
np.random.set_bit_generator(bg)
575+
rs = np.random.mtrand._rand
576+
rg = default_rng(rs)
577+
assert_allclose(rg.random(), rs.rand())
578+
579+
# vital to get global state back to original, otherwise
580+
# other tests start to fail.
581+
np.random.set_bit_generator(_original)

0 commit comments

Comments
 (0)