1
- //! Safe interface for NumPy's random [`BitGenerator`][bg]
1
+ //! Safe interface for NumPy's random [`BitGenerator`][bg].
2
+ //!
3
+ //! Using the patterns described in [“Extending `numpy.random`”][ext],
4
+ //! you can generate random numbers without holding the GIL,
5
+ //! by [acquiring][`PyBitGeneratorMethods::lock`] a [lock][`PyBitGeneratorLock`] for the [`PyBitGenerator`]:
6
+ //!
7
+ //! ```rust
8
+ //! use pyo3::prelude::*;
9
+ //! use numpy::random::{PyBitGenerator, PyBitGeneratorMethods as _};
10
+ //!
11
+ //! let mut bitgen = Python::with_gil(|py| -> PyResult<_> {
12
+ //! let default_rng = py.import("numpy.random")?.getattr("default_rng")?.call0()?;
13
+ //! let bit_generator = default_rng.getattr("bit_generator")?.downcast_into::<PyBitGenerator>()?;
14
+ //! bit_generator.lock()
15
+ //! })?;
16
+ //! let random_number = bitgen.next_u64();
17
+ //! ```
18
+ //!
19
+ //! With the [`rand`] crate installed, you can also use the [`rand::Rng`] APIs from the [`PyBitGeneratorLock`]:
20
+ //!
21
+ //! ```rust
22
+ //! use rand::Rng as _;
23
+ //!
24
+ //! if bitgen.random_ratio(1, 1_000_000) {
25
+ //! println!("a sure thing");
26
+ //! }
27
+ //! ```
2
28
//!
3
29
//! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html
30
+ //! [ext]: https://numpy.org/doc/stable/reference/random/extending.html
4
31
5
32
use std:: ptr:: NonNull ;
6
33
@@ -15,7 +42,9 @@ use pyo3::{
15
42
16
43
use crate :: npyffi:: npy_bitgen;
17
44
18
- ///! Wrapper for [`np.random.BitGenerator`][bg]
45
+ /// Wrapper for [`np.random.BitGenerator`][bg].
46
+ ///
47
+ /// [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html
19
48
#[ repr( transparent) ]
20
49
pub struct PyBitGenerator ( PyAny ) ;
21
50
@@ -42,13 +71,13 @@ unsafe impl PyTypeInfo for PyBitGenerator {
42
71
}
43
72
}
44
73
45
- /// Methods for [`PyBitGenerator`]
46
- pub trait BitGeneratorMethods {
47
- /// Acquire a lock on the BitGenerator to allow calling its methods in
74
+ /// Methods for [`PyBitGenerator`].
75
+ pub trait PyBitGeneratorMethods {
76
+ /// Acquire a lock on the BitGenerator to allow calling its methods in.
48
77
fn lock ( & self ) -> PyResult < PyBitGeneratorLock > ;
49
78
}
50
79
51
- impl < ' py > BitGeneratorMethods for Bound < ' py , PyBitGenerator > {
80
+ impl < ' py > PyBitGeneratorMethods for Bound < ' py , PyBitGenerator > {
52
81
fn lock ( & self ) -> PyResult < PyBitGeneratorLock > {
53
82
let capsule = self . getattr ( "capsule" ) ?. downcast_into :: < PyCapsule > ( ) ?;
54
83
let lock = self . getattr ( "lock" ) ?;
@@ -66,7 +95,10 @@ impl<'py> BitGeneratorMethods for Bound<'py, PyBitGenerator> {
66
95
return Err ( PyRuntimeError :: new_err ( "Invalid BitGenerator capsule" ) ) ;
67
96
}
68
97
} ;
69
- Ok ( PyBitGeneratorLock ( non_null, lock. unbind ( ) ) )
98
+ Ok ( PyBitGeneratorLock {
99
+ raw_bitgen : non_null,
100
+ lock : lock. unbind ( ) ,
101
+ } )
70
102
}
71
103
}
72
104
@@ -77,36 +109,39 @@ impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorLock {
77
109
}
78
110
}
79
111
80
- /// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL
81
- pub struct PyBitGeneratorLock ( NonNull < npy_bitgen > , Py < PyAny > ) ;
112
+ /// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL.
113
+ pub struct PyBitGeneratorLock {
114
+ raw_bitgen : NonNull < npy_bitgen > ,
115
+ lock : Py < PyAny > ,
116
+ }
82
117
83
118
// SAFETY for all methods: We hold the BitGenerator lock, so nothing is allowed to change its state
84
119
impl PyBitGeneratorLock {
85
- /// Returns the next random unsigned 64 bit integer
120
+ /// Returns the next random unsigned 64 bit integer.
86
121
pub fn next_uint64 ( & mut self ) -> u64 {
87
122
unsafe {
88
- let bitgen = self . 0 . as_mut ( ) ;
123
+ let bitgen = self . raw_bitgen . as_mut ( ) ;
89
124
( bitgen. next_uint64 ) ( bitgen. state )
90
125
}
91
126
}
92
- /// Returns the next random unsigned 32 bit integer
127
+ /// Returns the next random unsigned 32 bit integer.
93
128
pub fn next_uint32 ( & mut self ) -> u32 {
94
129
unsafe {
95
- let bitgen = self . 0 . as_mut ( ) ;
130
+ let bitgen = self . raw_bitgen . as_mut ( ) ;
96
131
( bitgen. next_uint32 ) ( bitgen. state )
97
132
}
98
133
}
99
- /// Returns the next random double
134
+ /// Returns the next random double.
100
135
pub fn next_double ( & mut self ) -> libc:: c_double {
101
136
unsafe {
102
- let bitgen = self . 0 . as_mut ( ) ;
137
+ let bitgen = self . raw_bitgen . as_mut ( ) ;
103
138
( bitgen. next_double ) ( bitgen. state )
104
139
}
105
140
}
106
- /// Returns the next raw value (can be used for testing)
141
+ /// Returns the next raw value (can be used for testing).
107
142
pub fn next_raw ( & mut self ) -> u64 {
108
143
unsafe {
109
- let bitgen = self . 0 . as_mut ( ) ;
144
+ let bitgen = self . raw_bitgen . as_mut ( ) ;
110
145
( bitgen. next_raw ) ( bitgen. state )
111
146
}
112
147
}
@@ -115,7 +150,7 @@ impl PyBitGeneratorLock {
115
150
impl Drop for PyBitGeneratorLock {
116
151
fn drop ( & mut self ) {
117
152
let r = Python :: with_gil ( |py| -> PyResult < ( ) > {
118
- self . 1 . bind ( py) . getattr ( "release" ) ?. call0 ( ) ?;
153
+ self . lock . bind ( py) . getattr ( "release" ) ?. call0 ( ) ?;
119
154
Ok ( ( ) )
120
155
} ) ;
121
156
if let Err ( e) = r {
@@ -142,9 +177,8 @@ mod tests {
142
177
use super :: * ;
143
178
144
179
fn get_bit_generator < ' py > ( py : Python < ' py > ) -> PyResult < Bound < ' py , PyBitGenerator > > {
145
- let default_rng = py. import ( "numpy.random" ) ?. getattr ( "default_rng" ) ?;
180
+ let default_rng = py. import ( "numpy.random" ) ?. getattr ( "default_rng" ) ?. call0 ( ) ? ;
146
181
let bit_generator = default_rng
147
- . call0 ( ) ?
148
182
. getattr ( "bit_generator" ) ?
149
183
. downcast_into :: < PyBitGenerator > ( ) ?;
150
184
Ok ( bit_generator)
@@ -163,6 +197,7 @@ mod tests {
163
197
#[ test]
164
198
fn rand ( ) -> PyResult < ( ) > {
165
199
use rand:: Rng as _;
200
+
166
201
let mut bitgen = Python :: with_gil ( |py| get_bit_generator ( py) ?. lock ( ) ) ?;
167
202
assert ! ( bitgen. random_ratio( 1 , 1 ) ) ;
168
203
assert ! ( !bitgen. random_ratio( 0 , 1 ) ) ;
0 commit comments