|
1 |
| -use std::cell::Cell; |
2 | 1 | use std::collections::hash_map::Entry;
|
3 | 2 | use std::ffi::{c_void, CString};
|
| 3 | +use std::mem::forget; |
4 | 4 | use std::os::raw::{c_char, c_int};
|
5 |
| -use std::ptr::null; |
6 | 5 | use std::slice::from_raw_parts;
|
7 | 6 |
|
8 | 7 | use num_integer::gcd;
|
9 |
| -use pyo3::{exceptions::PyTypeError, types::PyCapsule, PyResult, PyTryInto, Python}; |
| 8 | +use pyo3::{ |
| 9 | + exceptions::PyTypeError, once_cell::GILOnceCell, types::PyCapsule, Py, PyResult, PyTryInto, |
| 10 | + Python, |
| 11 | +}; |
10 | 12 | use rustc_hash::FxHashMap;
|
11 | 13 |
|
12 | 14 | use crate::array::get_array_module;
|
@@ -98,25 +100,23 @@ unsafe extern "C" fn release_mut_shared(flags: *mut c_void, array: *mut PyArrayO
|
98 | 100 |
|
99 | 101 | // This global state is a cache used to access the shared borrow checking API from this extension:
|
100 | 102 |
|
101 |
| -struct SharedPtr(Cell<*const Shared>); |
| 103 | +struct SharedPtr(GILOnceCell<*const Shared>); |
| 104 | + |
| 105 | +unsafe impl Send for SharedPtr {} |
102 | 106 |
|
103 | 107 | unsafe impl Sync for SharedPtr {}
|
104 | 108 |
|
105 |
| -static SHARED: SharedPtr = SharedPtr(Cell::new(null())); |
| 109 | +static SHARED: SharedPtr = SharedPtr(GILOnceCell::new()); |
106 | 110 |
|
107 | 111 | fn get_or_insert_shared<'py>(py: Python<'py>) -> PyResult<&'py Shared> {
|
108 |
| - let mut shared = SHARED.0.get(); |
109 |
| - |
110 |
| - if shared.is_null() { |
111 |
| - shared = insert_shared(py)?; |
112 |
| - } |
| 112 | + let shared = SHARED.0.get_or_try_init(py, || insert_shared(py))?; |
113 | 113 |
|
114 | 114 | // SAFETY: We inserted the capsule if it was missing
|
115 | 115 | // and verified that it contains a compatible version.
|
116 |
| - Ok(unsafe { &*shared }) |
| 116 | + Ok(unsafe { &**shared }) |
117 | 117 | }
|
118 | 118 |
|
119 |
| -// This function will publish this extensions version of the shared borrow checking API |
| 119 | +// This function will publish this extension's version of the shared borrow checking API |
120 | 120 | // as a capsule placed at `numpy.core.multiarray._RUST_NUMPY_BORROW_CHECKING_API` and
|
121 | 121 | // immediately initialize the cache used access it from this extension.
|
122 | 122 |
|
@@ -161,9 +161,11 @@ fn insert_shared(py: Python) -> PyResult<*const Shared> {
|
161 | 161 | )));
|
162 | 162 | }
|
163 | 163 |
|
164 |
| - let shared = capsule.pointer() as *const Shared; |
165 |
| - SHARED.0.set(shared); |
166 |
| - Ok(shared) |
| 164 | + // Intentionally leak a reference to the capsule |
| 165 | + // so we can safely cache a pointer into its interior. |
| 166 | + forget(Py::<PyCapsule>::from(capsule)); |
| 167 | + |
| 168 | + Ok(capsule.pointer() as *const Shared) |
167 | 169 | }
|
168 | 170 |
|
169 | 171 | // These entry points will be used to access the shared borrow checking API from this extension:
|
|
0 commit comments