Skip to content

Commit f64d8ec

Browse files
committed
Use GILOnceCell when caching the dynamic borrow checking capsule API.
1 parent 4bbadda commit f64d8ec

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

src/borrow/shared.rs

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
use std::cell::Cell;
21
use std::collections::hash_map::Entry;
32
use std::ffi::{c_void, CString};
3+
use std::mem::forget;
44
use std::os::raw::{c_char, c_int};
5-
use std::ptr::null;
65
use std::slice::from_raw_parts;
76

87
use num_integer::gcd;
9-
use pyo3::{exceptions::PyTypeError, types::PyCapsule, PyResult, PyTryInto, Python};
8+
use pyo3::{
9+
exceptions::PyTypeError, once_cell::GILOnceCell, types::PyCapsule, Py, PyResult, PyTryInto,
10+
Python,
11+
};
1012
use rustc_hash::FxHashMap;
1113

1214
use crate::array::get_array_module;
@@ -98,25 +100,23 @@ unsafe extern "C" fn release_mut_shared(flags: *mut c_void, array: *mut PyArrayO
98100

99101
// This global state is a cache used to access the shared borrow checking API from this extension:
100102

101-
struct SharedPtr(Cell<*const Shared>);
103+
struct SharedPtr(GILOnceCell<*const Shared>);
104+
105+
unsafe impl Send for SharedPtr {}
102106

103107
unsafe impl Sync for SharedPtr {}
104108

105-
static SHARED: SharedPtr = SharedPtr(Cell::new(null()));
109+
static SHARED: SharedPtr = SharedPtr(GILOnceCell::new());
106110

107111
fn get_or_insert_shared<'py>(py: Python<'py>) -> PyResult<&'py Shared> {
108-
let mut shared = SHARED.0.get();
109-
110-
if shared.is_null() {
111-
shared = insert_shared(py)?;
112-
}
112+
let shared = SHARED.0.get_or_try_init(py, || insert_shared(py))?;
113113

114114
// SAFETY: We inserted the capsule if it was missing
115115
// and verified that it contains a compatible version.
116-
Ok(unsafe { &*shared })
116+
Ok(unsafe { &**shared })
117117
}
118118

119-
// This function will publish this extensions version of the shared borrow checking API
119+
// This function will publish this extension's version of the shared borrow checking API
120120
// as a capsule placed at `numpy.core.multiarray._RUST_NUMPY_BORROW_CHECKING_API` and
121121
// immediately initialize the cache used access it from this extension.
122122

@@ -161,9 +161,11 @@ fn insert_shared(py: Python) -> PyResult<*const Shared> {
161161
)));
162162
}
163163

164-
let shared = capsule.pointer() as *const Shared;
165-
SHARED.0.set(shared);
166-
Ok(shared)
164+
// Intentionally leak a reference to the capsule
165+
// so we can safely cache a pointer into its interior.
166+
forget(Py::<PyCapsule>::from(capsule));
167+
168+
Ok(capsule.pointer() as *const Shared)
167169
}
168170

169171
// These entry points will be used to access the shared borrow checking API from this extension:

0 commit comments

Comments
 (0)