Skip to content

Commit 2aa3d90

Browse files
committed
guard
1 parent 1be6838 commit 2aa3d90

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

src/random.rs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//!
33
//! Using the patterns described in [“Extending `numpy.random`”][ext],
44
//! you can generate random numbers without holding the GIL,
5-
//! by [acquiring][`PyBitGeneratorMethods::lock`] a [lock][`PyBitGeneratorLock`] for the [`PyBitGenerator`]:
5+
//! by [acquiring][`PyBitGeneratorMethods::lock`] a lock [guard][`PyBitGeneratorGuard`] for the [`PyBitGenerator`]:
66
//!
77
//! ```rust
88
//! use pyo3::prelude::*;
@@ -16,7 +16,7 @@
1616
//! let random_number = bitgen.next_u64();
1717
//! ```
1818
//!
19-
//! With the [`rand`] crate installed, you can also use the [`rand::Rng`] APIs from the [`PyBitGeneratorLock`]:
19+
//! With the [`rand`] crate installed, you can also use the [`rand::Rng`] APIs from the [`PyBitGeneratorGuard`]:
2020
//!
2121
//! ```rust
2222
//! use rand::Rng as _;
@@ -76,11 +76,11 @@ unsafe impl PyTypeInfo for PyBitGenerator {
7676
/// Methods for [`PyBitGenerator`].
7777
pub trait PyBitGeneratorMethods {
7878
/// Acquire a lock on the BitGenerator to allow calling its methods in.
79-
fn lock(&self) -> PyResult<PyBitGeneratorLock>;
79+
fn lock(&self) -> PyResult<PyBitGeneratorGuard>;
8080
}
8181

8282
impl<'py> PyBitGeneratorMethods for Bound<'py, PyBitGenerator> {
83-
fn lock(&self) -> PyResult<PyBitGeneratorLock> {
83+
fn lock(&self) -> PyResult<PyBitGeneratorGuard> {
8484
let capsule = self.getattr("capsule")?.downcast_into::<PyCapsule>()?;
8585
let lock = self.getattr("lock")?;
8686
if lock.getattr("locked")?.call0()?.extract()? {
@@ -97,28 +97,29 @@ impl<'py> PyBitGeneratorMethods for Bound<'py, PyBitGenerator> {
9797
return Err(PyRuntimeError::new_err("Invalid BitGenerator capsule"));
9898
}
9999
};
100-
Ok(PyBitGeneratorLock {
100+
Ok(PyBitGeneratorGuard {
101101
raw_bitgen: non_null,
102102
lock: lock.unbind(),
103103
})
104104
}
105105
}
106106

107-
impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorLock {
107+
impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorGuard {
108108
type Error = PyErr;
109109
fn try_from(value: &Bound<'py, PyBitGenerator>) -> Result<Self, Self::Error> {
110110
value.lock()
111111
}
112112
}
113113

114114
/// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL.
115-
pub struct PyBitGeneratorLock {
115+
pub struct PyBitGeneratorGuard {
116116
raw_bitgen: NonNull<npy_bitgen>,
117117
lock: Py<PyAny>,
118118
}
119119

120-
// SAFETY for all methods: We hold the BitGenerator lock, so nothing is allowed to change its state
121-
impl PyBitGeneratorLock {
120+
// SAFETY: We hold the `BitGenerator.lock`,
121+
// so nothing apart from us is allowed to change its state.
122+
impl PyBitGeneratorGuard {
122123
/// Returns the next random unsigned 64 bit integer.
123124
pub fn next_uint64(&mut self) -> u64 {
124125
unsafe {
@@ -149,7 +150,7 @@ impl PyBitGeneratorLock {
149150
}
150151
}
151152

152-
impl Drop for PyBitGeneratorLock {
153+
impl Drop for PyBitGeneratorGuard {
153154
fn drop(&mut self) {
154155
let r = Python::with_gil(|py| -> PyResult<()> {
155156
self.lock.bind(py).getattr("release")?.call0()?;
@@ -162,7 +163,7 @@ impl Drop for PyBitGeneratorLock {
162163
}
163164

164165
#[cfg(feature = "rand")]
165-
impl rand::RngCore for PyBitGeneratorLock {
166+
impl rand::RngCore for PyBitGeneratorGuard {
166167
fn next_u32(&mut self) -> u32 {
167168
self.next_uint32()
168169
}
@@ -207,14 +208,19 @@ mod tests {
207208
Ok(())
208209
}
209210

210-
/// Test that dropping the lock works while holding the GIL
211+
/// Test that releasing the lock works while holding the GIL
211212
#[test]
212213
fn unlock_with_held_gil() -> PyResult<()> {
213214
Python::with_gil(|py| {
214215
let generator = get_bit_generator(py)?;
215216
let mut bitgen = generator.lock()?;
216217
let _ = bitgen.next_raw();
217218
std::mem::drop(bitgen);
219+
assert!(!generator
220+
.getattr("lock")?
221+
.getattr("locked")?
222+
.call0()?
223+
.extract()?);
218224
Ok(())
219225
})
220226
}

0 commit comments

Comments
 (0)