Skip to content

Commit eed5b19

Browse files
committed
implement rand
1 parent 37d360e commit eed5b19

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ num-integer = "0.1"
2323
num-traits = "0.2"
2424
ndarray = ">= 0.15, < 0.17"
2525
pyo3 = { version = "0.25.0", default-features = false, features = ["macros"] }
26+
rand = { version = "0.9.1", default-features = false, optional = true }
2627
rustc-hash = "2.0"
2728

2829
[dev-dependencies]

src/random.rs

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,18 +77,47 @@ impl<'py> BitGen<'py> {
7777
}
7878
}
7979

80+
#[cfg(feature = "rand")]
81+
impl rand::RngCore for BitGen<'_> {
82+
fn next_u32(&mut self) -> u32 {
83+
self.next_uint32()
84+
}
85+
fn next_u64(&mut self) -> u64 {
86+
self.next_uint64()
87+
}
88+
fn fill_bytes(&mut self, dst: &mut [u8]) {
89+
rand::rand_core::impls::fill_bytes_via_next(self, dst)
90+
}
91+
}
92+
8093
#[cfg(test)]
8194
mod tests {
8295
use super::*;
8396

97+
fn get_bit_generator<'py>(py: Python<'py>) -> PyResult<Bound<'py, BitGenerator>> {
98+
let default_rng = py.import("numpy.random")?.getattr("default_rng")?;
99+
let bit_generator = default_rng.call0()?.getattr("bit_generator")?.downcast_into::<BitGenerator>()?;
100+
Ok(bit_generator)
101+
}
102+
84103
#[test]
85-
fn test_bitgen() -> PyResult<()> {
104+
fn bitgen() -> PyResult<()> {
86105
Python::with_gil(|py| {
87-
let default_rng = py.import("numpy.random")?.getattr("default_rng")?;
88-
let bitgen = default_rng.call0()?.getattr("bit_generator")?.downcast_into::<BitGenerator>()?.bit_gen()?;
89-
let res = bitgen.next_raw();
90-
dbg!(res);
106+
let bitgen = get_bit_generator(py)?.bit_gen()?;
107+
let _ = bitgen.next_raw();
91108
Ok(())
92109
})
93-
}
110+
}
111+
112+
#[cfg(feature = "rand")]
113+
#[test]
114+
fn rand() -> PyResult<()> {
115+
use rand::Rng as _;
116+
117+
Python::with_gil(|py| {
118+
let mut bitgen = get_bit_generator(py)?.bit_gen()?;
119+
let _ = bitgen.random_ratio(2, 3);
120+
Ok(())
121+
})
122+
}
94123
}

0 commit comments

Comments
 (0)