Skip to content

Commit 06d6ce1

Browse files
committed
WIP bitgen
1 parent 8bfbb27 commit 06d6ce1

File tree

7 files changed

+62
-2
lines changed

7 files changed

+62
-2
lines changed

.vscode/settings.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"rust-analyzer.cargo.features": "all"
3+
}

examples/simple/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ fn rust_ext<'py>(m: &Bound<'py, PyModule>) -> PyResult<()> {
113113
// This crate follows a strongly-typed approach to wrapping NumPy arrays
114114
// while Python API are often expected to work with multiple element types.
115115
//
116-
// That kind of limited polymorphis can be recovered by accepting an enumerated type
116+
// That kind of limited polymorphism can be recovered by accepting an enumerated type
117117
// covering the supported element types and dispatching into a generic implementation.
118118
#[derive(FromPyObject)]
119119
enum SupportedArray<'py> {

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ pub mod datetime;
7979
mod dtype;
8080
mod error;
8181
pub mod npyffi;
82+
pub mod random;
8283
mod slice_container;
8384
mod strings;
8485
mod sum_products;

src/npyffi/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,13 @@ macro_rules! impl_api {
9494
pub mod array;
9595
pub mod flags;
9696
pub mod objects;
97+
pub mod random;
9798
pub mod types;
9899
pub mod ufunc;
99100

100101
pub use self::array::*;
101102
pub use self::flags::*;
102103
pub use self::objects::*;
104+
pub use self::random::*;
103105
pub use self::types::*;
104106
pub use self::ufunc::*;

src/npyffi/objects.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//! Low-Lebel binding for NumPy C API C-objects
1+
//! Low-Level binding for NumPy C API C-objects
22
//!
33
//! <https://numpy.org/doc/stable/reference/c-api/types-and-structures.html>
44
#![allow(non_camel_case_types)]

src/npyffi/random.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
use std::{ffi::c_void, ptr::NonNull};
2+
3+
use pyo3::{prelude::*, types::PyCapsule};
4+
5+
#[repr(C)]
6+
#[derive(Debug, Clone, Copy)]
7+
pub struct npy_bitgen {
8+
pub state: *mut c_void,
9+
pub next_uint64: NonNull<unsafe extern "C" fn(*mut c_void) -> super::npy_uint64>, //nogil
10+
pub next_uint32: NonNull<unsafe extern "C" fn(*mut c_void) -> super::npy_uint32>, //nogil
11+
pub next_double: NonNull<unsafe extern "C" fn(*mut c_void) -> libc::c_double>, //nogil
12+
pub next_raw: NonNull<unsafe extern "C" fn(*mut c_void) -> super::npy_uint64>, //nogil
13+
}
14+
15+
pub fn get_bitgen_api<'py>(bitgen: Bound<'py, PyAny>) -> PyResult<*mut npy_bitgen> {
16+
let capsule = bitgen.getattr("capsule")?.downcast_into::<PyCapsule>()?;
17+
assert_eq!(capsule.name()?, Some(c"BitGenerator"));
18+
Ok(capsule.pointer() as *mut npy_bitgen)
19+
}

src/random.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//! Safe interface for NumPy's random [`BitGenerator`][]
2+
//!
3+
//! `BitGenerator`: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html
4+
5+
use pyo3::{ffi, prelude::*, sync::GILOnceCell, types::PyType, PyTypeInfo};
6+
7+
use crate::npyffi::get_bitgen_api;
8+
9+
///! Wrapper for NumPy's random [`BitGenerator`][]
10+
///
11+
///! [BitGenerator]: https://numpy.org/doc/stable/reference/random/bit_generators/generated/numpy.random.BitGenerator.html
12+
#[repr(transparent)]
13+
pub struct BitGenerator(PyAny);
14+
15+
unsafe impl PyTypeInfo for BitGenerator {
16+
const NAME: &'static str = "BitGenerator";
17+
const MODULE: Option<&'static str> = Some("numpy.random");
18+
19+
fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject {
20+
const CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();
21+
let cls = CLS
22+
.get_or_try_init::<_, PyErr>(py, || {
23+
Ok(py
24+
.import("numpy.random")?
25+
.getattr("BitGenerator")?
26+
.downcast_into::<PyType>()?
27+
.unbind())
28+
})
29+
.expect("Failed to get BitGenerator type object")
30+
.clone_ref(py)
31+
.into_bound(py);
32+
cls.as_type_ptr()
33+
}
34+
}
35+

0 commit comments

Comments
 (0)