Skip to content

Commit 6d6084f

Browse files
authored
Merge pull request #222 from adamreichold/sync-api-globals
Make API globals thread safe using atomics
2 parents 131bb5f + a0ecc3f commit 6d6084f

File tree

2 files changed

+41
-25
lines changed

2 files changed

+41
-25
lines changed

src/npyffi/array.rs

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
use libc::FILE;
33
use pyo3::ffi::{self, PyObject, PyTypeObject};
44
use std::os::raw::*;
5-
use std::{cell::Cell, ptr};
5+
use std::ptr::null_mut;
6+
use std::sync::atomic::{AtomicPtr, Ordering};
67

78
use crate::npyffi::*;
89

@@ -12,7 +13,7 @@ const CAPSULE_NAME: &str = "_ARRAY_API";
1213
/// A global variable which stores a ['capsule'](https://docs.python.org/3/c-api/capsule.html)
1314
/// pointer to [Numpy Array API](https://numpy.org/doc/stable/reference/c-api/array.html).
1415
///
15-
/// You can acceess raw c APIs via this variable and its Deref implementation.
16+
/// You can acceess raw C APIs via this variable.
1617
///
1718
/// See [PyArrayAPI](struct.PyArrayAPI.html) for what methods you can use via this variable.
1819
///
@@ -31,28 +32,35 @@ pub static PY_ARRAY_API: PyArrayAPI = PyArrayAPI::new();
3132

3233
/// See [PY_ARRAY_API] for more.
3334
pub struct PyArrayAPI {
34-
api: Cell<*const *const c_void>,
35+
api: AtomicPtr<*const c_void>,
3536
}
3637

3738
impl PyArrayAPI {
3839
const fn new() -> Self {
3940
Self {
40-
api: Cell::new(ptr::null_mut()),
41+
api: AtomicPtr::new(null_mut()),
4142
}
4243
}
43-
fn get(&self, offset: isize) -> *const *const c_void {
44-
if self.api.get().is_null() {
45-
Python::with_gil(|py| {
46-
let api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
47-
self.api.set(api);
48-
});
44+
#[cold]
45+
fn init(&self) -> *const *const c_void {
46+
Python::with_gil(|py| {
47+
let mut api = self.api.load(Ordering::Relaxed) as *const *const c_void;
48+
if api.is_null() {
49+
api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
50+
self.api.store(api as *mut _, Ordering::Release);
51+
}
52+
api
53+
})
54+
}
55+
unsafe fn get(&self, offset: isize) -> *const *const c_void {
56+
let mut api = self.api.load(Ordering::Acquire) as *const *const c_void;
57+
if api.is_null() {
58+
api = self.init();
4959
}
50-
unsafe { self.api.get().offset(offset) }
60+
api.offset(offset)
5161
}
5262
}
5363

54-
unsafe impl Sync for PyArrayAPI {}
55-
5664
impl PyArrayAPI {
5765
impl_api![0; PyArray_GetNDArrayCVersion() -> c_uint];
5866
impl_api![40; PyArray_SetNumericOps(dict: *mut PyObject) -> c_int];

src/npyffi/ufunc.rs

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
//! Low-Level binding for [UFunc API](https://numpy.org/doc/stable/reference/c-api/ufunc.html)
22
33
use std::os::raw::*;
4-
use std::{cell::Cell, ptr};
4+
use std::ptr::null_mut;
5+
use std::sync::atomic::{AtomicPtr, Ordering};
56

67
use pyo3::ffi::PyObject;
78
use pyo3::Python;
@@ -18,28 +19,35 @@ const CAPSULE_NAME: &str = "_UFUNC_API";
1819
pub static PY_UFUNC_API: PyUFuncAPI = PyUFuncAPI::new();
1920

2021
pub struct PyUFuncAPI {
21-
api: Cell<*const *const c_void>,
22+
api: AtomicPtr<*const c_void>,
2223
}
2324

2425
impl PyUFuncAPI {
2526
const fn new() -> Self {
2627
Self {
27-
api: Cell::new(ptr::null_mut()),
28+
api: AtomicPtr::new(null_mut()),
2829
}
2930
}
30-
fn get(&self, offset: isize) -> *const *const c_void {
31-
if self.api.get().is_null() {
32-
Python::with_gil(|py| {
33-
let api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
34-
self.api.set(api);
35-
});
31+
#[cold]
32+
fn init(&self) -> *const *const c_void {
33+
Python::with_gil(|py| {
34+
let mut api = self.api.load(Ordering::Relaxed) as *const *const c_void;
35+
if api.is_null() {
36+
api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
37+
self.api.store(api as *mut _, Ordering::Release);
38+
}
39+
api
40+
})
41+
}
42+
unsafe fn get(&self, offset: isize) -> *const *const c_void {
43+
let mut api = self.api.load(Ordering::Acquire) as *const *const c_void;
44+
if api.is_null() {
45+
api = self.init();
3646
}
37-
unsafe { self.api.get().offset(offset) }
47+
api.offset(offset)
3848
}
3949
}
4050

41-
unsafe impl Sync for PyUFuncAPI {}
42-
4351
impl PyUFuncAPI {
4452
impl_api![1; PyUFunc_FromFuncAndData(func: *mut PyUFuncGenericFunction, data: *mut *mut c_void, types: *mut c_char, ntypes: c_int, nin: c_int, nout: c_int, identity: c_int, name: *const c_char, doc: *const c_char, unused: c_int) -> *mut PyObject];
4553
impl_api![2; PyUFunc_RegisterLoopForType(ufunc: *mut PyUFuncObject, usertype: c_int, function: PyUFuncGenericFunction, arg_types: *mut c_int, data: *mut c_void) -> c_int];

0 commit comments

Comments
 (0)