1- //! Safe interface for NumPy's random [`BitGenerator`]
1+ //! Safe interface for NumPy's random [`BitGenerator`][bg]
2+ //!
3+ //! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html
4+
5+ use std:: ptr:: NonNull ;
26
37use pyo3:: {
48 exceptions:: PyRuntimeError ,
@@ -12,8 +16,6 @@ use pyo3::{
1216use crate :: npyffi:: npy_bitgen;
1317
1418///! Wrapper for [`np.random.BitGenerator`][bg]
15- ///!
16- ///! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html
1719#[ repr( transparent) ]
1820pub struct PyBitGenerator ( PyAny ) ;
1921
@@ -40,55 +42,90 @@ unsafe impl PyTypeInfo for PyBitGenerator {
4042 }
4143}
4244
43- /// Methods for [`BitGenerator `]
44- pub trait BitGeneratorMethods < ' py > {
45- /// Returns a new [`BitGen`]
46- fn bit_gen ( & self ) -> PyResult < BitGen < ' py > > ;
45+ /// Methods for [`PyBitGenerator `]
46+ pub trait BitGeneratorMethods {
47+ /// Acquire a lock on the BitGenerator to allow calling its methods in
48+ fn lock ( & self ) -> PyResult < PyBitGeneratorLock > ;
4749}
4850
49- impl < ' py > BitGeneratorMethods < ' py > for Bound < ' py , PyBitGenerator > {
50- fn bit_gen ( & self ) -> PyResult < BitGen < ' py > > {
51+ impl < ' py > BitGeneratorMethods for Bound < ' py , PyBitGenerator > {
52+ fn lock ( & self ) -> PyResult < PyBitGeneratorLock > {
5153 let capsule = self . getattr ( "capsule" ) ?. downcast_into :: < PyCapsule > ( ) ?;
54+ let lock = self . getattr ( "lock" ) ?;
55+ if lock. getattr ( "locked" ) ?. call0 ( ) ?. extract ( ) ? {
56+ return Err ( PyRuntimeError :: new_err ( "BitGenerator is already locked" ) ) ;
57+ }
58+ lock. getattr ( "acquire" ) ?. call0 ( ) ?;
59+
5260 assert_eq ! ( capsule. name( ) ?, Some ( c"BitGenerator" ) ) ;
5361 let ptr = capsule. pointer ( ) as * mut npy_bitgen ;
54- // SAFETY: the lifetime of `ptr` is derived from the lifetime of `self`
55- let ref_ = unsafe { ptr. as_mut :: < ' py > ( ) }
56- . ok_or_else ( || PyRuntimeError :: new_err ( "Invalid BitGenerator capsule" ) ) ?;
57- Ok ( BitGen ( ref_) )
62+ let non_null = match NonNull :: new ( ptr) {
63+ Some ( non_null) => non_null,
64+ None => {
65+ lock. getattr ( "release" ) ?. call0 ( ) ?;
66+ return Err ( PyRuntimeError :: new_err ( "Invalid BitGenerator capsule" ) ) ;
67+ }
68+ } ;
69+ Ok ( PyBitGeneratorLock ( non_null, lock. unbind ( ) ) )
5870 }
5971}
6072
61- impl < ' py > TryFrom < & Bound < ' py , PyBitGenerator > > for BitGen < ' py > {
73+ impl < ' py > TryFrom < & Bound < ' py , PyBitGenerator > > for PyBitGeneratorLock {
6274 type Error = PyErr ;
6375 fn try_from ( value : & Bound < ' py , PyBitGenerator > ) -> Result < Self , Self :: Error > {
64- value. bit_gen ( )
76+ value. lock ( )
6577 }
6678}
6779
68- /// Wrapper for [`npy_bitgen`]
69- pub struct BitGen < ' a > ( & ' a mut npy_bitgen ) ;
80+ /// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL
81+ pub struct PyBitGeneratorLock ( NonNull < npy_bitgen > , Py < PyAny > ) ;
7082
71- impl < ' py > BitGen < ' py > {
83+ // SAFETY for all methods: We hold the BitGenerator lock, so nothing is allowed to change its state
84+ impl PyBitGeneratorLock {
7285 /// Returns the next random unsigned 64 bit integer
73- pub fn next_uint64 ( & self ) -> u64 {
74- unsafe { ( self . 0 . next_uint64 ) ( self . 0 . state ) }
86+ pub fn next_uint64 ( & mut self ) -> u64 {
87+ unsafe {
88+ let bitgen = self . 0 . as_mut ( ) ;
89+ ( bitgen. next_uint64 ) ( bitgen. state )
90+ }
7591 }
7692 /// Returns the next random unsigned 32 bit integer
77- pub fn next_uint32 ( & self ) -> u32 {
78- unsafe { ( self . 0 . next_uint32 ) ( self . 0 . state ) }
93+ pub fn next_uint32 ( & mut self ) -> u32 {
94+ unsafe {
95+ let bitgen = self . 0 . as_mut ( ) ;
96+ ( bitgen. next_uint32 ) ( bitgen. state )
97+ }
7998 }
8099 /// Returns the next random double
81- pub fn next_double ( & self ) -> libc:: c_double {
82- unsafe { ( self . 0 . next_double ) ( self . 0 . state ) }
100+ pub fn next_double ( & mut self ) -> libc:: c_double {
101+ unsafe {
102+ let bitgen = self . 0 . as_mut ( ) ;
103+ ( bitgen. next_double ) ( bitgen. state )
104+ }
83105 }
84106 /// Returns the next raw value (can be used for testing)
85- pub fn next_raw ( & self ) -> u64 {
86- unsafe { ( self . 0 . next_raw ) ( self . 0 . state ) }
107+ pub fn next_raw ( & mut self ) -> u64 {
108+ unsafe {
109+ let bitgen = self . 0 . as_mut ( ) ;
110+ ( bitgen. next_raw ) ( bitgen. state )
111+ }
112+ }
113+ }
114+
115+ impl Drop for PyBitGeneratorLock {
116+ fn drop ( & mut self ) {
117+ let r = Python :: with_gil ( |py| -> PyResult < ( ) > {
118+ self . 1 . bind ( py) . getattr ( "release" ) ?. call0 ( ) ?;
119+ Ok ( ( ) )
120+ } ) ;
121+ if let Err ( e) = r {
122+ eprintln ! ( "Failed to release BitGenerator lock: {e}" ) ;
123+ }
87124 }
88125}
89126
90127#[ cfg( feature = "rand" ) ]
91- impl rand:: RngCore for BitGen < ' _ > {
128+ impl rand:: RngCore for PyBitGeneratorLock {
92129 fn next_u32 ( & mut self ) -> u32 {
93130 self . next_uint32 ( )
94131 }
@@ -115,21 +152,43 @@ mod tests {
115152
116153 #[ test]
117154 fn bitgen ( ) -> PyResult < ( ) > {
118- Python :: with_gil ( |py| {
119- let bitgen = get_bit_generator ( py) ?. bit_gen ( ) ?;
120- let _ = bitgen. next_raw ( ) ;
121- Ok ( ( ) )
122- } )
155+ let mut bitgen = Python :: with_gil ( |py| get_bit_generator ( py) ?. lock ( ) ) ?;
156+ let _ = bitgen. next_raw ( ) ;
157+ std:: mem:: drop ( bitgen) ;
158+ Ok ( ( ) )
123159 }
124160
161+ /// Test that the `rand::Rng` APIs work
125162 #[ cfg( feature = "rand" ) ]
126163 #[ test]
127164 fn rand ( ) -> PyResult < ( ) > {
128165 use rand:: Rng as _;
166+ let mut bitgen = Python :: with_gil ( |py| get_bit_generator ( py) ?. lock ( ) ) ?;
167+ assert ! ( bitgen. random_ratio( 1 , 1 ) ) ;
168+ assert ! ( !bitgen. random_ratio( 0 , 1 ) ) ;
169+ std:: mem:: drop ( bitgen) ;
170+ Ok ( ( ) )
171+ }
129172
173+ /// Test that dropping the lock works while holding the GIL
174+ #[ test]
175+ fn unlock_with_held_gil ( ) -> PyResult < ( ) > {
176+ Python :: with_gil ( |py| {
177+ let generator = get_bit_generator ( py) ?;
178+ let mut bitgen = generator. lock ( ) ?;
179+ let _ = bitgen. next_raw ( ) ;
180+ std:: mem:: drop ( bitgen) ;
181+ Ok ( ( ) )
182+ } )
183+ }
184+
185+ #[ test]
186+ fn double_lock_fails ( ) -> PyResult < ( ) > {
130187 Python :: with_gil ( |py| {
131- let mut bitgen = get_bit_generator ( py) ?. bit_gen ( ) ?;
132- let _ = bitgen. random_ratio ( 2 , 3 ) ;
188+ let generator = get_bit_generator ( py) ?;
189+ let d1 = generator. lock ( ) ?;
190+ assert ! ( generator. lock( ) . is_err( ) ) ;
191+ std:: mem:: drop ( d1) ;
133192 Ok ( ( ) )
134193 } )
135194 }
0 commit comments