Skip to content

Commit 71ce8be

Browse files
committed
manually drop and capsule
1 parent 876001b commit 71ce8be

File tree

1 file changed

+47
-38
lines changed

1 file changed

+47
-38
lines changed

src/random.rs

Lines changed: 47 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,13 @@ unsafe impl PyTypeInfo for PyBitGenerator {
7474
}
7575

7676
/// Methods for [`PyBitGenerator`].
77-
pub trait PyBitGeneratorMethods {
77+
pub trait PyBitGeneratorMethods<'py> {
7878
/// Acquire a lock on the BitGenerator to allow calling its methods in.
79-
fn lock(&self) -> PyResult<PyBitGeneratorGuard>;
79+
fn lock(&self) -> PyResult<PyBitGeneratorGuard<'py>>;
8080
}
8181

82-
impl<'py> PyBitGeneratorMethods for Bound<'py, PyBitGenerator> {
83-
fn lock(&self) -> PyResult<PyBitGeneratorGuard> {
82+
impl<'py> PyBitGeneratorMethods<'py> for Bound<'py, PyBitGenerator> {
83+
fn lock(&self) -> PyResult<PyBitGeneratorGuard<'py>> {
8484
let capsule = self.getattr("capsule")?.downcast_into::<PyCapsule>()?;
8585
let lock = self.getattr("lock")?;
8686
if lock.call_method0("locked")?.extract()? {
@@ -99,27 +99,44 @@ impl<'py> PyBitGeneratorMethods for Bound<'py, PyBitGenerator> {
9999
};
100100
Ok(PyBitGeneratorGuard {
101101
raw_bitgen: non_null,
102-
lock: lock.unbind(),
102+
capsule,
103+
lock,
103104
})
104105
}
105106
}
106107

107-
impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorGuard {
108+
impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorGuard<'py> {
108109
type Error = PyErr;
109110
fn try_from(value: &Bound<'py, PyBitGenerator>) -> Result<Self, Self::Error> {
110111
value.lock()
111112
}
112113
}
113114

114115
/// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL.
115-
pub struct PyBitGeneratorGuard {
116+
pub struct PyBitGeneratorGuard<'py> {
116117
raw_bitgen: NonNull<npy_bitgen>,
117-
lock: Py<PyAny>,
118+
capsule: Bound<'py, PyCapsule>,
119+
lock: Bound<'py, PyAny>,
120+
}
121+
122+
unsafe impl Send for PyBitGeneratorGuard<'_> {}
123+
124+
impl Drop for PyBitGeneratorGuard<'_> {
125+
fn drop(&mut self) {
126+
// ignore errors. This includes when `try_drop` was called manually
127+
let _ = self.lock.call_method0("release");
128+
}
118129
}
119130

120131
// SAFETY: We hold the `BitGenerator.lock`,
121132
// so nothing apart from us is allowed to change its state.
122-
impl PyBitGeneratorGuard {
133+
impl PyBitGeneratorGuard<'_> {
134+
/// Drop the lock, allowing access to.
135+
pub fn try_drop(self) -> PyResult<()> {
136+
self.lock.call_method0("release")?;
137+
Ok(())
138+
}
139+
123140
/// Returns the next random unsigned 64 bit integer.
124141
pub fn next_uint64(&mut self) -> u64 {
125142
unsafe {
@@ -150,20 +167,8 @@ impl PyBitGeneratorGuard {
150167
}
151168
}
152169

153-
impl Drop for PyBitGeneratorGuard {
154-
fn drop(&mut self) {
155-
let r = Python::with_gil(|py| -> PyResult<()> {
156-
self.lock.bind(py).call_method0("release")?;
157-
Ok(())
158-
});
159-
if let Err(e) = r {
160-
eprintln!("Failed to release BitGenerator lock: {e}");
161-
}
162-
}
163-
}
164-
165170
#[cfg(feature = "rand")]
166-
impl rand::RngCore for PyBitGeneratorGuard {
171+
impl rand::RngCore for PyBitGeneratorGuard<'_> {
167172
fn next_u32(&mut self) -> u32 {
168173
self.next_uint32()
169174
}
@@ -190,10 +195,14 @@ mod tests {
190195
/// Test the primary use case: acquire the lock, release the GIL, then use the lock
191196
#[test]
192197
fn use_outside_gil() -> PyResult<()> {
193-
let mut bitgen = Python::with_gil(|py| get_bit_generator(py)?.lock())?;
194-
let _ = bitgen.next_raw();
195-
std::mem::drop(bitgen);
196-
Ok(())
198+
Python::with_gil(|py| {
199+
let mut bitgen = get_bit_generator(py)?.lock()?;
200+
py.allow_threads(|| {
201+
let _ = bitgen.next_raw();
202+
});
203+
assert!(bitgen.try_drop().is_ok());
204+
Ok(())
205+
})
197206
}
198207

199208
/// Test that the `rand::Rng` APIs work
@@ -202,11 +211,15 @@ mod tests {
202211
fn rand() -> PyResult<()> {
203212
use rand::Rng as _;
204213

205-
let mut bitgen = Python::with_gil(|py| get_bit_generator(py)?.lock())?;
206-
assert!(bitgen.random_ratio(1, 1));
207-
assert!(!bitgen.random_ratio(0, 1));
208-
std::mem::drop(bitgen);
209-
Ok(())
214+
Python::with_gil(|py| {
215+
let mut bitgen = get_bit_generator(py)?.lock()?;
216+
py.allow_threads(|| {
217+
assert!(bitgen.random_ratio(1, 1));
218+
assert!(!bitgen.random_ratio(0, 1));
219+
});
220+
assert!(bitgen.try_drop().is_ok());
221+
Ok(())
222+
})
210223
}
211224

212225
/// Test that releasing the lock works while holding the GIL
@@ -216,11 +229,7 @@ mod tests {
216229
let generator = get_bit_generator(py)?;
217230
let mut bitgen = generator.lock()?;
218231
let _ = bitgen.next_raw();
219-
std::mem::drop(bitgen);
220-
assert!(!generator
221-
.getattr("lock")?
222-
.call_method0("locked")?
223-
.extract()?);
232+
assert!(bitgen.try_drop().is_ok());
224233
Ok(())
225234
})
226235
}
@@ -229,9 +238,9 @@ mod tests {
229238
fn double_lock_fails() -> PyResult<()> {
230239
Python::with_gil(|py| {
231240
let generator = get_bit_generator(py)?;
232-
let d1 = generator.lock()?;
241+
let bitgen = generator.lock()?;
233242
assert!(generator.lock().is_err());
234-
std::mem::drop(d1);
243+
assert!(bitgen.try_drop().is_ok());
235244
Ok(())
236245
})
237246
}

0 commit comments

Comments
 (0)