Skip to content

Commit f6074f6

Browse files
Implemented suggestion #10
``` In [1]: import mkl_random In [2]: rs = mkl_random.RandomState(1234, brng='NONDETERM') In [3]: rs Out[3]: <mkl_random.mklrand.RandomState at 0x2ba18d58faf0> In [4]: rs.get_state() Out[4]: ('NON_DETERMINISTIC', b'\x02RNG\x14\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\xe0\x00\x00\x00\x00\x00\n\x00\x00\x00') In [6]: rs.seed(126) # brng has not been reset In [7]: rs.get_state() Out[7]: ('NON_DETERMINISTIC', b'\x02RNG\x14\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\xe0\x00\x00\x00\x00\x00\n\x00\x00\x00') In [8]: rs.seed(1234, brng='MT2203') In [9]: rs.get_state() Out[9]: ('MT2203', b"\x02RNG\x14\x00\x00\x00(\x01\x00\x00\x00\x00\x00\x00\x00\x00\x90\x00\x00\x00\x00\x80i\xc0w\xa7\x0eF\xd8L\xfbi\rk\xd6[\t\x80F\xa0\xae\xa2as3\x9eQ\xd5f\x9b\x95>\xb9\x0f\xf5\xbe\xb0p\xe7o\xf e\xe22\xc3\xad\x80D?\xe3X\x14\x86(D\xed(\xb1\xeb\xc6\xc5\xae\x00\xf8m,@\xd8\xd2\xd2\x98\x8b\xe74\x16\xd6\x07\xe3\xa9\x94q\xe5\xf1\xe1\xd8\x94\xf6\xf3\x8e8z\xddYrMym\xd9\x9a\xf1\x1e\xa8v\x97\ x1a\xcdvY\x82\xe9\x086!\xc8\x8db\xe5<\x1aH\xac&\x1c\xf8\x87\xc3\xefm\xc6\x17q\xb9\xda\xdcw\xcd.\xffI!3\xe8\x8e\xd5\x89\x19=\xca\x94\x88\xb2e|\xe5\xa0\xf0\xe3\xe5\xea\x0f\xd7K\xd4\xf4\xe2\x17 _\x1e\x89\x8f\xe7\x9dp\xa8}B\xa22l\x1c\xed\xedvtu6\xf3up\xc2W\xa8\xd9\xcf\xb59\r\xa6@\xce\xc7\x7f\xc2\x05[\xed\xd4y:h&\x8b\xe4\xc9\xcc\xa6\xf9\x07'\xa8g4\xd7\t\x07\xcf#\x94\xe8%\xa6\xa1\xcd\ xc1nd\xcc\x92\x80n\xe1i\x8a\xdc\x0b\x0e|\xc1[\xe2QWX\x98\xc6\x9f\x08\xe3\x1e\x00 W\xd6\xf7\x9c+\x11\xfa\xee\xbf\nH\xe6nmE\x00\x00\x00_\x05\x81\xaf\x80\xef~\xdd\x00\x80\xd7\xef\x08\x01\x0c\x0 2") ```
1 parent 762619c commit f6074f6

File tree

5 files changed

+43
-17
lines changed

5 files changed

+43
-17
lines changed

mkl_random/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
__version__ = '1.1.1'
1+
__version__ = '1.2.0'
22

mkl_random/mklrand.pyx

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2017-2019, Intel Corporation
2+
# Copyright (c) 2017-2020, Intel Corporation
33
#
44
# Redistribution and use in source and binary forms, with or without
55
# modification, are permitted provided that the following conditions are met:
@@ -24,9 +24,6 @@
2424
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

27-
from __future__ import absolute_import
28-
include "numpy.pxd"
29-
3027
cdef extern from "Python.h":
3128
void* PyMem_Malloc(size_t n)
3229
void PyMem_Free(void* buf)
@@ -38,6 +35,7 @@ cdef extern from "Python.h":
3835
void PyErr_Clear()
3936

4037

38+
include "numpy.pxd"
4139
from libc.string cimport memset, memcpy
4240

4341
cdef extern from "math.h":
@@ -83,6 +81,7 @@ cdef extern from "randomkit.h":
8381
void irk_get_state_mkl(irk_state * state, char * buf)
8482
int irk_set_state_mkl(irk_state * state, char * buf)
8583
int irk_get_brng_mkl(irk_state *state) nogil
84+
int irk_get_brng_and_stream_mkl(irk_state *state, unsigned int * stream_id) nogil
8685
int irk_leapfrog_stream_mkl(irk_state *state, int k, int nstreams) nogil
8786
int irk_skipahead_stream_mkl(irk_state *state, long long int nskips) nogil
8887

@@ -181,8 +180,9 @@ ctypedef void (* irk_discd_long_vec)(irk_state *state, npy_intp len, long *res,
181180
ctypedef void (* irk_discdptr_vec)(irk_state *state, npy_intp len, int *res, double *a) nogil
182181

183182

184-
# Initialize numpy
185-
import_array()
183+
cdef int r = _import_array()
184+
if (r<0):
185+
raise ImportError("Failed to import NumPy")
186186

187187
cimport cython
188188
import numpy as np
@@ -888,15 +888,15 @@ _brng_dict_stream_max = {
888888
NONDETERM: 1,
889889
}
890890

891-
def _default_fallback_brng_token_(brng):
891+
cdef irk_brng_t _default_fallback_brng_token_(brng):
892892
cdef irk_brng_t brng_token
893893
warnings.warn(("The basic random generator specification {given} is not recognized. "
894894
"\"MT19937\" will be used instead").format(given=brng),
895895
UserWarning)
896896
brng_token = MT19937
897897
return brng_token
898898

899-
def _parse_brng_token_(brng):
899+
cdef irk_brng_t _parse_brng_token_(brng):
900900
cdef irk_brng_t brng_token
901901

902902
if isinstance(brng, str):
@@ -946,7 +946,7 @@ cdef class RandomState:
946946
"""
947947
RandomState(seed=None, brng='MT19937')
948948
949-
Container for the Mersenne Twister pseudo-random number generator.
949+
Container for the Intel(R) MKL-powered (pseudo-)random number generators.
950950
951951
`RandomState` exposes a number of methods for generating random numbers
952952
drawn from a variety of probability distributions. In addition to the
@@ -1008,9 +1008,9 @@ cdef class RandomState:
10081008
PyMem_Free(self.internal_state)
10091009
self.internal_state = NULL
10101010

1011-
def seed(self, seed=None, brng='MT19937'):
1011+
def seed(self, seed=None, brng=None):
10121012
"""
1013-
seed(seed=None, brng='MT19937')
1013+
seed(seed=None, brng=None)
10141014
10151015
Seed the generator.
10161016
@@ -1023,9 +1023,10 @@ cdef class RandomState:
10231023
Seed for `RandomState`.
10241024
Must be convertible to 32 bit unsigned integers.
10251025
brng : {'MT19937', 'SFMT19937', 'MT2203', 'R250', 'WH', 'MCG31',
1026-
'MCG59', 'MRG32K3A', 'PHILOX4X32X10', 'NONDETERM'}, optional
1026+
'MCG59', 'MRG32K3A', 'PHILOX4X32X10', 'NONDETERM', None}, optional
10271027
Basic pseudo-random number generation algorithms, provided by
1028-
Intel MKL. The default choice is 'MT19937' - the Mersenne Twister.
1028+
Intel MKL. Use `brng==None` to keep the `brng` specified to construct
1029+
the class instance.
10291030
10301031
See Also
10311032
--------
@@ -1041,7 +1042,10 @@ cdef class RandomState:
10411042
cdef unsigned int stream_id
10421043
cdef ndarray obj "arrayObject_obj"
10431044

1044-
brng_token, stream_id = _parse_brng_argument(brng);
1045+
if (brng):
1046+
brng_token, stream_id = _parse_brng_argument(brng);
1047+
else:
1048+
brng_token = <irk_brng_t> irk_get_brng_and_stream_mkl(self.internal_state, &stream_id)
10451049
try:
10461050
if seed is None:
10471051
with self.lock:

mkl_random/src/numpy.pxd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,6 @@ cdef extern from "numpy/arrayobject.h":
133133

134134
dtype PyArray_DescrFromType(int)
135135

136-
void import_array()
137-
138136
# include functions that were once macros in the new api
139137

140138
int PyArray_NDIM(ndarray arr)
@@ -150,3 +148,5 @@ cdef extern from "numpy/arrayobject.h":
150148
int PyArray_TYPE(ndarray arr)
151149
int PyArray_CHKFLAGS(ndarray arr, int flags)
152150
object PyArray_GETITEM(ndarray arr, char *itemptr)
151+
152+
int _import_array()

mkl_random/src/randomkit.c

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,27 @@ int irk_get_brng_mkl(irk_state *state)
145145
return -1;
146146
}
147147

148+
int irk_get_brng_and_stream_mkl(irk_state *state, unsigned int* stream_id)
149+
{
150+
int i, mkl_brng_id = vslGetStreamStateBrng(state->stream);
151+
152+
if ((VSL_BRNG_MT2203 <= mkl_brng_id) && (mkl_brng_id < VSL_BRNG_MT2203 + SIZE_OF_MT2203_FAMILY)) {
153+
*stream_id = (unsigned int)(mkl_brng_id - VSL_BRNG_MT2203);
154+
mkl_brng_id = VSL_BRNG_MT2203;
155+
} else if ((VSL_BRNG_WH <= mkl_brng_id ) && (mkl_brng_id < VSL_BRNG_WH + SIZE_OF_WH_FAMILY)) {
156+
*stream_id = (unsigned int)(mkl_brng_id - VSL_BRNG_WH);
157+
mkl_brng_id = VSL_BRNG_WH;
158+
}
159+
160+
for(i = 0; i < BRNG_KINDS; i++)
161+
if(mkl_brng_id == brng_list[i]) {
162+
*stream_id = (unsigned int)(0);
163+
return i;
164+
}
165+
166+
return -1;
167+
}
168+
148169
void irk_seed_mkl(irk_state *state, const unsigned int seed, const irk_brng_t brng, const unsigned int stream_id)
149170
{
150171
VSLStreamStatePtr stream_loc;

mkl_random/src/randomkit.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ extern int irk_get_stream_size(irk_state *state);
9797
extern void irk_get_state_mkl(irk_state *state, char * buf);
9898
extern int irk_set_state_mkl(irk_state *state, char * buf);
9999
extern int irk_get_brng_mkl(irk_state *state);
100+
extern int irk_get_brng_and_stream_mkl(irk_state *state, unsigned int* stream_id);
100101

101102
extern int irk_leapfrog_stream_mkl(irk_state *state, const int k, const int nstreams);
102103
extern int irk_skipahead_stream_mkl(irk_state *state, const long long int nskip);

0 commit comments

Comments
 (0)