Skip to content

Commit bc4bb19

Browse files
Sheppard, KevinSheppard, Kevin
authored andcommitted
FIX: Fix pickling support
Fix pickling support and finalize test Change RNG_NAME to RNG_MOD_NAME to be more explicit and avoid conflict Clean up appveyor
1 parent 2c8a993 commit bc4bb19

File tree

16 files changed

+70
-34
lines changed

16 files changed

+70
-34
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ identical sequence of random numbers for a given seed.
3535
## Plans
3636
It is essentially complete. There are a few rough edges that need to be smoothed.
3737

38-
* Pickling support
3938
* Stream support for MLFG and MRG32K3A
4039
* Creation of additional streams from a RandomState where supported (i.e.
4140
a `next_stream()` method)

appveyor.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,5 @@ build_script:
2222
- pip install . -vvv
2323

2424
test_script:
25-
- echo %cd%
2625
- cd ..
27-
- echo %cd%
2826
- nosetests randomstate

randomstate/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from __future__ import division, absolute_import, print_function
22

3-
from randomstate.prng.mt19937.mt19937 import *
3+
from randomstate.prng.mt19937 import *
44
import randomstate.prng

randomstate/config.pxi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Autogenerated
22

3-
DEF RNG_NAME='xorshift128'
3+
DEF RNG_MOD_NAME='xorshift128'

randomstate/interface.pyx

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,22 @@ np.import_array()
2828
include "config.pxi"
2929
#include "src/common/binomial.pxi"
3030

31-
IF RNG_NAME == 'pcg32':
31+
IF RNG_MOD_NAME == 'pcg32':
3232
include "shims/pcg-32/pcg-32.pxi"
33-
IF RNG_NAME == 'pcg64':
33+
IF RNG_MOD_NAME == 'pcg64':
3434
IF PCG128_EMULATED:
3535
include "shims/pcg-64/pcg-64-emulated.pxi"
3636
ELSE:
3737
include "shims/pcg-64/pcg-64.pxi"
38-
IF RNG_NAME == 'mt19937':
38+
IF RNG_MOD_NAME == 'mt19937':
3939
include "shims/random-kit/random-kit.pxi"
40-
IF RNG_NAME == 'xorshift128':
40+
IF RNG_MOD_NAME == 'xorshift128':
4141
include "shims/xorshift128/xorshift128.pxi"
42-
IF RNG_NAME == 'xorshift1024':
42+
IF RNG_MOD_NAME == 'xorshift1024':
4343
include "shims/xorshift1024/xorshift1024.pxi"
44-
IF RNG_NAME == 'mrg32k3a':
44+
IF RNG_MOD_NAME == 'mrg32k3a':
4545
include "shims/mrg32k3a/mrg32k3a.pxi"
46-
IF RNG_NAME == 'mlfg_1279_861':
46+
IF RNG_MOD_NAME == 'mlfg_1279_861':
4747
include "shims/mlfg-1279-861/mlfg-1279-861.pxi"
4848

4949
IF NORMAL_METHOD == 'inv':
@@ -213,9 +213,7 @@ cdef class RandomState:
213213
self.set_state(state)
214214

215215
def __reduce__(self):
216-
# TODO: This is wrong
217-
# TODO: Removed np.random.__RandomState_ctor - This is needed on a RNG-by-RNG basis
218-
return (randomstate.prng.__generic_ctor, (RNG_NAME,), self.get_state())
216+
return (randomstate.prng.__generic_ctor, (RNG_MOD_NAME,), self.get_state())
219217

220218
IF RNG_NAME == 'mt19937':
221219
def seed(self, seed=None):

randomstate/prng/__init__.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,46 @@
1-
from randomstate.prng.mt19937 import mt19937
2-
from randomstate.prng.mlfg_1279_861 import mlfg_1279_861
3-
from randomstate.prng.mrg32k3a import mrg32k3a
4-
from randomstate.prng.pcg32 import pcg32
5-
from randomstate.prng.pcg64 import pcg64
6-
from randomstate.prng.xorshift1024 import xorshift1024
7-
from randomstate.prng.xorshift128 import xorshift128
1+
from .xorshift128 import xorshift128
2+
from .xorshift1024 import xorshift1024
3+
from .mlfg_1279_861 import mlfg_1279_861
4+
from .mt19937 import mt19937
5+
from .mrg32k3a import mrg32k3a
6+
from .pcg32 import pcg32
7+
from .pcg64 import pcg64
88

9-
def __generic_ctor(rng_name='mt19937'):
10-
print(rng_name)
11-
if rng_name == 'mt19937':
9+
def __generic_ctor(mod_name='mt19937'):
10+
"""
11+
Pickling helper function that returns a mod_name.RandomState object
12+
13+
Parameters
14+
----------
15+
mod_name: str
16+
String containing the module name
17+
18+
Returns
19+
-------
20+
rs: RandomState
21+
RandomState from the module randomstate.prng.mod_name
22+
"""
23+
print(mod_name)
24+
try:
25+
mod_name = mod_name.decode('ascii')
26+
except AttributeError:
27+
pass
28+
print(mod_name)
29+
if mod_name == 'mt19937':
1230
mod = mt19937
13-
elif rng_name == 'mlfg_1279_861':
31+
elif mod_name == 'mlfg_1279_861':
1432
mod = mlfg_1279_861
15-
elif rng_name == 'mrg32k3a':
33+
elif mod_name == 'mrg32k3a':
1634
mod = mrg32k3a
17-
elif rng_name == 'pcg32':
35+
elif mod_name == 'pcg32':
1836
mod = pcg32
19-
elif rng_name == 'pcg64':
37+
elif mod_name == 'pcg64':
2038
mod = pcg64
21-
elif rng_name == 'pcg32':
39+
elif mod_name == 'pcg32':
2240
mod = pcg32
23-
elif rng_name == 'xorshift128+':
41+
elif mod_name == 'xorshift128':
2442
mod = xorshift128
25-
elif rng_name == 'xorshift1024*':
43+
elif mod_name == 'xorshift1024':
2644
mod = xorshift1024
45+
2746
return mod.RandomState(0)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .mlfg_1279_861 import *

randomstate/prng/mrg32k3a/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .mrg32k3a import *

randomstate/prng/mt19937/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .mt19937 import *

randomstate/prng/pcg32/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .pcg32 import *

0 commit comments

Comments
 (0)