1
1
//! Safe interface for NumPy's random [`BitGenerator`]
2
2
3
- use pyo3:: { ffi, prelude:: * , sync:: GILOnceCell , types:: { PyCapsule , PyType } , PyTypeInfo , exceptions:: PyRuntimeError } ;
3
+ use pyo3:: {
4
+ exceptions:: PyRuntimeError ,
5
+ ffi,
6
+ prelude:: * ,
7
+ sync:: GILOnceCell ,
8
+ types:: { PyCapsule , PyType } ,
9
+ PyTypeInfo ,
10
+ } ;
4
11
5
12
use crate :: npyffi:: npy_bitgen;
6
13
@@ -39,11 +46,15 @@ pub trait BitGeneratorMethods<'py> {
39
46
40
47
impl < ' py > BitGeneratorMethods < ' py > for Bound < ' py , BitGenerator > {
41
48
fn bit_gen ( & self ) -> PyResult < BitGen < ' py > > {
42
- let capsule = self . as_any ( ) . getattr ( "capsule" ) ?. downcast_into :: < PyCapsule > ( ) ?;
49
+ let capsule = self
50
+ . as_any ( )
51
+ . getattr ( "capsule" ) ?
52
+ . downcast_into :: < PyCapsule > ( ) ?;
43
53
assert_eq ! ( capsule. name( ) ?, Some ( c"BitGenerator" ) ) ;
44
54
let ptr = capsule. pointer ( ) as * mut npy_bitgen ;
45
55
// SAFETY: the lifetime of `ptr` is derived from the lifetime of `self`
46
- let ref_ = unsafe { ptr. as_mut :: < ' py > ( ) } . ok_or_else ( || PyRuntimeError :: new_err ( "Invalid BitGenerator capsule" ) ) ?;
56
+ let ref_ = unsafe { ptr. as_mut :: < ' py > ( ) }
57
+ . ok_or_else ( || PyRuntimeError :: new_err ( "Invalid BitGenerator capsule" ) ) ?;
47
58
Ok ( BitGen ( ref_) )
48
59
}
49
60
}
@@ -93,13 +104,16 @@ impl rand::RngCore for BitGen<'_> {
93
104
#[ cfg( test) ]
94
105
mod tests {
95
106
use super :: * ;
96
-
107
+
97
108
fn get_bit_generator < ' py > ( py : Python < ' py > ) -> PyResult < Bound < ' py , BitGenerator > > {
98
109
let default_rng = py. import ( "numpy.random" ) ?. getattr ( "default_rng" ) ?;
99
- let bit_generator = default_rng. call0 ( ) ?. getattr ( "bit_generator" ) ?. downcast_into :: < BitGenerator > ( ) ?;
110
+ let bit_generator = default_rng
111
+ . call0 ( ) ?
112
+ . getattr ( "bit_generator" ) ?
113
+ . downcast_into :: < BitGenerator > ( ) ?;
100
114
Ok ( bit_generator)
101
115
}
102
-
116
+
103
117
#[ test]
104
118
fn bitgen ( ) -> PyResult < ( ) > {
105
119
Python :: with_gil ( |py| {
@@ -113,7 +127,7 @@ mod tests {
113
127
#[ test]
114
128
fn rand ( ) -> PyResult < ( ) > {
115
129
use rand:: Rng as _;
116
-
130
+
117
131
Python :: with_gil ( |py| {
118
132
let mut bitgen = get_bit_generator ( py) ?. bit_gen ( ) ?;
119
133
let _ = bitgen. random_ratio ( 2 , 3 ) ;
0 commit comments