Skip to content

Commit b611943

Browse files
committed
fix and test
1 parent 07e2416 commit b611943

File tree

2 files changed

+56
-11
lines changed

2 files changed

+56
-11
lines changed

src/npyffi/random.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyCapsule};
66
#[derive(Debug, Clone, Copy)]
77
pub struct npy_bitgen {
88
pub state: *mut c_void,
9-
pub next_uint64: NonNull<unsafe extern "C" fn(*mut c_void) -> super::npy_uint64>, //nogil
10-
pub next_uint32: NonNull<unsafe extern "C" fn(*mut c_void) -> super::npy_uint32>, //nogil
11-
pub next_double: NonNull<unsafe extern "C" fn(*mut c_void) -> libc::c_double>, //nogil
12-
pub next_raw: NonNull<unsafe extern "C" fn(*mut c_void) -> super::npy_uint64>, //nogil
9+
pub next_uint64: unsafe extern "C" fn(*mut c_void) -> super::npy_uint64, //nogil
10+
pub next_uint32: unsafe extern "C" fn(*mut c_void) -> super::npy_uint32, //nogil
11+
pub next_double: unsafe extern "C" fn(*mut c_void) -> libc::c_double, //nogil
12+
pub next_raw: unsafe extern "C" fn(*mut c_void) -> super::npy_uint64, //nogil
1313
}
1414

15-
pub fn get_bitgen_api<'py>(bitgen: Bound<'py, PyAny>) -> PyResult<NonNull<npy_bitgen>> {
15+
pub fn get_bitgen_api<'py>(bitgen: &Bound<'py, PyAny>) -> PyResult<NonNull<npy_bitgen>> {
1616
let capsule = bitgen.getattr("capsule")?.downcast_into::<PyCapsule>()?;
1717
assert_eq!(capsule.name()?, Some(c"BitGenerator"));
1818
let ptr = capsule.pointer() as *mut npy_bitgen;

src/random.rs

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
1-
//! Safe interface for NumPy's random [`BitGenerator`][]
2-
//!
3-
//! `BitGenerator`: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html
1+
//! Safe interface for NumPy's random [`BitGenerator`]
42
53
use pyo3::{ffi, prelude::*, sync::GILOnceCell, types::PyType, PyTypeInfo};
64

75
use crate::npyffi::get_bitgen_api;
86

9-
///! Wrapper for NumPy's random [`BitGenerator`][]
10-
///
11-
///! [BitGenerator]: https://numpy.org/doc/stable/reference/random/bit_generators/generated/numpy.random.BitGenerator.html
7+
///! Wrapper for NumPy's random [`BitGenerator`][bg]
8+
///!
9+
///! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html
1210
#[repr(transparent)]
1311
pub struct BitGenerator(PyAny);
1412

@@ -33,3 +31,50 @@ unsafe impl PyTypeInfo for BitGenerator {
3331
}
3432
}
3533

34+
/// Methods for [`BitGenerator`]
35+
pub trait BitGeneratorMethods {
36+
/// Returns the next random unsigned 64 bit integer
37+
fn next_uint64(&self) -> u64;
38+
/// Returns the next random unsigned 32 bit integer
39+
fn next_uint32(&self) -> u32;
40+
/// Returns the next random double
41+
fn next_double(&self) -> libc::c_double;
42+
/// Returns the next raw value (can be used for testing)
43+
fn next_raw(&self) -> u64;
44+
}
45+
46+
// TODO: cache npy_bitgen pointer
47+
impl<'py> BitGeneratorMethods for Bound<'py, BitGenerator> {
48+
fn next_uint64(&self) -> u64 {
49+
todo!()
50+
}
51+
fn next_uint32(&self) -> u32 {
52+
todo!()
53+
}
54+
fn next_double(&self) -> libc::c_double {
55+
todo!()
56+
}
57+
fn next_raw(&self) -> u64 {
58+
let mut api = get_bitgen_api(self.as_any()).expect("Could not get bitgen");
59+
unsafe {
60+
let api = api.as_mut();
61+
(api.next_raw)(api.state)
62+
}
63+
}
64+
}
65+
66+
#[cfg(test)]
67+
mod tests {
68+
use super::*;
69+
70+
#[test]
71+
fn test_bitgen() -> PyResult<()> {
72+
Python::with_gil(|py| {
73+
let default_rng = py.import("numpy.random")?.getattr("default_rng")?;
74+
let bitgen = default_rng.call0()?.getattr("bit_generator")?.downcast_into::<BitGenerator>()?;
75+
let res = bitgen.next_raw();
76+
dbg!(res);
77+
Ok(())
78+
})
79+
}
80+
}

0 commit comments

Comments
 (0)