1
- //! Safe interface for NumPy's random [`BitGenerator`]
1
+ //! Safe interface for NumPy's random [`BitGenerator`][bg]
2
+ //!
3
+ //! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html
4
+
5
+ use std:: ptr:: NonNull ;
2
6
3
7
use pyo3:: {
4
8
exceptions:: PyRuntimeError ,
@@ -12,8 +16,6 @@ use pyo3::{
12
16
use crate :: npyffi:: npy_bitgen;
13
17
14
18
///! Wrapper for [`np.random.BitGenerator`][bg]
15
- ///!
16
- ///! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html
17
19
#[ repr( transparent) ]
18
20
pub struct PyBitGenerator ( PyAny ) ;
19
21
@@ -40,55 +42,90 @@ unsafe impl PyTypeInfo for PyBitGenerator {
40
42
}
41
43
}
42
44
43
- /// Methods for [`BitGenerator `]
44
- pub trait BitGeneratorMethods < ' py > {
45
- /// Returns a new [`BitGen`]
46
- fn bit_gen ( & self ) -> PyResult < BitGen < ' py > > ;
45
+ /// Methods for [`PyBitGenerator `]
46
+ pub trait BitGeneratorMethods {
47
+ /// Acquire a lock on the BitGenerator to allow calling its methods in
48
+ fn lock ( & self ) -> PyResult < PyBitGeneratorLock > ;
47
49
}
48
50
49
- impl < ' py > BitGeneratorMethods < ' py > for Bound < ' py , PyBitGenerator > {
50
- fn bit_gen ( & self ) -> PyResult < BitGen < ' py > > {
51
+ impl < ' py > BitGeneratorMethods for Bound < ' py , PyBitGenerator > {
52
+ fn lock ( & self ) -> PyResult < PyBitGeneratorLock > {
51
53
let capsule = self . getattr ( "capsule" ) ?. downcast_into :: < PyCapsule > ( ) ?;
54
+ let lock = self . getattr ( "lock" ) ?;
55
+ if lock. getattr ( "locked" ) ?. call0 ( ) ?. extract ( ) ? {
56
+ return Err ( PyRuntimeError :: new_err ( "BitGenerator is already locked" ) ) ;
57
+ }
58
+ lock. getattr ( "acquire" ) ?. call0 ( ) ?;
59
+
52
60
assert_eq ! ( capsule. name( ) ?, Some ( c"BitGenerator" ) ) ;
53
61
let ptr = capsule. pointer ( ) as * mut npy_bitgen ;
54
- // SAFETY: the lifetime of `ptr` is derived from the lifetime of `self`
55
- let ref_ = unsafe { ptr. as_mut :: < ' py > ( ) }
56
- . ok_or_else ( || PyRuntimeError :: new_err ( "Invalid BitGenerator capsule" ) ) ?;
57
- Ok ( BitGen ( ref_) )
62
+ let non_null = match NonNull :: new ( ptr) {
63
+ Some ( non_null) => non_null,
64
+ None => {
65
+ lock. getattr ( "release" ) ?. call0 ( ) ?;
66
+ return Err ( PyRuntimeError :: new_err ( "Invalid BitGenerator capsule" ) ) ;
67
+ }
68
+ } ;
69
+ Ok ( PyBitGeneratorLock ( non_null, lock. unbind ( ) ) )
58
70
}
59
71
}
60
72
61
- impl < ' py > TryFrom < & Bound < ' py , PyBitGenerator > > for BitGen < ' py > {
73
+ impl < ' py > TryFrom < & Bound < ' py , PyBitGenerator > > for PyBitGeneratorLock {
62
74
type Error = PyErr ;
63
75
fn try_from ( value : & Bound < ' py , PyBitGenerator > ) -> Result < Self , Self :: Error > {
64
- value. bit_gen ( )
76
+ value. lock ( )
65
77
}
66
78
}
67
79
68
- /// Wrapper for [`npy_bitgen`]
69
- pub struct BitGen < ' a > ( & ' a mut npy_bitgen ) ;
80
+ /// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL
81
+ pub struct PyBitGeneratorLock ( NonNull < npy_bitgen > , Py < PyAny > ) ;
70
82
71
- impl < ' py > BitGen < ' py > {
83
+ // SAFETY for all methods: We hold the BitGenerator lock, so nothing is allowed to change its state
84
+ impl PyBitGeneratorLock {
72
85
/// Returns the next random unsigned 64 bit integer
73
- pub fn next_uint64 ( & self ) -> u64 {
74
- unsafe { ( self . 0 . next_uint64 ) ( self . 0 . state ) }
86
+ pub fn next_uint64 ( & mut self ) -> u64 {
87
+ unsafe {
88
+ let bitgen = self . 0 . as_mut ( ) ;
89
+ ( bitgen. next_uint64 ) ( bitgen. state )
90
+ }
75
91
}
76
92
/// Returns the next random unsigned 32 bit integer
77
- pub fn next_uint32 ( & self ) -> u32 {
78
- unsafe { ( self . 0 . next_uint32 ) ( self . 0 . state ) }
93
+ pub fn next_uint32 ( & mut self ) -> u32 {
94
+ unsafe {
95
+ let bitgen = self . 0 . as_mut ( ) ;
96
+ ( bitgen. next_uint32 ) ( bitgen. state )
97
+ }
79
98
}
80
99
/// Returns the next random double
81
- pub fn next_double ( & self ) -> libc:: c_double {
82
- unsafe { ( self . 0 . next_double ) ( self . 0 . state ) }
100
+ pub fn next_double ( & mut self ) -> libc:: c_double {
101
+ unsafe {
102
+ let bitgen = self . 0 . as_mut ( ) ;
103
+ ( bitgen. next_double ) ( bitgen. state )
104
+ }
83
105
}
84
106
/// Returns the next raw value (can be used for testing)
85
- pub fn next_raw ( & self ) -> u64 {
86
- unsafe { ( self . 0 . next_raw ) ( self . 0 . state ) }
107
+ pub fn next_raw ( & mut self ) -> u64 {
108
+ unsafe {
109
+ let bitgen = self . 0 . as_mut ( ) ;
110
+ ( bitgen. next_raw ) ( bitgen. state )
111
+ }
112
+ }
113
+ }
114
+
115
+ impl Drop for PyBitGeneratorLock {
116
+ fn drop ( & mut self ) {
117
+ let r = Python :: with_gil ( |py| -> PyResult < ( ) > {
118
+ self . 1 . bind ( py) . getattr ( "release" ) ?. call0 ( ) ?;
119
+ Ok ( ( ) )
120
+ } ) ;
121
+ if let Err ( e) = r {
122
+ eprintln ! ( "Failed to release BitGenerator lock: {e}" ) ;
123
+ }
87
124
}
88
125
}
89
126
90
127
#[ cfg( feature = "rand" ) ]
91
- impl rand:: RngCore for BitGen < ' _ > {
128
+ impl rand:: RngCore for PyBitGeneratorLock {
92
129
fn next_u32 ( & mut self ) -> u32 {
93
130
self . next_uint32 ( )
94
131
}
@@ -115,21 +152,43 @@ mod tests {
115
152
116
153
#[ test]
117
154
fn bitgen ( ) -> PyResult < ( ) > {
118
- Python :: with_gil ( |py| {
119
- let bitgen = get_bit_generator ( py) ?. bit_gen ( ) ?;
120
- let _ = bitgen. next_raw ( ) ;
121
- Ok ( ( ) )
122
- } )
155
+ let mut bitgen = Python :: with_gil ( |py| get_bit_generator ( py) ?. lock ( ) ) ?;
156
+ let _ = bitgen. next_raw ( ) ;
157
+ std:: mem:: drop ( bitgen) ;
158
+ Ok ( ( ) )
123
159
}
124
160
161
+ /// Test that the `rand::Rng` APIs work
125
162
#[ cfg( feature = "rand" ) ]
126
163
#[ test]
127
164
fn rand ( ) -> PyResult < ( ) > {
128
165
use rand:: Rng as _;
166
+ let mut bitgen = Python :: with_gil ( |py| get_bit_generator ( py) ?. lock ( ) ) ?;
167
+ assert ! ( bitgen. random_ratio( 1 , 1 ) ) ;
168
+ assert ! ( !bitgen. random_ratio( 0 , 1 ) ) ;
169
+ std:: mem:: drop ( bitgen) ;
170
+ Ok ( ( ) )
171
+ }
129
172
173
+ /// Test that dropping the lock works while holding the GIL
174
+ #[ test]
175
+ fn unlock_with_held_gil ( ) -> PyResult < ( ) > {
176
+ Python :: with_gil ( |py| {
177
+ let generator = get_bit_generator ( py) ?;
178
+ let mut bitgen = generator. lock ( ) ?;
179
+ let _ = bitgen. next_raw ( ) ;
180
+ std:: mem:: drop ( bitgen) ;
181
+ Ok ( ( ) )
182
+ } )
183
+ }
184
+
185
+ #[ test]
186
+ fn double_lock_fails ( ) -> PyResult < ( ) > {
130
187
Python :: with_gil ( |py| {
131
- let mut bitgen = get_bit_generator ( py) ?. bit_gen ( ) ?;
132
- let _ = bitgen. random_ratio ( 2 , 3 ) ;
188
+ let generator = get_bit_generator ( py) ?;
189
+ let d1 = generator. lock ( ) ?;
190
+ assert ! ( generator. lock( ) . is_err( ) ) ;
191
+ std:: mem:: drop ( d1) ;
133
192
Ok ( ( ) )
134
193
} )
135
194
}
0 commit comments