|
9 | 9 | clippy::missing_safety_doc
|
10 | 10 | )]
|
11 | 11 |
|
12 |
| -use pyo3::{ffi, Python}; |
13 |
| -use std::ffi::CString; |
| 12 | +use std::mem::forget; |
14 | 13 | use std::os::raw::c_void;
|
15 |
| -use std::ptr::null_mut; |
16 |
| - |
17 |
| -fn get_numpy_api(_py: Python, module: &str, capsule: &str) -> *const *const c_void { |
18 |
| - let module = CString::new(module).unwrap(); |
19 |
| - let capsule = CString::new(capsule).unwrap(); |
20 |
| - unsafe { |
21 |
| - let module = ffi::PyImport_ImportModule(module.as_ptr()); |
22 |
| - assert!(!module.is_null(), "Failed to import NumPy module"); |
23 |
| - let capsule = ffi::PyObject_GetAttrString(module as _, capsule.as_ptr()); |
24 |
| - assert!(!capsule.is_null(), "Failed to get NumPy API capsule"); |
25 |
| - ffi::PyCapsule_GetPointer(capsule, null_mut()) as _ |
26 |
| - } |
| 14 | + |
| 15 | +use pyo3::{ |
| 16 | + types::{PyCapsule, PyModule}, |
| 17 | + Py, PyResult, PyTryInto, Python, |
| 18 | +}; |
| 19 | + |
| 20 | +fn get_numpy_api(py: Python, module: &str, capsule: &str) -> PyResult<*const *const c_void> { |
| 21 | + let module = PyModule::import(py, module)?; |
| 22 | + let capsule: &PyCapsule = module.getattr(capsule)?.try_into()?; |
| 23 | + |
| 24 | + let api = capsule.pointer() as *const *const c_void; |
| 25 | + |
| 26 | + // Intentionally leak a reference to the capsule |
| 27 | + // so we can safely cache a pointer into its interior. |
| 28 | + forget(Py::<PyCapsule>::from(capsule)); |
| 29 | + |
| 30 | + Ok(api) |
27 | 31 | }
|
28 | 32 |
|
29 | 33 | // Implements wrappers for NumPy's Array and UFunc API
|
|
0 commit comments