Skip to content

Commit 744a3f3

Browse files
Icxoluadamreichold
authored andcommitted
migration to pyo3 0.21 beta using the new Bound API
This does still use GIL Refs in numpy's API but switches our internals to use the Bound API where appropriate.
1 parent 32740b3 commit 744a3f3

File tree

15 files changed

+84
-72
lines changed

15 files changed

+84
-72
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ num-complex = ">= 0.2, < 0.5"
2222
num-integer = "0.1"
2323
num-traits = "0.2"
2424
ndarray = ">= 0.13, < 0.16"
25-
pyo3 = { version = "0.20", default-features = false, features = ["macros"] }
25+
pyo3 = { version = "0.21.0-beta", default-features = false, features = ["gil-refs", "macros"] }
2626
rustc-hash = "1.1"
2727

2828
[dev-dependencies]
29-
pyo3 = { version = "0.20", default-features = false, features = ["auto-initialize"] }
29+
pyo3 = { version = "0.21.0-beta", default-features = false, features = ["auto-initialize", "gil-refs"] }
3030
nalgebra = { version = "0.32", default-features = false, features = ["std"] }
3131

3232
[package.metadata.docs.rs]

examples/linalg/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ name = "rust_linalg"
99
crate-type = ["cdylib"]
1010

1111
[dependencies]
12-
pyo3 = { version = "0.20", features = ["extension-module"] }
12+
pyo3 = { version = "0.21.0-beta", features = ["extension-module"] }
1313
numpy = { path = "../.." }
1414
ndarray-linalg = { version = "0.14.1", features = ["openblas-system"] }
1515

examples/parallel/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ name = "rust_parallel"
99
crate-type = ["cdylib"]
1010

1111
[dependencies]
12-
pyo3 = { version = "0.20", features = ["extension-module", "multiple-pymethods"] }
12+
pyo3 = { version = "0.21.0-beta", features = ["extension-module", "multiple-pymethods"] }
1313
numpy = { path = "../.." }
1414
ndarray = { version = "0.15", features = ["rayon", "blas"] }
1515
blas-src = { version = "0.8", features = ["openblas"] }

examples/simple/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ name = "rust_ext"
99
crate-type = ["cdylib"]
1010

1111
[dependencies]
12-
pyo3 = { version = "0.20", features = ["extension-module", "abi3-py37"] }
12+
pyo3 = { version = "0.21.0-beta", features = ["extension-module", "abi3-py37"] }
1313
numpy = { path = "../.." }
1414

1515
[workspace]

src/array.rs

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@ use ndarray::{
1717
};
1818
use num_traits::AsPrimitive;
1919
use pyo3::{
20-
ffi, pyobject_native_type_base, types::PyModule, AsPyPointer, FromPyObject, IntoPy, Py, PyAny,
21-
PyClassInitializer, PyDowncastError, PyErr, PyNativeType, PyObject, PyResult, PyTypeInfo,
22-
Python, ToPyObject,
20+
ffi, pyobject_native_type_base,
21+
types::{DerefToPyAny, PyAnyMethods, PyModule},
22+
AsPyPointer, Bound, DowncastError, FromPyObject, IntoPy, Py, PyAny, PyErr, PyNativeType,
23+
PyObject, PyResult, PyTypeInfo, Python, ToPyObject,
2324
};
2425

2526
use crate::borrow::{PyReadonlyArray, PyReadwriteArray};
@@ -118,21 +119,21 @@ pub type PyArray6<T> = PyArray<T, Ix6>;
118119
pub type PyArrayDyn<T> = PyArray<T, IxDyn>;
119120

120121
/// Returns a handle to NumPy's multiarray module.
121-
pub fn get_array_module<'py>(py: Python<'py>) -> PyResult<&PyModule> {
122-
PyModule::import(py, npyffi::array::MOD_NAME)
122+
pub fn get_array_module<'py>(py: Python<'py>) -> PyResult<Bound<'_, PyModule>> {
123+
PyModule::import_bound(py, npyffi::array::MOD_NAME)
123124
}
124125

125-
unsafe impl<T: Element, D: Dimension> PyTypeInfo for PyArray<T, D> {
126-
type AsRefTarget = Self;
126+
impl<T, D> DerefToPyAny for PyArray<T, D> {}
127127

128+
unsafe impl<T: Element, D: Dimension> PyTypeInfo for PyArray<T, D> {
128129
const NAME: &'static str = "PyArray<T, D>";
129130
const MODULE: Option<&'static str> = Some("numpy");
130131

131132
fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject {
132133
unsafe { npyffi::PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type) }
133134
}
134135

135-
fn is_type_of(ob: &PyAny) -> bool {
136+
fn is_type_of_bound(ob: &Bound<'_, PyAny>) -> bool {
136137
Self::extract::<IgnoreError>(ob).is_ok()
137138
}
138139
}
@@ -189,8 +190,11 @@ impl<T, D> IntoPy<PyObject> for PyArray<T, D> {
189190
}
190191

191192
impl<'py, T: Element, D: Dimension> FromPyObject<'py> for &'py PyArray<T, D> {
192-
fn extract(ob: &'py PyAny) -> PyResult<Self> {
193+
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
194+
#[allow(clippy::map_clone)] // due to MSRV
193195
PyArray::extract(ob)
196+
.map(Clone::clone)
197+
.map(Bound::into_gil_ref)
194198
}
195199
}
196200

@@ -251,28 +255,30 @@ impl<T, D> PyArray<T, D> {
251255
}
252256

253257
impl<T: Element, D: Dimension> PyArray<T, D> {
254-
fn extract<'py, E>(ob: &'py PyAny) -> Result<&'py Self, E>
258+
fn extract<'a, 'py, E>(ob: &'a Bound<'py, PyAny>) -> Result<&'a Bound<'py, Self>, E>
255259
where
256-
E: From<PyDowncastError<'py>> + From<DimensionalityError> + From<TypeError<'py>>,
260+
E: From<DowncastError<'a, 'py>> + From<DimensionalityError> + From<TypeError<'a>>,
257261
{
258262
// Check if the object is an array.
259263
let array = unsafe {
260264
if npyffi::PyArray_Check(ob.py(), ob.as_ptr()) == 0 {
261-
return Err(PyDowncastError::new(ob, Self::NAME).into());
265+
return Err(DowncastError::new(ob, <Self as PyTypeInfo>::NAME).into());
262266
}
263-
&*(ob as *const PyAny as *const Self)
267+
ob.downcast_unchecked()
264268
};
265269

270+
let arr_gil_ref: &PyArray<T, D> = array.as_gil_ref();
271+
266272
// Check if the dimensionality matches `D`.
267-
let src_ndim = array.ndim();
273+
let src_ndim = arr_gil_ref.ndim();
268274
if let Some(dst_ndim) = D::NDIM {
269275
if src_ndim != dst_ndim {
270276
return Err(DimensionalityError::new(src_ndim, dst_ndim).into());
271277
}
272278
}
273279

274280
// Check if the element type matches `T`.
275-
let src_dtype = array.dtype();
281+
let src_dtype = arr_gil_ref.dtype();
276282
let dst_dtype = T::get_dtype(ob.py());
277283
if !src_dtype.is_equiv_to(dst_dtype) {
278284
return Err(TypeError::new(src_dtype, dst_dtype).into());
@@ -399,11 +405,11 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
399405
data_ptr: *const T,
400406
container: PySliceContainer,
401407
) -> &'py Self {
402-
let container = PyClassInitializer::from(container)
403-
.create_cell(py)
404-
.expect("Failed to create slice container");
408+
let container = Bound::new(py, container)
409+
.expect("Failed to create slice container")
410+
.into_ptr();
405411

406-
Self::new_with_data(py, dims, strides, data_ptr, container as *mut PyAny)
412+
Self::new_with_data(py, dims, strides, data_ptr, container.cast())
407413
}
408414

409415
/// Creates a NumPy array backed by `array` and ties its ownership to the Python object `container`.

src/array_like.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@ use std::marker::PhantomData;
22
use std::ops::Deref;
33

44
use ndarray::{Array1, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
5-
use pyo3::{intern, sync::GILOnceCell, types::PyDict, FromPyObject, Py, PyAny, PyResult};
5+
use pyo3::{
6+
intern,
7+
sync::GILOnceCell,
8+
types::{PyAnyMethods, PyDict},
9+
FromPyObject, Py, PyAny, PyResult,
10+
};
611

712
use crate::sealed::Sealed;
813
use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray};

src/borrow/shared.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@ use std::os::raw::{c_char, c_int};
55
use std::slice::from_raw_parts;
66

77
use num_integer::gcd;
8-
use pyo3::{
9-
exceptions::PyTypeError, once_cell::GILOnceCell, types::PyCapsule, Py, PyResult, PyTryInto,
10-
Python,
11-
};
8+
use pyo3::types::{PyAnyMethods, PyCapsuleMethods};
9+
use pyo3::{exceptions::PyTypeError, sync::GILOnceCell, types::PyCapsule, PyResult, Python};
1210
use rustc_hash::FxHashMap;
1311

1412
use crate::array::get_array_module;
@@ -124,8 +122,8 @@ fn get_or_insert_shared<'py>(py: Python<'py>) -> PyResult<&'py Shared> {
124122
fn insert_shared<'py>(py: Python<'py>) -> PyResult<*const Shared> {
125123
let module = get_array_module(py)?;
126124

127-
let capsule: &PyCapsule = match module.getattr("_RUST_NUMPY_BORROW_CHECKING_API") {
128-
Ok(capsule) => PyTryInto::try_into(capsule)?,
125+
let capsule = match module.getattr("_RUST_NUMPY_BORROW_CHECKING_API") {
126+
Ok(capsule) => capsule.downcast_into::<PyCapsule>()?,
129127
Err(_err) => {
130128
let flags: *mut BorrowFlags = Box::into_raw(Box::default());
131129

@@ -138,7 +136,7 @@ fn insert_shared<'py>(py: Python<'py>) -> PyResult<*const Shared> {
138136
release_mut: release_mut_shared,
139137
};
140138

141-
let capsule = PyCapsule::new_with_destructor(
139+
let capsule = PyCapsule::new_bound_with_destructor(
142140
py,
143141
shared,
144142
Some(CString::new("_RUST_NUMPY_BORROW_CHECKING_API").unwrap()),
@@ -147,25 +145,27 @@ fn insert_shared<'py>(py: Python<'py>) -> PyResult<*const Shared> {
147145
let _ = unsafe { Box::from_raw(shared.flags as *mut BorrowFlags) };
148146
},
149147
)?;
150-
module.setattr("_RUST_NUMPY_BORROW_CHECKING_API", capsule)?;
148+
module.setattr("_RUST_NUMPY_BORROW_CHECKING_API", &capsule)?;
151149
capsule
152150
}
153151
};
154152

155153
// SAFETY: All versions of the shared borrow checking API start with a version field.
156-
let version = unsafe { *(capsule.pointer() as *mut u64) };
154+
let version = unsafe { *capsule.pointer().cast::<u64>() };
157155
if version < 1 {
158156
return Err(PyTypeError::new_err(format!(
159157
"Version {} of borrow checking API is not supported by this version of rust-numpy",
160158
version
161159
)));
162160
}
163161

162+
let ptr = capsule.pointer();
163+
164164
// Intentionally leak a reference to the capsule
165165
// so we can safely cache a pointer into its interior.
166-
forget(Py::<PyCapsule>::from(capsule));
166+
forget(capsule);
167167

168-
Ok(capsule.pointer() as *const Shared)
168+
Ok(ptr.cast())
169169
}
170170

171171
// These entry points will be used to access the shared borrow checking API from this extension:

src/dtype.rs

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@ use pyo3::{
1111
exceptions::{PyIndexError, PyValueError},
1212
ffi::{self, PyTuple_Size},
1313
pyobject_native_type_extract, pyobject_native_type_named,
14-
types::{PyDict, PyTuple, PyType},
15-
AsPyPointer, FromPyObject, FromPyPointer, PyAny, PyNativeType, PyObject, PyResult, PyTypeInfo,
16-
Python, ToPyObject,
14+
types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyType},
15+
AsPyPointer, Borrowed, PyAny, PyNativeType, PyObject, PyResult, PyTypeInfo, Python, ToPyObject,
1716
};
1817
#[cfg(feature = "half")]
1918
use pyo3::{sync::GILOnceCell, IntoPy, Py};
@@ -53,8 +52,6 @@ pub struct PyArrayDescr(PyAny);
5352
pyobject_native_type_named!(PyArrayDescr);
5453

5554
unsafe impl PyTypeInfo for PyArrayDescr {
56-
type AsRefTarget = Self;
57-
5855
const NAME: &'static str = "PyArrayDescr";
5956
const MODULE: Option<&'static str> = Some("numpy");
6057

@@ -249,7 +246,9 @@ impl PyArrayDescr {
249246
if !self.has_subarray() {
250247
self
251248
} else {
249+
#[allow(deprecated)]
252250
unsafe {
251+
use pyo3::FromPyPointer;
253252
Self::from_borrowed_ptr(self.py(), (*(*self.as_dtype_ptr()).subarray).base as _)
254253
}
255254
}
@@ -267,11 +266,9 @@ impl PyArrayDescr {
267266
Vec::new()
268267
} else {
269268
// NumPy guarantees that shape is a tuple of non-negative integers so this should never panic.
270-
unsafe {
271-
PyTuple::from_borrowed_ptr(self.py(), (*(*self.as_dtype_ptr()).subarray).shape)
272-
}
273-
.extract()
274-
.unwrap()
269+
unsafe { Borrowed::from_ptr(self.py(), (*(*self.as_dtype_ptr()).subarray).shape) }
270+
.extract()
271+
.unwrap()
275272
}
276273
}
277274

@@ -329,8 +326,8 @@ impl PyArrayDescr {
329326
if !self.has_fields() {
330327
return None;
331328
}
332-
let names = unsafe { PyTuple::from_borrowed_ptr(self.py(), (*self.as_dtype_ptr()).names) };
333-
FromPyObject::extract(names).ok()
329+
let names = unsafe { Borrowed::from_ptr(self.py(), (*self.as_dtype_ptr()).names) };
330+
names.extract().ok()
334331
}
335332

336333
/// Returns the type descriptor and offset of the field with the given name.
@@ -349,17 +346,22 @@ impl PyArrayDescr {
349346
"cannot get field information: type descriptor has no fields",
350347
));
351348
}
352-
let dict = unsafe { PyDict::from_borrowed_ptr(self.py(), (*self.as_dtype_ptr()).fields) };
349+
let dict = unsafe { Borrowed::from_ptr(self.py(), (*self.as_dtype_ptr()).fields) };
350+
let dict = unsafe { dict.downcast_unchecked::<PyDict>() };
353351
// NumPy guarantees that fields are tuples of proper size and type, so this should never panic.
354352
let tuple = dict
355353
.get_item(name)?
356354
.ok_or_else(|| PyIndexError::new_err(name.to_owned()))?
357-
.downcast::<PyTuple>()
355+
.downcast_into::<PyTuple>()
358356
.unwrap();
359357
// Note that we cannot just extract the entire tuple since the third element can be a title.
360-
let dtype = FromPyObject::extract(tuple.as_ref().get_item(0).unwrap()).unwrap();
361-
let offset = FromPyObject::extract(tuple.as_ref().get_item(1).unwrap()).unwrap();
362-
Ok((dtype, offset))
358+
let dtype = tuple
359+
.get_item(0)
360+
.unwrap()
361+
.downcast_into::<PyArrayDescr>()
362+
.unwrap();
363+
let offset = tuple.get_item(1).unwrap().extract().unwrap();
364+
Ok((dtype.into_gil_ref(), offset))
363365
}
364366
}
365367

@@ -548,8 +550,8 @@ mod tests {
548550

549551
#[test]
550552
fn test_dtype_names() {
551-
fn type_name<'py, T: Element>(py: Python<'py>) -> &str {
552-
dtype::<T>(py).typeobj().name().unwrap()
553+
fn type_name<'py, T: Element>(py: Python<'py>) -> String {
554+
dtype::<T>(py).typeobj().qualname().unwrap()
553555
}
554556
Python::with_gil(|py| {
555557
assert_eq!(type_name::<bool>(py), "bool_");
@@ -589,7 +591,7 @@ mod tests {
589591

590592
assert_eq!(dt.num(), NPY_TYPES::NPY_DOUBLE as c_int);
591593
assert_eq!(dt.flags(), 0);
592-
assert_eq!(dt.typeobj().name().unwrap(), "float64");
594+
assert_eq!(dt.typeobj().qualname().unwrap(), "float64");
593595
assert_eq!(dt.char(), b'd');
594596
assert_eq!(dt.kind(), b'f');
595597
assert_eq!(dt.byteorder(), b'=');
@@ -625,7 +627,7 @@ mod tests {
625627

626628
assert_eq!(dt.num(), NPY_TYPES::NPY_VOID as c_int);
627629
assert_eq!(dt.flags(), 0);
628-
assert_eq!(dt.typeobj().name().unwrap(), "void");
630+
assert_eq!(dt.typeobj().qualname().unwrap(), "void");
629631
assert_eq!(dt.char(), b'V');
630632
assert_eq!(dt.kind(), b'V');
631633
assert_eq!(dt.byteorder(), b'|');
@@ -663,7 +665,7 @@ mod tests {
663665
assert_ne!(dt.flags() & NPY_ITEM_HASOBJECT, 0);
664666
assert_ne!(dt.flags() & NPY_NEEDS_PYAPI, 0);
665667
assert_ne!(dt.flags() & NPY_ALIGNED_STRUCT, 0);
666-
assert_eq!(dt.typeobj().name().unwrap(), "void");
668+
assert_eq!(dt.typeobj().qualname().unwrap(), "void");
667669
assert_eq!(dt.char(), b'V');
668670
assert_eq!(dt.kind(), b'V');
669671
assert_eq!(dt.byteorder(), b'|');

src/npyffi/array.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::os::raw::*;
99
use libc::FILE;
1010
use pyo3::{
1111
ffi::{self, PyObject, PyTypeObject},
12-
once_cell::GILOnceCell,
12+
sync::GILOnceCell,
1313
};
1414

1515
use crate::npyffi::*;

src/npyffi/mod.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,23 @@ use std::mem::forget;
1313
use std::os::raw::c_void;
1414

1515
use pyo3::{
16-
types::{PyCapsule, PyModule},
17-
Py, PyResult, PyTryInto, Python,
16+
types::{PyAnyMethods, PyCapsule, PyCapsuleMethods, PyModule},
17+
PyResult, Python,
1818
};
1919

2020
fn get_numpy_api<'py>(
2121
py: Python<'py>,
2222
module: &str,
2323
capsule: &str,
2424
) -> PyResult<*const *const c_void> {
25-
let module = PyModule::import(py, module)?;
26-
let capsule: &PyCapsule = PyTryInto::try_into(module.getattr(capsule)?)?;
25+
let module = PyModule::import_bound(py, module)?;
26+
let capsule = module.getattr(capsule)?.downcast_into::<PyCapsule>()?;
2727

2828
let api = capsule.pointer() as *const *const c_void;
2929

3030
// Intentionally leak a reference to the capsule
3131
// so we can safely cache a pointer into its interior.
32-
forget(Py::<PyCapsule>::from(capsule));
32+
forget(capsule);
3333

3434
Ok(api)
3535
}

0 commit comments

Comments
 (0)