Skip to content

Commit 9517ed9

Browse files
Icxoluadamreichold
authored andcommitted
convert PyUntypedArray to Bound API
1 parent 744a3f3 commit 9517ed9

File tree

3 files changed

+243
-22
lines changed

3 files changed

+243
-22
lines changed

src/array.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ use crate::error::{
3333
};
3434
use crate::npyffi::{self, npy_intp, NPY_ORDER, PY_ARRAY_API};
3535
use crate::slice_container::PySliceContainer;
36-
use crate::untyped_array::PyUntypedArray;
36+
use crate::untyped_array::{PyUntypedArray, PyUntypedArrayMethods};
3737

3838
/// A safe, statically-typed wrapper for NumPy's [`ndarray`][ndarray] class.
3939
///
@@ -1480,6 +1480,20 @@ unsafe fn clone_elements<T: Element>(elems: &[T], data_ptr: &mut *mut T) {
14801480
}
14811481
}
14821482

1483+
/// Implementation of functionality for [`PyArray<T, D>`].
1484+
#[doc(alias = "PyArray")]
1485+
pub trait PyArrayMethods<'py, T, D>: PyUntypedArrayMethods<'py> {
1486+
/// Access an untyped representation of this array.
1487+
fn as_untyped(&self) -> &Bound<'py, PyUntypedArray>;
1488+
}
1489+
1490+
impl<'py, T, D> PyArrayMethods<'py, T, D> for Bound<'py, PyArray<T, D>> {
1491+
#[inline(always)]
1492+
fn as_untyped(&self) -> &Bound<'py, PyUntypedArray> {
1493+
unsafe { self.downcast_unchecked() }
1494+
}
1495+
}
1496+
14831497
#[cfg(test)]
14841498
mod tests {
14851499
use super::*;

src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ pub use nalgebra;
9393

9494
pub use crate::array::{
9595
get_array_module, PyArray, PyArray0, PyArray1, PyArray2, PyArray3, PyArray4, PyArray5,
96-
PyArray6, PyArrayDyn,
96+
PyArray6, PyArrayDyn, PyArrayMethods,
9797
};
9898
pub use crate::array_like::{
9999
AllowTypeChange, PyArrayLike, PyArrayLike0, PyArrayLike1, PyArrayLike2, PyArrayLike3,
@@ -111,7 +111,7 @@ pub use crate::error::{BorrowError, FromVecError, NotContiguousError};
111111
pub use crate::npyffi::{PY_ARRAY_API, PY_UFUNC_API};
112112
pub use crate::strings::{PyFixedString, PyFixedUnicode};
113113
pub use crate::sum_products::{dot, einsum, inner};
114-
pub use crate::untyped_array::PyUntypedArray;
114+
pub use crate::untyped_array::{PyUntypedArray, PyUntypedArrayMethods};
115115

116116
pub use ndarray::{array, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
117117

src/untyped_array.rs

Lines changed: 226 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
//! Safe, untyped interface for NumPy's [N-dimensional arrays][ndarray]
22
//!
33
//! [ndarray]: https://numpy.org/doc/stable/reference/arrays.ndarray.html
4-
use std::{os::raw::c_int, slice};
4+
use std::slice;
55

66
use pyo3::{
7-
ffi, pyobject_native_type_extract, pyobject_native_type_named, AsPyPointer, IntoPy, PyAny,
8-
PyNativeType, PyObject, PyTypeInfo, Python,
7+
ffi, pyobject_native_type_extract, pyobject_native_type_named, types::PyAnyMethods,
8+
AsPyPointer, Bound, IntoPy, PyAny, PyNativeType, PyObject, PyTypeInfo, Python,
99
};
1010

11+
use crate::array::{PyArray, PyArrayMethods};
1112
use crate::cold;
1213
use crate::dtype::PyArrayDescr;
1314
use crate::npyffi;
@@ -68,7 +69,7 @@ unsafe impl PyTypeInfo for PyUntypedArray {
6869
unsafe { npyffi::PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type) }
6970
}
7071

71-
fn is_type_of(ob: &PyAny) -> bool {
72+
fn is_type_of_bound(ob: &Bound<'_, PyAny>) -> bool {
7273
unsafe { npyffi::PyArray_Check(ob.py(), ob.as_ptr()) != 0 }
7374
}
7475
}
@@ -87,7 +88,7 @@ impl PyUntypedArray {
8788
/// Returns a raw pointer to the underlying [`PyArrayObject`][npyffi::PyArrayObject].
8889
#[inline]
8990
pub fn as_array_ptr(&self) -> *mut npyffi::PyArrayObject {
90-
self.as_ptr() as _
91+
self.as_borrowed().as_array_ptr()
9192
}
9293

9394
/// Returns the `dtype` of the array.
@@ -109,16 +110,9 @@ impl PyUntypedArray {
109110
///
110111
/// [ndarray-dtype]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.dtype.html
111112
/// [PyArray_DTYPE]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_DTYPE
113+
#[inline]
112114
pub fn dtype(&self) -> &PyArrayDescr {
113-
unsafe {
114-
let descr_ptr = (*self.as_array_ptr()).descr;
115-
self.py().from_borrowed_ptr(descr_ptr as _)
116-
}
117-
}
118-
119-
#[inline(always)]
120-
pub(crate) fn check_flags(&self, flags: c_int) -> bool {
121-
unsafe { (*self.as_array_ptr()).flags & flags != 0 }
115+
self.as_borrowed().dtype().into_gil_ref()
122116
}
123117

124118
/// Returns `true` if the internal data of the array is contiguous,
@@ -142,18 +136,21 @@ impl PyUntypedArray {
142136
/// assert!(!view.is_contiguous());
143137
/// });
144138
/// ```
139+
#[inline]
145140
pub fn is_contiguous(&self) -> bool {
146-
self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS)
141+
self.as_borrowed().is_contiguous()
147142
}
148143

149144
/// Returns `true` if the internal data of the array is Fortran-style/column-major contiguous.
145+
#[inline]
150146
pub fn is_fortran_contiguous(&self) -> bool {
151-
self.check_flags(npyffi::NPY_ARRAY_F_CONTIGUOUS)
147+
self.as_borrowed().is_fortran_contiguous()
152148
}
153149

154150
/// Returns `true` if the internal data of the array is C-style/row-major contiguous.
151+
#[inline]
155152
pub fn is_c_contiguous(&self) -> bool {
156-
self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS)
153+
self.as_borrowed().is_c_contiguous()
157154
}
158155

159156
/// Returns the number of dimensions of the array.
@@ -177,7 +174,7 @@ impl PyUntypedArray {
177174
/// [PyArray_NDIM]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_NDIM
178175
#[inline]
179176
pub fn ndim(&self) -> usize {
180-
unsafe { (*self.as_array_ptr()).nd as usize }
177+
self.as_borrowed().ndim()
181178
}
182179

183180
/// Returns a slice indicating how many bytes to advance when iterating along each axis.
@@ -246,12 +243,222 @@ impl PyUntypedArray {
246243
}
247244

248245
/// Calculates the total number of elements in the array.
246+
#[inline]
249247
pub fn len(&self) -> usize {
250-
self.shape().iter().product()
248+
self.as_borrowed().len()
251249
}
252250

253251
/// Returns `true` if the there are no elements in the array.
252+
#[inline]
254253
pub fn is_empty(&self) -> bool {
254+
self.as_borrowed().is_empty()
255+
}
256+
}
257+
258+
/// Implementation of functionality for [`PyUntypedArray`].
259+
#[doc(alias = "PyUntypedArray")]
260+
pub trait PyUntypedArrayMethods<'py>: sealed::Sealed {
261+
/// Returns a raw pointer to the underlying [`PyArrayObject`][npyffi::PyArrayObject].
262+
fn as_array_ptr(&self) -> *mut npyffi::PyArrayObject;
263+
264+
/// Returns the `dtype` of the array.
265+
///
266+
/// See also [`ndarray.dtype`][ndarray-dtype] and [`PyArray_DTYPE`][PyArray_DTYPE].
267+
///
268+
/// # Example
269+
///
270+
/// ```
271+
/// use numpy::{dtype, PyArray};
272+
/// use pyo3::Python;
273+
///
274+
/// Python::with_gil(|py| {
275+
/// let array = PyArray::from_vec(py, vec![1_i32, 2, 3]);
276+
///
277+
/// assert!(array.dtype().is_equiv_to(dtype::<i32>(py)));
278+
/// });
279+
/// ```
280+
///
281+
/// [ndarray-dtype]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.dtype.html
282+
/// [PyArray_DTYPE]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_DTYPE
283+
fn dtype(&self) -> Bound<'py, PyArrayDescr>;
284+
285+
/// Returns `true` if the internal data of the array is contiguous,
286+
/// indepedently of whether C-style/row-major or Fortran-style/column-major.
287+
///
288+
/// # Example
289+
///
290+
/// ```
291+
/// use numpy::PyArray1;
292+
/// use pyo3::{types::IntoPyDict, Python};
293+
///
294+
/// Python::with_gil(|py| {
295+
/// let array = PyArray1::arange(py, 0, 10, 1);
296+
/// assert!(array.is_contiguous());
297+
///
298+
/// let view = py
299+
/// .eval("array[::2]", None, Some([("array", array)].into_py_dict(py)))
300+
/// .unwrap()
301+
/// .downcast::<PyArray1<i32>>()
302+
/// .unwrap();
303+
/// assert!(!view.is_contiguous());
304+
/// });
305+
/// ```
306+
fn is_contiguous(&self) -> bool {
307+
unsafe {
308+
check_flags(
309+
&*self.as_array_ptr(),
310+
npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS,
311+
)
312+
}
313+
}
314+
315+
/// Returns `true` if the internal data of the array is Fortran-style/column-major contiguous.
316+
fn is_fortran_contiguous(&self) -> bool {
317+
unsafe { check_flags(&*self.as_array_ptr(), npyffi::NPY_ARRAY_F_CONTIGUOUS) }
318+
}
319+
320+
/// Returns `true` if the internal data of the array is C-style/row-major contiguous.
321+
fn is_c_contiguous(&self) -> bool {
322+
unsafe { check_flags(&*self.as_array_ptr(), npyffi::NPY_ARRAY_C_CONTIGUOUS) }
323+
}
324+
325+
/// Returns the number of dimensions of the array.
326+
///
327+
/// See also [`ndarray.ndim`][ndarray-ndim] and [`PyArray_NDIM`][PyArray_NDIM].
328+
///
329+
/// # Example
330+
///
331+
/// ```
332+
/// use numpy::PyArray3;
333+
/// use pyo3::Python;
334+
///
335+
/// Python::with_gil(|py| {
336+
/// let arr = PyArray3::<f64>::zeros(py, [4, 5, 6], false);
337+
///
338+
/// assert_eq!(arr.ndim(), 3);
339+
/// });
340+
/// ```
341+
///
342+
/// [ndarray-ndim]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.ndim.html
343+
/// [PyArray_NDIM]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_NDIM
344+
#[inline]
345+
fn ndim(&self) -> usize {
346+
unsafe { (*self.as_array_ptr()).nd as usize }
347+
}
348+
349+
/// Returns a slice indicating how many bytes to advance when iterating along each axis.
350+
///
351+
/// See also [`ndarray.strides`][ndarray-strides] and [`PyArray_STRIDES`][PyArray_STRIDES].
352+
///
353+
/// # Example
354+
///
355+
/// ```
356+
/// use numpy::PyArray3;
357+
/// use pyo3::Python;
358+
///
359+
/// Python::with_gil(|py| {
360+
/// let arr = PyArray3::<f64>::zeros(py, [4, 5, 6], false);
361+
///
362+
/// assert_eq!(arr.strides(), &[240, 48, 8]);
363+
/// });
364+
/// ```
365+
/// [ndarray-strides]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html
366+
/// [PyArray_STRIDES]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_STRIDES
367+
#[inline]
368+
fn strides(&self) -> &[isize] {
369+
let n = self.ndim();
370+
if n == 0 {
371+
cold();
372+
return &[];
373+
}
374+
let ptr = self.as_array_ptr();
375+
unsafe {
376+
let p = (*ptr).strides;
377+
slice::from_raw_parts(p, n)
378+
}
379+
}
380+
381+
/// Returns a slice which contains dimmensions of the array.
382+
///
383+
/// See also [`ndarray.shape`][ndaray-shape] and [`PyArray_DIMS`][PyArray_DIMS].
384+
///
385+
/// # Example
386+
///
387+
/// ```
388+
/// use numpy::PyArray3;
389+
/// use pyo3::Python;
390+
///
391+
/// Python::with_gil(|py| {
392+
/// let arr = PyArray3::<f64>::zeros(py, [4, 5, 6], false);
393+
///
394+
/// assert_eq!(arr.shape(), &[4, 5, 6]);
395+
/// });
396+
/// ```
397+
///
398+
/// [ndarray-shape]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.shape.html
399+
/// [PyArray_DIMS]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_DIMS
400+
#[inline]
401+
fn shape(&self) -> &[usize] {
402+
let n = self.ndim();
403+
if n == 0 {
404+
cold();
405+
return &[];
406+
}
407+
let ptr = self.as_array_ptr();
408+
unsafe {
409+
let p = (*ptr).dimensions as *mut usize;
410+
slice::from_raw_parts(p, n)
411+
}
412+
}
413+
414+
/// Calculates the total number of elements in the array.
415+
fn len(&self) -> usize {
416+
self.shape().iter().product()
417+
}
418+
419+
/// Returns `true` if the there are no elements in the array.
420+
fn is_empty(&self) -> bool {
255421
self.shape().iter().any(|dim| *dim == 0)
256422
}
257423
}
424+
425+
fn check_flags(obj: &npyffi::PyArrayObject, flags: i32) -> bool {
426+
obj.flags & flags != 0
427+
}
428+
429+
impl<'py> PyUntypedArrayMethods<'py> for Bound<'py, PyUntypedArray> {
430+
#[inline]
431+
fn as_array_ptr(&self) -> *mut npyffi::PyArrayObject {
432+
self.as_ptr().cast()
433+
}
434+
435+
fn dtype(&self) -> Bound<'py, PyArrayDescr> {
436+
unsafe {
437+
let descr_ptr = (*self.as_array_ptr()).descr;
438+
Bound::from_borrowed_ptr(self.py(), descr_ptr.cast()).downcast_into_unchecked()
439+
}
440+
}
441+
}
442+
443+
// We won't be able to provide a `Deref` impl from `Bound<'_, PyArray<T, D>>` to
444+
// `Bound<'_, PyUntypedArray>`, so this seems to be the next best thing to do
445+
impl<'py, T, D> PyUntypedArrayMethods<'py> for Bound<'py, PyArray<T, D>> {
446+
#[inline]
447+
fn as_array_ptr(&self) -> *mut npyffi::PyArrayObject {
448+
self.as_untyped().as_array_ptr()
449+
}
450+
451+
#[inline]
452+
fn dtype(&self) -> Bound<'py, PyArrayDescr> {
453+
self.as_untyped().dtype()
454+
}
455+
}
456+
457+
mod sealed {
458+
use super::{PyArray, PyUntypedArray};
459+
460+
pub trait Sealed {}
461+
462+
impl Sealed for pyo3::Bound<'_, PyUntypedArray> {}
463+
impl<T, D> Sealed for pyo3::Bound<'_, PyArray<T, D>> {}
464+
}

0 commit comments

Comments
 (0)