Skip to content

Commit d7fdc43

Browse files
committed
Intentionally leak a reference to the NumPy capsule API
1 parent f64d8ec commit d7fdc43

File tree

3 files changed

+22
-16
lines changed

3 files changed

+22
-16
lines changed

src/npyffi/array.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ impl PyArrayAPI {
4848
unsafe fn get(&self, py: Python, offset: isize) -> *const *const c_void {
4949
let api = self
5050
.0
51-
.get_or_init(py, || get_numpy_api(py, MOD_NAME, CAPSULE_NAME));
51+
.get_or_try_init(py, || get_numpy_api(py, MOD_NAME, CAPSULE_NAME))
52+
.expect("Failed to access NumPy array API capsule");
5253

5354
api.offset(offset)
5455
}

src/npyffi/mod.rs

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,25 @@
99
clippy::missing_safety_doc
1010
)]
1111

12-
use pyo3::{ffi, Python};
13-
use std::ffi::CString;
12+
use std::mem::forget;
1413
use std::os::raw::c_void;
15-
use std::ptr::null_mut;
16-
17-
fn get_numpy_api(_py: Python, module: &str, capsule: &str) -> *const *const c_void {
18-
let module = CString::new(module).unwrap();
19-
let capsule = CString::new(capsule).unwrap();
20-
unsafe {
21-
let module = ffi::PyImport_ImportModule(module.as_ptr());
22-
assert!(!module.is_null(), "Failed to import NumPy module");
23-
let capsule = ffi::PyObject_GetAttrString(module as _, capsule.as_ptr());
24-
assert!(!capsule.is_null(), "Failed to get NumPy API capsule");
25-
ffi::PyCapsule_GetPointer(capsule, null_mut()) as _
26-
}
14+
15+
use pyo3::{
16+
types::{PyCapsule, PyModule},
17+
Py, PyResult, PyTryInto, Python,
18+
};
19+
20+
fn get_numpy_api(py: Python, module: &str, capsule: &str) -> PyResult<*const *const c_void> {
21+
let module = PyModule::import(py, module)?;
22+
let capsule: &PyCapsule = module.getattr(capsule)?.try_into()?;
23+
24+
let api = capsule.pointer() as *const *const c_void;
25+
26+
// Intentionally leak a reference to the capsule
27+
// so we can safely cache a pointer into its interior.
28+
forget(Py::<PyCapsule>::from(capsule));
29+
30+
Ok(api)
2731
}
2832

2933
// Implements wrappers for NumPy's Array and UFunc API

src/npyffi/ufunc.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ impl PyUFuncAPI {
2323
unsafe fn get(&self, py: Python, offset: isize) -> *const *const c_void {
2424
let api = self
2525
.0
26-
.get_or_init(py, || get_numpy_api(py, MOD_NAME, CAPSULE_NAME));
26+
.get_or_try_init(py, || get_numpy_api(py, MOD_NAME, CAPSULE_NAME))
27+
.expect("Failed to access NumPy ufunc API capsule");
2728

2829
api.offset(offset)
2930
}

0 commit comments

Comments
 (0)