Skip to content

Commit a0b9ec5

Browse files
committed
make into lock
1 parent bde2553 commit a0b9ec5

File tree

1 file changed

+93
-34
lines changed

1 file changed

+93
-34
lines changed

src/random.rs

Lines changed: 93 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
//! Safe interface for NumPy's random [`BitGenerator`]
1+
//! Safe interface for NumPy's random [`BitGenerator`][bg]
2+
//!
3+
//! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html
4+
5+
use std::ptr::NonNull;
26

37
use pyo3::{
48
exceptions::PyRuntimeError,
@@ -12,8 +16,6 @@ use pyo3::{
1216
use crate::npyffi::npy_bitgen;
1317

1418
///! Wrapper for [`np.random.BitGenerator`][bg]
15-
///!
16-
///! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html
1719
#[repr(transparent)]
1820
pub struct PyBitGenerator(PyAny);
1921

@@ -40,55 +42,90 @@ unsafe impl PyTypeInfo for PyBitGenerator {
4042
}
4143
}
4244

43-
/// Methods for [`BitGenerator`]
44-
pub trait BitGeneratorMethods<'py> {
45-
/// Returns a new [`BitGen`]
46-
fn bit_gen(&self) -> PyResult<BitGen<'py>>;
45+
/// Methods for [`PyBitGenerator`]
46+
pub trait BitGeneratorMethods {
47+
/// Acquire a lock on the BitGenerator to allow calling its methods in
48+
fn lock(&self) -> PyResult<PyBitGeneratorLock>;
4749
}
4850

49-
impl<'py> BitGeneratorMethods<'py> for Bound<'py, PyBitGenerator> {
50-
fn bit_gen(&self) -> PyResult<BitGen<'py>> {
51+
impl<'py> BitGeneratorMethods for Bound<'py, PyBitGenerator> {
52+
fn lock(&self) -> PyResult<PyBitGeneratorLock> {
5153
let capsule = self.getattr("capsule")?.downcast_into::<PyCapsule>()?;
54+
let lock = self.getattr("lock")?;
55+
if lock.getattr("locked")?.call0()?.extract()? {
56+
return Err(PyRuntimeError::new_err("BitGenerator is already locked"));
57+
}
58+
lock.getattr("acquire")?.call0()?;
59+
5260
assert_eq!(capsule.name()?, Some(c"BitGenerator"));
5361
let ptr = capsule.pointer() as *mut npy_bitgen;
54-
// SAFETY: the lifetime of `ptr` is derived from the lifetime of `self`
55-
let ref_ = unsafe { ptr.as_mut::<'py>() }
56-
.ok_or_else(|| PyRuntimeError::new_err("Invalid BitGenerator capsule"))?;
57-
Ok(BitGen(ref_))
62+
let non_null = match NonNull::new(ptr) {
63+
Some(non_null) => non_null,
64+
None => {
65+
lock.getattr("release")?.call0()?;
66+
return Err(PyRuntimeError::new_err("Invalid BitGenerator capsule"));
67+
}
68+
};
69+
Ok(PyBitGeneratorLock(non_null, lock.unbind()))
5870
}
5971
}
6072

61-
impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for BitGen<'py> {
73+
impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorLock {
6274
type Error = PyErr;
6375
fn try_from(value: &Bound<'py, PyBitGenerator>) -> Result<Self, Self::Error> {
64-
value.bit_gen()
76+
value.lock()
6577
}
6678
}
6779

68-
/// Wrapper for [`npy_bitgen`]
69-
pub struct BitGen<'a>(&'a mut npy_bitgen);
80+
/// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL
81+
pub struct PyBitGeneratorLock(NonNull<npy_bitgen>, Py<PyAny>);
7082

71-
impl<'py> BitGen<'py> {
83+
// SAFETY for all methods: We hold the BitGenerator lock, so nothing is allowed to change its state
84+
impl PyBitGeneratorLock {
7285
/// Returns the next random unsigned 64 bit integer
73-
pub fn next_uint64(&self) -> u64 {
74-
unsafe { (self.0.next_uint64)(self.0.state) }
86+
pub fn next_uint64(&mut self) -> u64 {
87+
unsafe {
88+
let bitgen = self.0.as_mut();
89+
(bitgen.next_uint64)(bitgen.state)
90+
}
7591
}
7692
/// Returns the next random unsigned 32 bit integer
77-
pub fn next_uint32(&self) -> u32 {
78-
unsafe { (self.0.next_uint32)(self.0.state) }
93+
pub fn next_uint32(&mut self) -> u32 {
94+
unsafe {
95+
let bitgen = self.0.as_mut();
96+
(bitgen.next_uint32)(bitgen.state)
97+
}
7998
}
8099
/// Returns the next random double
81-
pub fn next_double(&self) -> libc::c_double {
82-
unsafe { (self.0.next_double)(self.0.state) }
100+
pub fn next_double(&mut self) -> libc::c_double {
101+
unsafe {
102+
let bitgen = self.0.as_mut();
103+
(bitgen.next_double)(bitgen.state)
104+
}
83105
}
84106
/// Returns the next raw value (can be used for testing)
85-
pub fn next_raw(&self) -> u64 {
86-
unsafe { (self.0.next_raw)(self.0.state) }
107+
pub fn next_raw(&mut self) -> u64 {
108+
unsafe {
109+
let bitgen = self.0.as_mut();
110+
(bitgen.next_raw)(bitgen.state)
111+
}
112+
}
113+
}
114+
115+
impl Drop for PyBitGeneratorLock {
116+
fn drop(&mut self) {
117+
let r = Python::with_gil(|py| -> PyResult<()> {
118+
self.1.bind(py).getattr("release")?.call0()?;
119+
Ok(())
120+
});
121+
if let Err(e) = r {
122+
eprintln!("Failed to release BitGenerator lock: {e}");
123+
}
87124
}
88125
}
89126

90127
#[cfg(feature = "rand")]
91-
impl rand::RngCore for BitGen<'_> {
128+
impl rand::RngCore for PyBitGeneratorLock {
92129
fn next_u32(&mut self) -> u32 {
93130
self.next_uint32()
94131
}
@@ -115,21 +152,43 @@ mod tests {
115152

116153
#[test]
117154
fn bitgen() -> PyResult<()> {
118-
Python::with_gil(|py| {
119-
let bitgen = get_bit_generator(py)?.bit_gen()?;
120-
let _ = bitgen.next_raw();
121-
Ok(())
122-
})
155+
let mut bitgen = Python::with_gil(|py| get_bit_generator(py)?.lock())?;
156+
let _ = bitgen.next_raw();
157+
std::mem::drop(bitgen);
158+
Ok(())
123159
}
124160

161+
/// Test that the `rand::Rng` APIs work
125162
#[cfg(feature = "rand")]
126163
#[test]
127164
fn rand() -> PyResult<()> {
128165
use rand::Rng as _;
166+
let mut bitgen = Python::with_gil(|py| get_bit_generator(py)?.lock())?;
167+
assert!(bitgen.random_ratio(1, 1));
168+
assert!(!bitgen.random_ratio(0, 1));
169+
std::mem::drop(bitgen);
170+
Ok(())
171+
}
129172

173+
/// Test that dropping the lock works while holding the GIL
174+
#[test]
175+
fn unlock_with_held_gil() -> PyResult<()> {
176+
Python::with_gil(|py| {
177+
let generator = get_bit_generator(py)?;
178+
let mut bitgen = generator.lock()?;
179+
let _ = bitgen.next_raw();
180+
std::mem::drop(bitgen);
181+
Ok(())
182+
})
183+
}
184+
185+
#[test]
186+
fn double_lock_fails() -> PyResult<()> {
130187
Python::with_gil(|py| {
131-
let mut bitgen = get_bit_generator(py)?.bit_gen()?;
132-
let _ = bitgen.random_ratio(2, 3);
188+
let generator = get_bit_generator(py)?;
189+
let d1 = generator.lock()?;
190+
assert!(generator.lock().is_err());
191+
std::mem::drop(d1);
133192
Ok(())
134193
})
135194
}

0 commit comments

Comments
 (0)