@@ -13,14 +13,14 @@ cimport numpy as np
13
13
cimport cython
14
14
15
15
from libc cimport string
16
- from libc .stdint cimport (uint8_t , uint16_t , uint32_t , uint64_t , int8_t ,
17
- int16_t , int32_t , int64_t , intptr_t )
16
+ from libc .stdint cimport (uint8_t , uint16_t , uint32_t , uint64_t ,
17
+ int8_t , int16_t , int32_t , int64_t , intptr_t )
18
+
18
19
from cpython cimport Py_INCREF
19
- from cpython .mem cimport PyMem_Malloc , PyMem_Free
20
20
21
21
import randomstate
22
22
from binomial cimport binomial_t
23
- from cython_overrides cimport PyFloat_AsDouble , PyInt_AsLong , PyErr_Occurred , PyErr_Clear
23
+ from cython_overrides cimport PyFloat_AsDouble , PyInt_AsLong
24
24
from randomstate .entropy import random_entropy
25
25
26
26
np .import_array ()
@@ -234,7 +234,9 @@ cdef class RandomState:
234
234
raise ValueError ("Seed must be between 0 and 4294967295" )
235
235
obj = obj .astype ('L' , casting = 'unsafe' )
236
236
with self .lock :
237
- set_seed_by_array (& self .rng_state , < unsigned long * > np .PyArray_DATA (obj ), np .PyArray_DIM (obj , 0 ))
237
+ set_seed_by_array (& self .rng_state ,
238
+ < unsigned long * > np .PyArray_DATA (obj ),
239
+ np .PyArray_DIM (obj , 0 ))
238
240
self ._reset_state_variables ()
239
241
240
242
ELIF RS_RNG_SEED == 1 :
@@ -263,12 +265,36 @@ cdef class RandomState:
263
265
--------
264
266
RandomState
265
267
"""
266
- if seed is not None :
267
- if seed < 0 :
268
- raise ValueError ('seed < 0' )
269
- else :
270
- self .__seed = seed = _generate_seed (RS_SEED_NBYTES )
271
- set_seed (& self .rng_state , seed )
268
+ try :
269
+ if seed is not None :
270
+ idx = operator .index (seed )
271
+ if idx < 0 :
272
+ raise ValueError ('seed < 0' )
273
+ else :
274
+ self .__seed = seed = _generate_seed (RS_SEED_NBYTES )
275
+ set_seed (& self .rng_state , seed )
276
+ except TypeError :
277
+ IF RS_SEED_ARRAY_BITS == 32 :
278
+ seed = np .asarray (seed ).astype (np .int64 , casting = 'safe' )
279
+ if ((seed > int (2 ** 32 - 1 )) | (seed < 0 )).any ():
280
+ raise ValueError ("Seed values must be between 0 and "
281
+ "4294967295 (2**32-1)" )
282
+ seed = seed .astype (np .uint32 , casting = 'unsafe' )
283
+ with self .lock :
284
+ set_seed_by_array (& self .rng_state ,
285
+ < uint32_t * > np .PyArray_DATA (seed ),
286
+ np .PyArray_DIM (seed , 0 ))
287
+ ELSE :
288
+ seed = np .asarray (seed ).astype (np .object , casting = 'safe' )
289
+ if ((seed > int (2 ** 64 - 1 )) | (seed < 0 )).any ():
290
+ raise ValueError ("Seed values must be between 0 and "
291
+ "18446744073709551616 (2**64-1)" )
292
+ seed = seed .astype (np .uint64 , casting = 'unsafe' )
293
+ with self .lock :
294
+ set_seed_by_array (& self .rng_state ,
295
+ < uint64_t * > np .PyArray_DATA (seed ),
296
+ np .PyArray_DIM (seed , 0 ))
297
+ self .__seed = seed
272
298
self ._reset_state_variables ()
273
299
274
300
ELSE :
0 commit comments