Skip to content

Commit d567754

Browse files
authored
BUG: Ensure seed sequences are restored through pickling (numpy#26260)
Explicity store and restore seed sequence closes numpy#26234 --- * BUG: Ensure seed sequences are restored through pickling Explicity store and restore seed sequence closes numpy#26234 * CLN: Simplify refactor Make more use of set and getstate to avoid changes in the pickling functions * BUG: Correct behavior for legacy pickles Add test for legacy pickles Include pickles for tests * MAINT: Correct types for pickle related functions * REF: Switch from string to type * REF: Swtich to returning bit generators Explicitly return bit generator rather than ctor
1 parent c3ce003 commit d567754

16 files changed

+167
-48
lines changed

numpy/random/_generator.pyi

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,12 @@ class Generator:
6868
def __init__(self, bit_generator: BitGenerator) -> None: ...
6969
def __repr__(self) -> str: ...
7070
def __str__(self) -> str: ...
71-
def __getstate__(self) -> dict[str, Any]: ...
72-
def __setstate__(self, state: dict[str, Any]) -> None: ...
73-
def __reduce__(self) -> tuple[Callable[[str], Generator], tuple[str], dict[str, Any]]: ...
71+
def __getstate__(self) -> None: ...
72+
def __setstate__(self, state: dict[str, Any] | None) -> None: ...
73+
def __reduce__(self) -> tuple[
74+
Callable[[BitGenerator], Generator],
75+
tuple[BitGenerator],
76+
None]: ...
7477
@property
7578
def bit_generator(self) -> BitGenerator: ...
7679
def spawn(self, n_children: int) -> list[Generator]: ...

numpy/random/_generator.pyx

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,17 +214,19 @@ cdef class Generator:
214214

215215
# Pickling support:
216216
def __getstate__(self):
217-
return self.bit_generator.state
217+
return None
218218

219-
def __setstate__(self, state):
220-
self.bit_generator.state = state
219+
def __setstate__(self, bit_gen):
220+
if isinstance(bit_gen, dict):
221+
# Legacy path
222+
# Prior to 2.0.x only the state of the underlying bit generator
223+
# was preserved and any seed sequence information was lost
224+
self.bit_generator.state = bit_gen
221225

222226
def __reduce__(self):
223-
ctor, name_tpl, state = self._bit_generator.__reduce__()
224-
225227
from ._pickle import __generator_ctor
226-
# Requirements of __generator_ctor are (name, ctor)
227-
return __generator_ctor, (name_tpl[0], ctor), state
228+
# Requirements of __generator_ctor are (bit_generator, )
229+
return __generator_ctor, (self._bit_generator, ), None
228230

229231
@property
230232
def bit_generator(self):

numpy/random/_pickle.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .bit_generator import BitGenerator
12
from .mtrand import RandomState
23
from ._philox import Philox
34
from ._pcg64 import PCG64, PCG64DXSM
@@ -14,27 +15,30 @@
1415
}
1516

1617

17-
def __bit_generator_ctor(bit_generator_name='MT19937'):
18+
def __bit_generator_ctor(bit_generator: str | type[BitGenerator] = 'MT19937'):
1819
"""
1920
Pickling helper function that returns a bit generator object
2021
2122
Parameters
2223
----------
23-
bit_generator_name : str
24-
String containing the name of the BitGenerator
24+
bit_generator : type[BitGenerator] or str
25+
BitGenerator class or string containing the name of the BitGenerator
2526
2627
Returns
2728
-------
28-
bit_generator : BitGenerator
29+
BitGenerator
2930
BitGenerator instance
3031
"""
31-
if bit_generator_name in BitGenerators:
32-
bit_generator = BitGenerators[bit_generator_name]
32+
if isinstance(bit_generator, type):
33+
bit_gen_class = bit_generator
34+
elif bit_generator in BitGenerators:
35+
bit_gen_class = BitGenerators[bit_generator]
3336
else:
34-
raise ValueError(str(bit_generator_name) + ' is not a known '
35-
'BitGenerator module.')
37+
raise ValueError(
38+
str(bit_generator) + ' is not a known BitGenerator module.'
39+
)
3640

37-
return bit_generator()
41+
return bit_gen_class()
3842

3943

4044
def __generator_ctor(bit_generator_name="MT19937",
@@ -44,8 +48,9 @@ def __generator_ctor(bit_generator_name="MT19937",
4448
4549
Parameters
4650
----------
47-
bit_generator_name : str
48-
String containing the core BitGenerator's name
51+
bit_generator_name : str or BitGenerator
52+
String containing the core BitGenerator's name or a
53+
BitGenerator instance
4954
bit_generator_ctor : callable, optional
5055
Callable function that takes bit_generator_name as its only argument
5156
and returns an instantized bit generator.
@@ -55,6 +60,9 @@ def __generator_ctor(bit_generator_name="MT19937",
5560
rg : Generator
5661
Generator using the named core BitGenerator
5762
"""
63+
if isinstance(bit_generator_name, BitGenerator):
64+
return Generator(bit_generator_name)
65+
# Legacy path that uses a bit generator name and ctor
5866
return Generator(bit_generator_ctor(bit_generator_name))
5967

6068

@@ -76,5 +84,6 @@ def __randomstate_ctor(bit_generator_name="MT19937",
7684
rs : RandomState
7785
Legacy RandomState using the named core BitGenerator
7886
"""
79-
87+
if isinstance(bit_generator_name, BitGenerator):
88+
return RandomState(bit_generator_name)
8089
return RandomState(bit_generator_ctor(bit_generator_name))

numpy/random/bit_generator.pyi

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,17 @@ class SeedSequence(ISpawnableSeedSequence):
9292
class BitGenerator(abc.ABC):
9393
lock: Lock
9494
def __init__(self, seed: None | _ArrayLikeInt_co | SeedSequence = ...) -> None: ...
95-
def __getstate__(self) -> dict[str, Any]: ...
96-
def __setstate__(self, state: dict[str, Any]) -> None: ...
95+
def __getstate__(self) -> tuple[dict[str, Any], ISeedSequence]: ...
96+
def __setstate__(
97+
self, state_seed_seq: dict[str, Any] | tuple[dict[str, Any], ISeedSequence]
98+
) -> None: ...
9799
def __reduce__(
98100
self,
99-
) -> tuple[Callable[[str], BitGenerator], tuple[str], tuple[dict[str, Any]]]: ...
101+
) -> tuple[
102+
Callable[[str], BitGenerator],
103+
tuple[str],
104+
tuple[dict[str, Any], ISeedSequence]
105+
]: ...
100106
@abc.abstractmethod
101107
@property
102108
def state(self) -> Mapping[str, Any]: ...

numpy/random/bit_generator.pyx

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -537,14 +537,27 @@ cdef class BitGenerator():
537537

538538
# Pickling support:
539539
def __getstate__(self):
540-
return self.state
540+
return self.state, self._seed_seq
541541

542-
def __setstate__(self, state):
543-
self.state = state
542+
def __setstate__(self, state_seed_seq):
543+
544+
if isinstance(state_seed_seq, dict):
545+
# Legacy path
546+
# Prior to 2.0.x only the state of the underlying bit generator
547+
# was preserved and any seed sequence information was lost
548+
self.state = state_seed_seq
549+
else:
550+
self._seed_seq = state_seed_seq[1]
551+
self.state = state_seed_seq[0]
544552

545553
def __reduce__(self):
546554
from ._pickle import __bit_generator_ctor
547-
return __bit_generator_ctor, (self.state['bit_generator'],), self.state
555+
556+
return (
557+
__bit_generator_ctor,
558+
(type(self), ),
559+
(self.state, self._seed_seq)
560+
)
548561

549562
@property
550563
def state(self):

numpy/random/meson.build

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ py.install_sources(
139139
'tests/data/philox-testset-2.csv',
140140
'tests/data/sfc64-testset-1.csv',
141141
'tests/data/sfc64-testset-2.csv',
142+
'tests/data/sfc64_np126.pkl.gz',
143+
'tests/data/generator_pcg64_np126.pkl.gz',
144+
'tests/data/generator_pcg64_np121.pkl.gz',
142145
],
143146
subdir: 'numpy/random/tests/data'
144147
)

numpy/random/mtrand.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class RandomState:
7373
def __str__(self) -> str: ...
7474
def __getstate__(self) -> dict[str, Any]: ...
7575
def __setstate__(self, state: dict[str, Any]) -> None: ...
76-
def __reduce__(self) -> tuple[Callable[[str], RandomState], tuple[str], dict[str, Any]]: ...
76+
def __reduce__(self) -> tuple[Callable[[BitGenerator], RandomState], tuple[BitGenerator], dict[str, Any]]: ...
7777
def seed(self, seed: None | _ArrayLikeFloat_co = ...) -> None: ...
7878
@overload
7979
def get_state(self, legacy: Literal[False] = ...) -> dict[str, Any]: ...

numpy/random/mtrand.pyx

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,13 @@ cdef class RandomState:
205205
self.set_state(state)
206206

207207
def __reduce__(self):
208-
ctor, name_tpl, _ = self._bit_generator.__reduce__()
209-
210208
from ._pickle import __randomstate_ctor
211-
return __randomstate_ctor, (name_tpl[0], ctor), self.get_state(legacy=False)
209+
# The third argument containing the state is required here since
210+
# RandomState contains state information in addition to the state
211+
# contained in the bit generator that described the gaussian
212+
# generator. This argument is passed to __setstate__ after the
213+
# Generator is created.
214+
return __randomstate_ctor, (self._bit_generator, ), self.get_state(legacy=False)
212215

213216
cdef _initialize_bit_generator(self, bit_generator):
214217
self._bit_generator = bit_generator
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)