Skip to content

Commit ee32246

Browse files
committed
docs
1 parent a0b9ec5 commit ee32246

File tree

1 file changed

+55
-20
lines changed

1 file changed

+55
-20
lines changed

src/random.rs

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,33 @@
1-
//! Safe interface for NumPy's random [`BitGenerator`][bg]
1+
//! Safe interface for NumPy's random [`BitGenerator`][bg].
2+
//!
3+
//! Using the patterns described in [“Extending `numpy.random`”][ext],
4+
//! you can generate random numbers without holding the GIL,
5+
//! by [acquiring][`PyBitGeneratorMethods::lock`] a [lock][`PyBitGeneratorLock`] for the [`PyBitGenerator`]:
6+
//!
7+
//! ```rust
8+
//! use pyo3::prelude::*;
9+
//! use numpy::random::{PyBitGenerator, PyBitGeneratorMethods as _};
10+
//!
11+
//! let mut bitgen = Python::with_gil(|py| -> PyResult<_> {
12+
//! let default_rng = py.import("numpy.random")?.getattr("default_rng")?.call0()?;
13+
//! let bit_generator = default_rng.getattr("bit_generator")?.downcast_into::<PyBitGenerator>()?;
14+
//! bit_generator.lock()
15+
//! })?;
16+
//! let random_number = bitgen.next_u64();
17+
//! ```
18+
//!
19+
//! With the [`rand`] crate installed, you can also use the [`rand::Rng`] APIs from the [`PyBitGeneratorLock`]:
20+
//!
21+
//! ```rust
22+
//! use rand::Rng as _;
23+
//!
24+
//! if bitgen.random_ratio(1, 1_000_000) {
25+
//! println!("a sure thing");
26+
//! }
27+
//! ```
228
//!
329
//! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html
30+
//! [ext]: https://numpy.org/doc/stable/reference/random/extending.html
431
532
use std::ptr::NonNull;
633

@@ -15,7 +42,9 @@ use pyo3::{
1542

1643
use crate::npyffi::npy_bitgen;
1744

18-
///! Wrapper for [`np.random.BitGenerator`][bg]
45+
/// Wrapper for [`np.random.BitGenerator`][bg].
46+
///
47+
/// [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html
1948
#[repr(transparent)]
2049
pub struct PyBitGenerator(PyAny);
2150

@@ -42,13 +71,13 @@ unsafe impl PyTypeInfo for PyBitGenerator {
4271
}
4372
}
4473

45-
/// Methods for [`PyBitGenerator`]
46-
pub trait BitGeneratorMethods {
47-
/// Acquire a lock on the BitGenerator to allow calling its methods in
74+
/// Methods for [`PyBitGenerator`].
75+
pub trait PyBitGeneratorMethods {
76+
/// Acquire a lock on the BitGenerator to allow calling its methods in.
4877
fn lock(&self) -> PyResult<PyBitGeneratorLock>;
4978
}
5079

51-
impl<'py> BitGeneratorMethods for Bound<'py, PyBitGenerator> {
80+
impl<'py> PyBitGeneratorMethods for Bound<'py, PyBitGenerator> {
5281
fn lock(&self) -> PyResult<PyBitGeneratorLock> {
5382
let capsule = self.getattr("capsule")?.downcast_into::<PyCapsule>()?;
5483
let lock = self.getattr("lock")?;
@@ -66,7 +95,10 @@ impl<'py> BitGeneratorMethods for Bound<'py, PyBitGenerator> {
6695
return Err(PyRuntimeError::new_err("Invalid BitGenerator capsule"));
6796
}
6897
};
69-
Ok(PyBitGeneratorLock(non_null, lock.unbind()))
98+
Ok(PyBitGeneratorLock {
99+
raw_bitgen: non_null,
100+
lock: lock.unbind(),
101+
})
70102
}
71103
}
72104

@@ -77,36 +109,39 @@ impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorLock {
77109
}
78110
}
79111

80-
/// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL
81-
pub struct PyBitGeneratorLock(NonNull<npy_bitgen>, Py<PyAny>);
112+
/// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL.
113+
pub struct PyBitGeneratorLock {
114+
raw_bitgen: NonNull<npy_bitgen>,
115+
lock: Py<PyAny>,
116+
}
82117

83118
// SAFETY for all methods: We hold the BitGenerator lock, so nothing is allowed to change its state
84119
impl PyBitGeneratorLock {
85-
/// Returns the next random unsigned 64 bit integer
120+
/// Returns the next random unsigned 64 bit integer.
86121
pub fn next_uint64(&mut self) -> u64 {
87122
unsafe {
88-
let bitgen = self.0.as_mut();
123+
let bitgen = self.raw_bitgen.as_mut();
89124
(bitgen.next_uint64)(bitgen.state)
90125
}
91126
}
92-
/// Returns the next random unsigned 32 bit integer
127+
/// Returns the next random unsigned 32 bit integer.
93128
pub fn next_uint32(&mut self) -> u32 {
94129
unsafe {
95-
let bitgen = self.0.as_mut();
130+
let bitgen = self.raw_bitgen.as_mut();
96131
(bitgen.next_uint32)(bitgen.state)
97132
}
98133
}
99-
/// Returns the next random double
134+
/// Returns the next random double.
100135
pub fn next_double(&mut self) -> libc::c_double {
101136
unsafe {
102-
let bitgen = self.0.as_mut();
137+
let bitgen = self.raw_bitgen.as_mut();
103138
(bitgen.next_double)(bitgen.state)
104139
}
105140
}
106-
/// Returns the next raw value (can be used for testing)
141+
/// Returns the next raw value (can be used for testing).
107142
pub fn next_raw(&mut self) -> u64 {
108143
unsafe {
109-
let bitgen = self.0.as_mut();
144+
let bitgen = self.raw_bitgen.as_mut();
110145
(bitgen.next_raw)(bitgen.state)
111146
}
112147
}
@@ -115,7 +150,7 @@ impl PyBitGeneratorLock {
115150
impl Drop for PyBitGeneratorLock {
116151
fn drop(&mut self) {
117152
let r = Python::with_gil(|py| -> PyResult<()> {
118-
self.1.bind(py).getattr("release")?.call0()?;
153+
self.lock.bind(py).getattr("release")?.call0()?;
119154
Ok(())
120155
});
121156
if let Err(e) = r {
@@ -142,9 +177,8 @@ mod tests {
142177
use super::*;
143178

144179
fn get_bit_generator<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyBitGenerator>> {
145-
let default_rng = py.import("numpy.random")?.getattr("default_rng")?;
180+
let default_rng = py.import("numpy.random")?.getattr("default_rng")?.call0()?;
146181
let bit_generator = default_rng
147-
.call0()?
148182
.getattr("bit_generator")?
149183
.downcast_into::<PyBitGenerator>()?;
150184
Ok(bit_generator)
@@ -163,6 +197,7 @@ mod tests {
163197
#[test]
164198
fn rand() -> PyResult<()> {
165199
use rand::Rng as _;
200+
166201
let mut bitgen = Python::with_gil(|py| get_bit_generator(py)?.lock())?;
167202
assert!(bitgen.random_ratio(1, 1));
168203
assert!(!bitgen.random_ratio(0, 1));

0 commit comments

Comments
 (0)