1- //! Safe interface for NumPy's random [`BitGenerator`][bg]
1+ //! Safe interface for NumPy's random [`BitGenerator`][bg].
2+ //!
3+ //! Using the patterns described in [“Extending `numpy.random`”][ext],
4+ //! you can generate random numbers without holding the GIL,
5+ //! by [acquiring][`PyBitGeneratorMethods::lock`] a [lock][`PyBitGeneratorLock`] for the [`PyBitGenerator`]:
6+ //!
7+ //! ```rust
8+ //! use pyo3::prelude::*;
9+ //! use numpy::random::{PyBitGenerator, PyBitGeneratorMethods as _};
10+ //!
11+ //! let mut bitgen = Python::with_gil(|py| -> PyResult<_> {
12+ //! let default_rng = py.import("numpy.random")?.getattr("default_rng")?.call0()?;
13+ //! let bit_generator = default_rng.getattr("bit_generator")?.downcast_into::<PyBitGenerator>()?;
14+ //! bit_generator.lock()
15+ //! })?;
16+ //! let random_number = bitgen.next_u64();
17+ //! ```
18+ //!
19+ //! With the [`rand`] crate installed, you can also use the [`rand::Rng`] APIs from the [`PyBitGeneratorLock`]:
20+ //!
21+ //! ```rust
22+ //! use rand::Rng as _;
23+ //!
24+ //! if bitgen.random_ratio(1, 1_000_000) {
25+ //! println!("a sure thing");
26+ //! }
27+ //! ```
228//!
329//! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html
30+ //! [ext]: https://numpy.org/doc/stable/reference/random/extending.html
431
532use std:: ptr:: NonNull ;
633
@@ -15,7 +42,9 @@ use pyo3::{
1542
1643use crate :: npyffi:: npy_bitgen;
1744
18- ///! Wrapper for [`np.random.BitGenerator`][bg]
45+ /// Wrapper for [`np.random.BitGenerator`][bg].
46+ ///
47+ /// [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html
1948#[ repr( transparent) ]
2049pub struct PyBitGenerator ( PyAny ) ;
2150
@@ -42,13 +71,13 @@ unsafe impl PyTypeInfo for PyBitGenerator {
4271 }
4372}
4473
45- /// Methods for [`PyBitGenerator`]
46- pub trait BitGeneratorMethods {
47- /// Acquire a lock on the BitGenerator to allow calling its methods in
74+ /// Methods for [`PyBitGenerator`].
75+ pub trait PyBitGeneratorMethods {
76+ /// Acquire a lock on the BitGenerator to allow calling its methods in.
4877 fn lock ( & self ) -> PyResult < PyBitGeneratorLock > ;
4978}
5079
51- impl < ' py > BitGeneratorMethods for Bound < ' py , PyBitGenerator > {
80+ impl < ' py > PyBitGeneratorMethods for Bound < ' py , PyBitGenerator > {
5281 fn lock ( & self ) -> PyResult < PyBitGeneratorLock > {
5382 let capsule = self . getattr ( "capsule" ) ?. downcast_into :: < PyCapsule > ( ) ?;
5483 let lock = self . getattr ( "lock" ) ?;
@@ -66,7 +95,10 @@ impl<'py> BitGeneratorMethods for Bound<'py, PyBitGenerator> {
6695 return Err ( PyRuntimeError :: new_err ( "Invalid BitGenerator capsule" ) ) ;
6796 }
6897 } ;
69- Ok ( PyBitGeneratorLock ( non_null, lock. unbind ( ) ) )
98+ Ok ( PyBitGeneratorLock {
99+ raw_bitgen : non_null,
100+ lock : lock. unbind ( ) ,
101+ } )
70102 }
71103}
72104
@@ -77,36 +109,39 @@ impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorLock {
77109 }
78110}
79111
80- /// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL
81- pub struct PyBitGeneratorLock ( NonNull < npy_bitgen > , Py < PyAny > ) ;
112+ /// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL.
113+ pub struct PyBitGeneratorLock {
114+ raw_bitgen : NonNull < npy_bitgen > ,
115+ lock : Py < PyAny > ,
116+ }
82117
83118// SAFETY for all methods: We hold the BitGenerator lock, so nothing is allowed to change its state
84119impl PyBitGeneratorLock {
85- /// Returns the next random unsigned 64 bit integer
120+ /// Returns the next random unsigned 64 bit integer.
86121 pub fn next_uint64 ( & mut self ) -> u64 {
87122 unsafe {
88- let bitgen = self . 0 . as_mut ( ) ;
123+ let bitgen = self . raw_bitgen . as_mut ( ) ;
89124 ( bitgen. next_uint64 ) ( bitgen. state )
90125 }
91126 }
92- /// Returns the next random unsigned 32 bit integer
127+ /// Returns the next random unsigned 32 bit integer.
93128 pub fn next_uint32 ( & mut self ) -> u32 {
94129 unsafe {
95- let bitgen = self . 0 . as_mut ( ) ;
130+ let bitgen = self . raw_bitgen . as_mut ( ) ;
96131 ( bitgen. next_uint32 ) ( bitgen. state )
97132 }
98133 }
99- /// Returns the next random double
134+ /// Returns the next random double.
100135 pub fn next_double ( & mut self ) -> libc:: c_double {
101136 unsafe {
102- let bitgen = self . 0 . as_mut ( ) ;
137+ let bitgen = self . raw_bitgen . as_mut ( ) ;
103138 ( bitgen. next_double ) ( bitgen. state )
104139 }
105140 }
106- /// Returns the next raw value (can be used for testing)
141+ /// Returns the next raw value (can be used for testing).
107142 pub fn next_raw ( & mut self ) -> u64 {
108143 unsafe {
109- let bitgen = self . 0 . as_mut ( ) ;
144+ let bitgen = self . raw_bitgen . as_mut ( ) ;
110145 ( bitgen. next_raw ) ( bitgen. state )
111146 }
112147 }
@@ -115,7 +150,7 @@ impl PyBitGeneratorLock {
115150impl Drop for PyBitGeneratorLock {
116151 fn drop ( & mut self ) {
117152 let r = Python :: with_gil ( |py| -> PyResult < ( ) > {
118- self . 1 . bind ( py) . getattr ( "release" ) ?. call0 ( ) ?;
153+ self . lock . bind ( py) . getattr ( "release" ) ?. call0 ( ) ?;
119154 Ok ( ( ) )
120155 } ) ;
121156 if let Err ( e) = r {
@@ -142,9 +177,8 @@ mod tests {
142177 use super :: * ;
143178
144179 fn get_bit_generator < ' py > ( py : Python < ' py > ) -> PyResult < Bound < ' py , PyBitGenerator > > {
145- let default_rng = py. import ( "numpy.random" ) ?. getattr ( "default_rng" ) ?;
180+ let default_rng = py. import ( "numpy.random" ) ?. getattr ( "default_rng" ) ?. call0 ( ) ? ;
146181 let bit_generator = default_rng
147- . call0 ( ) ?
148182 . getattr ( "bit_generator" ) ?
149183 . downcast_into :: < PyBitGenerator > ( ) ?;
150184 Ok ( bit_generator)
@@ -163,6 +197,7 @@ mod tests {
163197 #[ test]
164198 fn rand ( ) -> PyResult < ( ) > {
165199 use rand:: Rng as _;
200+
166201 let mut bitgen = Python :: with_gil ( |py| get_bit_generator ( py) ?. lock ( ) ) ?;
167202 assert ! ( bitgen. random_ratio( 1 , 1 ) ) ;
168203 assert ! ( !bitgen. random_ratio( 0 , 1 ) ) ;
0 commit comments