1
1
//! Safe interface for NumPy's random [`BitGenerator`]
2
2
3
- use pyo3:: { ffi, prelude:: * , sync:: GILOnceCell , types:: PyType , PyTypeInfo } ;
3
+ use pyo3:: { ffi, prelude:: * , sync:: GILOnceCell , types:: { PyCapsule , PyType } , PyTypeInfo , exceptions :: PyRuntimeError } ;
4
4
5
- use crate :: npyffi:: get_bitgen_api ;
5
+ use crate :: npyffi:: npy_bitgen ;
6
6
7
- ///! Wrapper for NumPy's random [` BitGenerator`][bg]
7
+ ///! Wrapper for [`np.random. BitGenerator`][bg]
8
8
///!
9
9
///! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html
10
10
#[ repr( transparent) ]
@@ -32,7 +32,27 @@ unsafe impl PyTypeInfo for BitGenerator {
32
32
}
33
33
34
34
/// Methods for [`BitGenerator`]
35
- pub trait BitGeneratorMethods {
35
+ pub trait BitGeneratorMethods < ' py > {
36
+ /// Returns a new [`BitGen`]
37
+ fn bit_gen ( & self ) -> PyResult < BitGen < ' py > > ;
38
+ }
39
+
40
+ impl < ' py > BitGeneratorMethods < ' py > for Bound < ' py , BitGenerator > {
41
+ fn bit_gen ( & self ) -> PyResult < BitGen < ' py > > {
42
+ let capsule = self . as_any ( ) . getattr ( "capsule" ) ?. downcast_into :: < PyCapsule > ( ) ?;
43
+ assert_eq ! ( capsule. name( ) ?, Some ( c"BitGenerator" ) ) ;
44
+ let ptr = capsule. pointer ( ) as * mut npy_bitgen ;
45
+ // SAFETY: the lifetime of `ptr` is derived from the lifetime of `self`
46
+ let ref_ = unsafe { ptr. as_mut :: < ' py > ( ) } . ok_or_else ( || PyRuntimeError :: new_err ( "Invalid BitGenerator capsule" ) ) ?;
47
+ Ok ( BitGen ( ref_) )
48
+ }
49
+ }
50
+
51
+ /// Wrapper for [`npy_bitgen`]
52
+ pub struct BitGen < ' a > ( & ' a mut npy_bitgen ) ;
53
+
54
+ /// Methods for [`BitGen`]
55
+ pub trait BitGenMethods {
36
56
/// Returns the next random unsigned 64 bit integer
37
57
fn next_uint64 ( & self ) -> u64 ;
38
58
/// Returns the next random unsigned 32 bit integer
@@ -43,23 +63,18 @@ pub trait BitGeneratorMethods {
43
63
fn next_raw ( & self ) -> u64 ;
44
64
}
45
65
46
- // TODO: cache npy_bitgen pointer
47
- impl < ' py > BitGeneratorMethods for Bound < ' py , BitGenerator > {
66
+ impl < ' py > BitGenMethods for BitGen < ' py > {
48
67
fn next_uint64 ( & self ) -> u64 {
49
- todo ! ( )
68
+ unsafe { ( self . 0 . next_uint64 ) ( self . 0 . state ) }
50
69
}
51
70
fn next_uint32 ( & self ) -> u32 {
52
- todo ! ( )
71
+ unsafe { ( self . 0 . next_uint32 ) ( self . 0 . state ) }
53
72
}
54
73
fn next_double ( & self ) -> libc:: c_double {
55
- todo ! ( )
74
+ unsafe { ( self . 0 . next_double ) ( self . 0 . state ) }
56
75
}
57
76
fn next_raw ( & self ) -> u64 {
58
- let mut api = get_bitgen_api ( self . as_any ( ) ) . expect ( "Could not get bitgen" ) ;
59
- unsafe {
60
- let api = api. as_mut ( ) ;
61
- ( api. next_raw ) ( api. state )
62
- }
77
+ unsafe { ( self . 0 . next_raw ) ( self . 0 . state ) }
63
78
}
64
79
}
65
80
@@ -71,7 +86,7 @@ mod tests {
71
86
fn test_bitgen ( ) -> PyResult < ( ) > {
72
87
Python :: with_gil ( |py| {
73
88
let default_rng = py. import ( "numpy.random" ) ?. getattr ( "default_rng" ) ?;
74
- let bitgen = default_rng. call0 ( ) ?. getattr ( "bit_generator" ) ?. downcast_into :: < BitGenerator > ( ) ?;
89
+ let bitgen = default_rng. call0 ( ) ?. getattr ( "bit_generator" ) ?. downcast_into :: < BitGenerator > ( ) ?. bit_gen ( ) ? ;
75
90
let res = bitgen. next_raw ( ) ;
76
91
dbg ! ( res) ;
77
92
Ok ( ( ) )
0 commit comments