Skip to content

Commit f52b2fa

Browse files
committed
safer: don’t allow trying to get BitGen from any PyAny
1 parent d93a264 commit f52b2fa

File tree

2 files changed

+31
-24
lines changed

2 files changed

+31
-24
lines changed

src/npyffi/random.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
use std::{ffi::c_void, ptr::NonNull};
1+
use std::ffi::c_void;
22

3-
use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyCapsule};
43

54
#[repr(C)]
65
#[derive(Debug, Clone, Copy)] // TODO: can it be Clone and/or Copy?
@@ -11,10 +10,3 @@ pub struct npy_bitgen {
1110
pub next_double: unsafe extern "C" fn(*mut c_void) -> libc::c_double, //nogil
1211
pub next_raw: unsafe extern "C" fn(*mut c_void) -> super::npy_uint64, //nogil
1312
}
14-
15-
pub fn get_bitgen_api<'py>(bitgen: &Bound<'py, PyAny>) -> PyResult<NonNull<npy_bitgen>> {
16-
let capsule = bitgen.getattr("capsule")?.downcast_into::<PyCapsule>()?;
17-
assert_eq!(capsule.name()?, Some(c"BitGenerator"));
18-
let ptr = capsule.pointer() as *mut npy_bitgen;
19-
NonNull::new(ptr).ok_or_else(|| PyRuntimeError::new_err("Invalid BitGenerator capsule"))
20-
}

src/random.rs

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
//! Safe interface for NumPy's random [`BitGenerator`]
22
3-
use pyo3::{ffi, prelude::*, sync::GILOnceCell, types::PyType, PyTypeInfo};
3+
use pyo3::{ffi, prelude::*, sync::GILOnceCell, types::{PyCapsule, PyType}, PyTypeInfo, exceptions::PyRuntimeError};
44

5-
use crate::npyffi::get_bitgen_api;
5+
use crate::npyffi::npy_bitgen;
66

7-
///! Wrapper for NumPy's random [`BitGenerator`][bg]
7+
///! Wrapper for [`np.random.BitGenerator`][bg]
88
///!
99
///! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html
1010
#[repr(transparent)]
@@ -32,7 +32,27 @@ unsafe impl PyTypeInfo for BitGenerator {
3232
}
3333

3434
/// Methods for [`BitGenerator`]
35-
pub trait BitGeneratorMethods {
35+
pub trait BitGeneratorMethods<'py> {
36+
/// Returns a new [`BitGen`]
37+
fn bit_gen(&self) -> PyResult<BitGen<'py>>;
38+
}
39+
40+
impl<'py> BitGeneratorMethods<'py> for Bound<'py, BitGenerator> {
41+
fn bit_gen(&self) -> PyResult<BitGen<'py>> {
42+
let capsule = self.as_any().getattr("capsule")?.downcast_into::<PyCapsule>()?;
43+
assert_eq!(capsule.name()?, Some(c"BitGenerator"));
44+
let ptr = capsule.pointer() as *mut npy_bitgen;
45+
// SAFETY: the lifetime of `ptr` is derived from the lifetime of `self`
46+
let ref_ = unsafe { ptr.as_mut::<'py>() }.ok_or_else(|| PyRuntimeError::new_err("Invalid BitGenerator capsule"))?;
47+
Ok(BitGen(ref_))
48+
}
49+
}
50+
51+
/// Wrapper for [`npy_bitgen`]
52+
pub struct BitGen<'a>(&'a mut npy_bitgen);
53+
54+
/// Methods for [`BitGen`]
55+
pub trait BitGenMethods {
3656
/// Returns the next random unsigned 64 bit integer
3757
fn next_uint64(&self) -> u64;
3858
/// Returns the next random unsigned 32 bit integer
@@ -43,23 +63,18 @@ pub trait BitGeneratorMethods {
4363
fn next_raw(&self) -> u64;
4464
}
4565

46-
// TODO: cache npy_bitgen pointer
47-
impl<'py> BitGeneratorMethods for Bound<'py, BitGenerator> {
66+
impl<'py> BitGenMethods for BitGen<'py> {
4867
fn next_uint64(&self) -> u64 {
49-
todo!()
68+
unsafe { (self.0.next_uint64)(self.0.state) }
5069
}
5170
fn next_uint32(&self) -> u32 {
52-
todo!()
71+
unsafe { (self.0.next_uint32)(self.0.state) }
5372
}
5473
fn next_double(&self) -> libc::c_double {
55-
todo!()
74+
unsafe { (self.0.next_double)(self.0.state) }
5675
}
5776
fn next_raw(&self) -> u64 {
58-
let mut api = get_bitgen_api(self.as_any()).expect("Could not get bitgen");
59-
unsafe {
60-
let api = api.as_mut();
61-
(api.next_raw)(api.state)
62-
}
77+
unsafe { (self.0.next_raw)(self.0.state) }
6378
}
6479
}
6580

@@ -71,7 +86,7 @@ mod tests {
7186
fn test_bitgen() -> PyResult<()> {
7287
Python::with_gil(|py| {
7388
let default_rng = py.import("numpy.random")?.getattr("default_rng")?;
74-
let bitgen = default_rng.call0()?.getattr("bit_generator")?.downcast_into::<BitGenerator>()?;
89+
let bitgen = default_rng.call0()?.getattr("bit_generator")?.downcast_into::<BitGenerator>()?.bit_gen()?;
7590
let res = bitgen.next_raw();
7691
dbg!(res);
7792
Ok(())

0 commit comments

Comments
 (0)