Skip to content

Commit 8ae2634

Browse files
committed
Instead of slowing everything down via atomics, utilize that we already hold the GIL when accessing PY_ARRAY_API and PY_UFUNC_API.
1 parent bc4bf35 commit 8ae2634

File tree

7 files changed

+135
-108
lines changed

7 files changed

+135
-108
lines changed

src/array.rs

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ unsafe impl<T: Element, D: Dimension> PyTypeInfo for PyArray<T, D> {
120120
const MODULE: ::std::option::Option<&'static str> = Some("numpy");
121121

122122
#[inline]
123-
fn type_object_raw(_py: Python) -> *mut ffi::PyTypeObject {
124-
unsafe { npyffi::PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type) }
123+
fn type_object_raw(py: Python) -> *mut ffi::PyTypeObject {
124+
unsafe { npyffi::PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type) }
125125
}
126126

127127
fn is_type_of(ob: &PyAny) -> bool {
@@ -144,7 +144,7 @@ impl<'py, T: Element, D: Dimension> FromPyObject<'py> for &'py PyArray<T, D> {
144144
// 3. Checks if the dimension is same as D
145145
fn extract(ob: &'py PyAny) -> PyResult<Self> {
146146
let array = unsafe {
147-
if npyffi::PyArray_Check(ob.as_ptr()) == 0 {
147+
if npyffi::PyArray_Check(ob.py(), ob.as_ptr()) == 0 {
148148
return Err(PyDowncastError::new(ob, "PyArray<T, D>").into());
149149
}
150150
&*(ob as *const PyAny as *const PyArray<T, D>)
@@ -376,7 +376,6 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
376376
}
377377

378378
fn ndarray_shape_ptr(&self) -> (StrideShape<D>, *mut T, InvertedAxises) {
379-
const ERR_MSG: &str = "PyArray::ndarray_shape: dimension mismatching";
380379
let shape_slice = self.shape();
381380
let shape: Shape<_> = Dim(self.dims()).into();
382381
let sizeof_t = mem::size_of::<T>();
@@ -400,7 +399,8 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
400399
new_strides[i] = strides[i] as usize / sizeof_t;
401400
}
402401
}
403-
let st = D::from_dimension(&Dim(new_strides)).expect(ERR_MSG);
402+
let st = D::from_dimension(&Dim(new_strides))
403+
.expect("PyArray::ndarray_shape: dimension mismatching");
404404
(shape.strides(st), data_ptr, InvertedAxises(inverted_axises))
405405
}
406406

@@ -461,7 +461,8 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
461461
{
462462
let dims = dims.into_dimension();
463463
let ptr = PY_ARRAY_API.PyArray_NewFromDescr(
464-
PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type),
464+
py,
465+
PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type),
465466
T::get_dtype(py).into_dtype_ptr(),
466467
dims.ndim_cint(),
467468
dims.as_dims_ptr(),
@@ -485,7 +486,8 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
485486
{
486487
let dims = dims.into_dimension();
487488
let ptr = PY_ARRAY_API.PyArray_NewFromDescr(
488-
PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type),
489+
py,
490+
PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type),
489491
T::get_dtype(py).into_dtype_ptr(),
490492
dims.ndim_cint(),
491493
dims.as_dims_ptr(),
@@ -496,6 +498,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
496498
);
497499

498500
PY_ARRAY_API.PyArray_SetBaseObject(
501+
py,
499502
ptr as *mut npyffi::PyArrayObject,
500503
container as *mut ffi::PyObject,
501504
);
@@ -600,6 +603,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
600603
let dims = dims.into_dimension();
601604
unsafe {
602605
let ptr = PY_ARRAY_API.PyArray_Zeros(
606+
py,
603607
dims.ndim_cint(),
604608
dims.as_dims_ptr(),
605609
T::get_dtype(py).into_dtype_ptr(),
@@ -1046,7 +1050,7 @@ impl<T: Element> PyArray<T, Ix1> {
10461050
/// });
10471051
/// ```
10481052
pub fn resize(&self, new_elems: usize) -> PyResult<()> {
1049-
self.resize_([new_elems], 1, NPY_ORDER::NPY_ANYORDER)
1053+
self.resize_(self.py(), [new_elems], 1, NPY_ORDER::NPY_ANYORDER)
10501054
}
10511055

10521056
/// Iterates all elements of this array.
@@ -1059,6 +1063,7 @@ impl<T: Element> PyArray<T, Ix1> {
10591063

10601064
fn resize_<D: IntoDimension>(
10611065
&self,
1066+
py: Python,
10621067
dims: D,
10631068
check_ref: c_int,
10641069
order: NPY_ORDER,
@@ -1067,6 +1072,7 @@ impl<T: Element> PyArray<T, Ix1> {
10671072
let mut np_dims = dims.to_npy_dims();
10681073
let res = unsafe {
10691074
PY_ARRAY_API.PyArray_Resize(
1075+
py,
10701076
self.as_array_ptr(),
10711077
&mut np_dims as *mut npyffi::PyArray_Dims,
10721078
check_ref,
@@ -1175,7 +1181,7 @@ impl<T: Element, D> PyArray<T, D> {
11751181
pub fn copy_to<U: Element>(&self, other: &PyArray<U, D>) -> PyResult<()> {
11761182
let self_ptr = self.as_array_ptr();
11771183
let other_ptr = other.as_array_ptr();
1178-
let result = unsafe { PY_ARRAY_API.PyArray_CopyInto(other_ptr, self_ptr) };
1184+
let result = unsafe { PY_ARRAY_API.PyArray_CopyInto(self.py(), other_ptr, self_ptr) };
11791185
if result == -1 {
11801186
Err(PyErr::fetch(self.py()))
11811187
} else {
@@ -1197,6 +1203,7 @@ impl<T: Element, D> PyArray<T, D> {
11971203
pub fn cast<'py, U: Element>(&'py self, is_fortran: bool) -> PyResult<&'py PyArray<U, D>> {
11981204
let ptr = unsafe {
11991205
PY_ARRAY_API.PyArray_CastToType(
1206+
self.py(),
12001207
self.as_array_ptr(),
12011208
U::get_dtype(self.py()).into_dtype_ptr(),
12021209
if is_fortran { -1 } else { 0 },
@@ -1250,6 +1257,7 @@ impl<T: Element, D> PyArray<T, D> {
12501257
let mut np_dims = dims.to_npy_dims();
12511258
let ptr = unsafe {
12521259
PY_ARRAY_API.PyArray_Newshape(
1260+
self.py(),
12531261
self.as_array_ptr(),
12541262
&mut np_dims as *mut npyffi::PyArray_Dims,
12551263
order,
@@ -1281,6 +1289,7 @@ impl<T: Element + AsPrimitive<f64>> PyArray<T, Ix1> {
12811289
pub fn arange(py: Python, start: T, stop: T, step: T) -> &Self {
12821290
unsafe {
12831291
let ptr = PY_ARRAY_API.PyArray_Arange(
1292+
py,
12841293
start.as_(),
12851294
stop.as_(),
12861295
step.as_(),

src/dtype.rs

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@ use std::ptr;
66

77
use num_traits::{Bounded, Zero};
88
use pyo3::{
9+
exceptions::{PyIndexError, PyValueError},
910
ffi::{self, PyTuple_Size},
10-
prelude::*,
11-
pyobject_native_type_core,
11+
pyobject_native_type_extract, pyobject_native_type_named,
1212
types::{PyDict, PyTuple, PyType},
13-
AsPyPointer, FromPyObject, FromPyPointer, PyNativeType,
13+
AsPyPointer, FromPyObject, FromPyPointer, IntoPyPointer, PyAny, PyNativeType, PyObject,
14+
PyResult, PyTypeInfo, Python, ToPyObject,
1415
};
1516

1617
use crate::npyffi::{
@@ -19,7 +20,6 @@ use crate::npyffi::{
1920
};
2021

2122
pub use num_complex::{Complex32, Complex64};
22-
use pyo3::exceptions::{PyIndexError, PyValueError};
2323

2424
/// Binding of [`numpy.dtype`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html).
2525
///
@@ -38,20 +38,31 @@ use pyo3::exceptions::{PyIndexError, PyValueError};
3838
/// ```
3939
pub struct PyArrayDescr(PyAny);
4040

41-
pyobject_native_type_core!(
42-
PyArrayDescr,
43-
*PY_ARRAY_API.get_type_object(NpyTypes::PyArrayDescr_Type),
44-
#module=Some("numpy"),
45-
#checkfunction=arraydescr_check
46-
);
47-
48-
unsafe fn arraydescr_check(op: *mut ffi::PyObject) -> c_int {
49-
ffi::PyObject_TypeCheck(
50-
op,
51-
PY_ARRAY_API.get_type_object(NpyTypes::PyArrayDescr_Type),
52-
)
41+
pyobject_native_type_named!(PyArrayDescr);
42+
43+
unsafe impl PyTypeInfo for PyArrayDescr {
44+
type AsRefTarget = Self;
45+
46+
const NAME: &'static str = "PyArrayDescr";
47+
const MODULE: ::std::option::Option<&'static str> = Some("numpy");
48+
49+
#[inline]
50+
fn type_object_raw(py: Python) -> *mut ffi::PyTypeObject {
51+
unsafe { PY_ARRAY_API.get_type_object(py, NpyTypes::PyArrayDescr_Type) }
52+
}
53+
54+
fn is_type_of(ob: &PyAny) -> bool {
55+
unsafe {
56+
ffi::PyObject_TypeCheck(
57+
ob.as_ptr(),
58+
PY_ARRAY_API.get_type_object(ob.py(), NpyTypes::PyArrayDescr_Type),
59+
) > 0
60+
}
61+
}
5362
}
5463

64+
pyobject_native_type_extract!(PyArrayDescr);
65+
5566
/// Returns the type descriptor ("dtype") for a registered type.
5667
pub fn dtype<T: Element>(py: Python) -> &PyArrayDescr {
5768
T::get_dtype(py)
@@ -70,7 +81,7 @@ impl PyArrayDescr {
7081
let mut descr: *mut PyArray_Descr = ptr::null_mut();
7182
unsafe {
7283
// None is an invalid input here and is not converted to NPY_DEFAULT_TYPE
73-
PY_ARRAY_API.PyArray_DescrConverter2(obj.as_ptr(), &mut descr as *mut _);
84+
PY_ARRAY_API.PyArray_DescrConverter2(py, obj.as_ptr(), &mut descr as *mut _);
7485
py.from_owned_ptr_or_err(descr as _)
7586
}
7687
}
@@ -99,12 +110,15 @@ impl PyArrayDescr {
99110

100111
/// Returns true if two type descriptors are equivalent.
101112
pub fn is_equiv_to(&self, other: &Self) -> bool {
102-
unsafe { PY_ARRAY_API.PyArray_EquivTypes(self.as_dtype_ptr(), other.as_dtype_ptr()) != 0 }
113+
unsafe {
114+
PY_ARRAY_API.PyArray_EquivTypes(self.py(), self.as_dtype_ptr(), other.as_dtype_ptr())
115+
!= 0
116+
}
103117
}
104118

105119
fn from_npy_type(py: Python, npy_type: NPY_TYPES) -> &Self {
106120
unsafe {
107-
let descr = PY_ARRAY_API.PyArray_DescrFromType(npy_type as _);
121+
let descr = PY_ARRAY_API.PyArray_DescrFromType(py, npy_type as _);
108122
py.from_owned_ptr(descr as _)
109123
}
110124
}

src/npyffi/array.rs

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
//! Low-Level binding for [Array API](https://numpy.org/doc/stable/reference/c-api/array.html)
22
use libc::FILE;
3-
use pyo3::ffi::{self, PyObject, PyTypeObject};
3+
use std::cell::Cell;
44
use std::os::raw::*;
5-
use std::ptr::null_mut;
6-
use std::sync::atomic::{AtomicPtr, Ordering};
5+
use std::ptr::null;
6+
7+
use pyo3::ffi::{self, PyObject, PyTypeObject};
78

89
use crate::npyffi::*;
910

@@ -23,7 +24,7 @@ const CAPSULE_NAME: &str = "_ARRAY_API";
2324
/// pyo3::Python::with_gil(|py| {
2425
/// let array = PyArray::from_slice(py, &[3, 2, 4]);
2526
/// unsafe {
26-
/// PY_ARRAY_API.PyArray_Sort(array.as_array_ptr(), 0, NPY_SORTKIND::NPY_QUICKSORT);
27+
/// PY_ARRAY_API.PyArray_Sort(py, array.as_array_ptr(), 0, NPY_SORTKIND::NPY_QUICKSORT);
2728
/// }
2829
/// assert_eq!(array.readonly().as_slice().unwrap(), &[2, 3, 4]);
2930
/// })
@@ -32,30 +33,29 @@ pub static PY_ARRAY_API: PyArrayAPI = PyArrayAPI::new();
3233

3334
/// See [PY_ARRAY_API] for more.
3435
pub struct PyArrayAPI {
35-
api: AtomicPtr<*const c_void>,
36+
api: Cell<*const *const c_void>,
3637
}
3738

39+
unsafe impl Send for PyArrayAPI {}
40+
41+
unsafe impl Sync for PyArrayAPI {}
42+
3843
impl PyArrayAPI {
3944
const fn new() -> Self {
4045
Self {
41-
api: AtomicPtr::new(null_mut()),
46+
api: Cell::new(null()),
4247
}
4348
}
4449
#[cold]
45-
fn init(&self) -> *const *const c_void {
46-
Python::with_gil(|py| {
47-
let mut api = self.api.load(Ordering::Relaxed) as *const *const c_void;
48-
if api.is_null() {
49-
api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
50-
self.api.store(api as *mut _, Ordering::Release);
51-
}
52-
api
53-
})
50+
fn init(&self, py: Python) -> *const *const c_void {
51+
let api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
52+
self.api.set(api);
53+
api
5454
}
55-
unsafe fn get(&self, offset: isize) -> *const *const c_void {
56-
let mut api = self.api.load(Ordering::Acquire) as *const *const c_void;
55+
unsafe fn get(&self, py: Python, offset: isize) -> *const *const c_void {
56+
let mut api = self.api.get();
5757
if api.is_null() {
58-
api = self.init();
58+
api = self.init(py);
5959
}
6060
api.offset(offset)
6161
}
@@ -331,9 +331,9 @@ macro_rules! impl_array_type {
331331
pub enum NpyTypes { $($tname),* }
332332
impl PyArrayAPI {
333333
/// Get the pointer of the type object that `self` refers.
334-
pub unsafe fn get_type_object(&self, ty: NpyTypes) -> *mut PyTypeObject {
334+
pub unsafe fn get_type_object(&self, py: Python, ty: NpyTypes) -> *mut PyTypeObject {
335335
match ty {
336-
$( NpyTypes::$tname => *(self.get($offset)) as _ ),*
336+
$( NpyTypes::$tname => *(self.get(py, $offset)) as _ ),*
337337
}
338338
}
339339
}
@@ -384,14 +384,14 @@ impl_array_type! {
384384

385385
/// Checks that `op` is an instance of `PyArray` or not.
386386
#[allow(non_snake_case)]
387-
pub unsafe fn PyArray_Check(op: *mut PyObject) -> c_int {
388-
ffi::PyObject_TypeCheck(op, PY_ARRAY_API.get_type_object(NpyTypes::PyArray_Type))
387+
pub unsafe fn PyArray_Check(py: Python, op: *mut PyObject) -> c_int {
388+
ffi::PyObject_TypeCheck(op, PY_ARRAY_API.get_type_object(py, NpyTypes::PyArray_Type))
389389
}
390390

391391
/// Checks that `op` is an exact instance of `PyArray` or not.
392392
#[allow(non_snake_case)]
393-
pub unsafe fn PyArray_CheckExact(op: *mut PyObject) -> c_int {
394-
(ffi::Py_TYPE(op) == PY_ARRAY_API.get_type_object(NpyTypes::PyArray_Type)) as _
393+
pub unsafe fn PyArray_CheckExact(py: Python, op: *mut PyObject) -> c_int {
394+
(ffi::Py_TYPE(op) == PY_ARRAY_API.get_type_object(py, NpyTypes::PyArray_Type)) as _
395395
}
396396

397397
// these are under `#if NPY_USE_PYMEM == 1` which seems to be always defined as 1
@@ -406,9 +406,9 @@ mod tests {
406406

407407
#[test]
408408
fn call_api() {
409-
pyo3::Python::with_gil(|_py| unsafe {
409+
pyo3::Python::with_gil(|py| unsafe {
410410
assert_eq!(
411-
PY_ARRAY_API.PyArray_MultiplyIntList([1, 2, 3].as_mut_ptr(), 3),
411+
PY_ARRAY_API.PyArray_MultiplyIntList(py, [1, 2, 3].as_mut_ptr(), 3),
412412
6
413413
);
414414
})

src/npyffi/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ fn get_numpy_api(_py: Python, module: &str, capsule: &str) -> *const *const c_vo
2828
macro_rules! impl_api {
2929
[$offset: expr; $fname: ident ( $($arg: ident : $t: ty),* ) $( -> $ret: ty )* ] => {
3030
#[allow(non_snake_case)]
31-
pub unsafe fn $fname(&self, $($arg : $t), *) $( -> $ret )* {
32-
let fptr = self.get($offset)
31+
pub unsafe fn $fname(&self, py: Python, $($arg : $t), *) $( -> $ret )* {
32+
let fptr = self.get(py, $offset)
3333
as *const extern fn ($($arg : $t), *) $( -> $ret )*;
3434
(*fptr)($($arg), *)
3535
}

0 commit comments

Comments
 (0)