2
2
use libc:: FILE ;
3
3
use pyo3:: ffi:: { self , PyObject , PyTypeObject } ;
4
4
use std:: os:: raw:: * ;
5
- use std:: { cell:: Cell , ptr} ;
5
+ use std:: ptr:: null_mut;
6
+ use std:: sync:: atomic:: { AtomicPtr , Ordering } ;
6
7
7
8
use crate :: npyffi:: * ;
8
9
@@ -12,7 +13,7 @@ const CAPSULE_NAME: &str = "_ARRAY_API";
12
13
/// A global variable which stores a ['capsule'](https://docs.python.org/3/c-api/capsule.html)
13
14
/// pointer to [Numpy Array API](https://numpy.org/doc/stable/reference/c-api/array.html).
14
15
///
15
- /// You can acceess raw c APIs via this variable and its Deref implementation .
16
+ /// You can acceess raw C APIs via this variable.
16
17
///
17
18
/// See [PyArrayAPI](struct.PyArrayAPI.html) for what methods you can use via this variable.
18
19
///
@@ -31,28 +32,35 @@ pub static PY_ARRAY_API: PyArrayAPI = PyArrayAPI::new();
31
32
32
33
/// See [PY_ARRAY_API] for more.
33
34
pub struct PyArrayAPI {
34
- api : Cell < * const * const c_void > ,
35
+ api : AtomicPtr < * const c_void > ,
35
36
}
36
37
37
38
impl PyArrayAPI {
38
39
const fn new ( ) -> Self {
39
40
Self {
40
- api : Cell :: new ( ptr :: null_mut ( ) ) ,
41
+ api : AtomicPtr :: new ( null_mut ( ) ) ,
41
42
}
42
43
}
43
- fn get ( & self , offset : isize ) -> * const * const c_void {
44
- if self . api . get ( ) . is_null ( ) {
45
- Python :: with_gil ( |py| {
46
- let api = get_numpy_api ( py, MOD_NAME , CAPSULE_NAME ) ;
47
- self . api . set ( api) ;
48
- } ) ;
44
+ #[ cold]
45
+ fn init ( & self ) -> * const * const c_void {
46
+ Python :: with_gil ( |py| {
47
+ let mut api = self . api . load ( Ordering :: Relaxed ) as * const * const c_void ;
48
+ if api. is_null ( ) {
49
+ api = get_numpy_api ( py, MOD_NAME , CAPSULE_NAME ) ;
50
+ self . api . store ( api as * mut _ , Ordering :: Release ) ;
51
+ }
52
+ api
53
+ } )
54
+ }
55
+ unsafe fn get ( & self , offset : isize ) -> * const * const c_void {
56
+ let mut api = self . api . load ( Ordering :: Acquire ) as * const * const c_void ;
57
+ if api. is_null ( ) {
58
+ api = self . init ( ) ;
49
59
}
50
- unsafe { self . api . get ( ) . offset ( offset) }
60
+ api. offset ( offset)
51
61
}
52
62
}
53
63
54
- unsafe impl Sync for PyArrayAPI { }
55
-
56
64
impl PyArrayAPI {
57
65
impl_api ! [ 0 ; PyArray_GetNDArrayCVersion ( ) -> c_uint] ;
58
66
impl_api ! [ 40 ; PyArray_SetNumericOps ( dict: * mut PyObject ) -> c_int] ;
0 commit comments