|
15 | 15 | import functools
|
16 | 16 | import warnings
|
17 | 17 |
|
| 18 | +from collections import namedtuple |
18 | 19 | from collections.abc import Sequence
|
19 | 20 | from copy import deepcopy
|
20 | 21 | from typing import NewType, cast
|
@@ -601,6 +602,31 @@ def update(
|
601 | 602 | return None
|
602 | 603 |
|
603 | 604 |
|
| 605 | +RandomGeneratorState = namedtuple("RandomGeneratorState", ["bit_generator_state", "seed_seq_state"]) |
| 606 | + |
| 607 | + |
| 608 | +def get_state_from_generator( |
| 609 | + rng: np.random.Generator | np.random.BitGenerator, |
| 610 | +) -> RandomGeneratorState: |
| 611 | + assert isinstance(rng, (np.random.Generator | np.random.BitGenerator)) |
| 612 | + bit_gen: np.random.BitGenerator = ( |
| 613 | + rng.bit_generator if isinstance(rng, np.random.Generator) else rng |
| 614 | + ) |
| 615 | + |
| 616 | + return RandomGeneratorState( |
| 617 | + bit_generator_state=bit_gen.state, |
| 618 | + seed_seq_state=bit_gen.seed_seq.state, # type: ignore[attr-defined] |
| 619 | + ) |
| 620 | + |
| 621 | + |
| 622 | +def random_generator_from_state(state: RandomGeneratorState) -> np.random.Generator: |
| 623 | + seed_seq = np.random.SeedSequence(**state.seed_seq_state) |
| 624 | + bit_generator_class = getattr(np.random, state.bit_generator_state["bit_generator"]) |
| 625 | + bit_generator = bit_generator_class(seed_seq) |
| 626 | + bit_generator.state = state.bit_generator_state |
| 627 | + return np.random.Generator(bit_generator) |
| 628 | + |
| 629 | + |
604 | 630 | def get_random_generator(
|
605 | 631 | seed: RandomGenerator | np.random.RandomState = None, copy: bool = True
|
606 | 632 | ) -> np.random.Generator:
|
@@ -645,6 +671,10 @@ def get_random_generator(
|
645 | 671 | # In the former case, it will return seed, in the latter it will return
|
646 | 672 | # a new Generator object that has the same BitGenerator. This would potentially
|
647 | 673 | # make the new generator be shared across many users. To avoid this, we
|
648 |
| - # deepcopy by default. |
| 674 | + # copy by default. |
| 675 | + # Also, because of https://github.com/numpy/numpy/issues/27727, we can't use |
| 676 | + # deepcopy. We must rebuild a Generator without losing the SeedSequence information |
| 677 | + if isinstance(seed, np.random.Generator | np.random.BitGenerator): |
| 678 | + return random_generator_from_state(get_state_from_generator(seed)) |
649 | 679 | seed = deepcopy(seed)
|
650 | 680 | return np.random.default_rng(seed)
|
0 commit comments