Skip to content

Commit 83761e6

Browse files
authored
Merge pull request #105 from kngwyu/contiguous
Contiguous checking
2 parents 935463e + 123fbb5 commit 83761e6

File tree

8 files changed

+190
-89
lines changed

8 files changed

+190
-89
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ fn main_<'py>(py: Python<'py>) -> PyResult<()> {
8484
let pyarray: &PyArray1<i32> = py
8585
.eval("np.absolute(np.array([-1, -2, -3], dtype='int32'))", Some(&dict), None)?
8686
.extract()?;
87-
let slice = pyarray.as_slice();
87+
let slice = pyarray.as_slice()?;
8888
assert_eq!(slice, &[1, 2, 3]);
8989
Ok(())
9090
}

src/array.rs

Lines changed: 78 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,46 @@ impl<T, D> PyArray<T, D> {
145145
self.as_ptr() as _
146146
}
147147

148+
#[inline(always)]
149+
fn check_flag(&self, flag: c_int) -> bool {
150+
unsafe { (*self.as_array_ptr()).flags & flag == flag }
151+
}
152+
153+
/// Returns `true` if the internal data of the array is C-style contiguous
154+
/// (default of numpy and ndarray) or Fortran-style contiguous.
155+
///
156+
/// # Example
157+
/// ```
158+
/// # extern crate pyo3; extern crate numpy; fn main() {
159+
/// use pyo3::types::IntoPyDict;
160+
/// let gil = pyo3::Python::acquire_gil();
161+
/// let py = gil.python();
162+
/// let array = numpy::PyArray::arange(py, 0, 1, 10);
163+
/// assert!(array.is_contiguous());
164+
/// let locals = [("np", numpy::get_array_module(py).unwrap())].into_py_dict(py);
165+
/// let not_contiguous: &numpy::PyArray1<f32> = py
166+
/// .eval("np.zeros((3, 5))[::2, 4]", Some(locals), None)
167+
/// .unwrap()
168+
/// .downcast_ref()
169+
/// .unwrap();
170+
/// assert!(!not_contiguous.is_contiguous());
171+
/// # }
172+
/// ```
173+
pub fn is_contiguous(&self) -> bool {
174+
self.check_flag(npyffi::NPY_ARRAY_C_CONTIGUOUS)
175+
| self.check_flag(npyffi::NPY_ARRAY_F_CONTIGUOUS)
176+
}
177+
178+
/// Returns `true` if the internal data of the array is Fortran-style contiguous.
179+
pub fn is_fotran_contiguous(&self) -> bool {
180+
self.check_flag(npyffi::NPY_ARRAY_F_CONTIGUOUS)
181+
}
182+
183+
/// Returns `true` if the internal data of the array is C-style contiguous.
184+
pub fn is_c_contiguous(&self) -> bool {
185+
self.check_flag(npyffi::NPY_ARRAY_C_CONTIGUOUS)
186+
}
187+
148188
/// Get `Py<PyArray>` from `&PyArray`, which is the owned wrapper of PyObject.
149189
///
150190
/// You can use this method when you have to avoid lifetime annotation to your function args
@@ -162,7 +202,7 @@ impl<T, D> PyArray<T, D> {
162202
/// }
163203
/// let gil = Python::acquire_gil();
164204
/// let array = return_py_array();
165-
/// assert_eq!(array.as_ref(gil.python()).as_slice(), &[0, 0, 0, 0, 0]);
205+
/// assert_eq!(array.as_ref(gil.python()).as_slice().unwrap(), &[0, 0, 0, 0, 0]);
166206
/// # }
167207
/// ```
168208
pub fn to_owned(&self) -> Py<Self> {
@@ -395,24 +435,43 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
395435
}
396436

397437
/// Get the immutable view of the internal data of `PyArray`, as slice.
438+
///
439+
/// Returns `ErrorKind::NotContiguous` if the internal array is not contiguous.
398440
/// # Example
399441
/// ```
400442
/// # extern crate pyo3; extern crate numpy; fn main() {
401-
/// use numpy::PyArray;
443+
/// use numpy::{PyArray, PyArray1};
444+
/// use pyo3::types::IntoPyDict;
402445
/// let gil = pyo3::Python::acquire_gil();
403-
/// let py_array = PyArray::arange(gil.python(), 0, 4, 1).reshape([2, 2]).unwrap();
404-
/// assert_eq!(py_array.as_slice(), &[0, 1, 2, 3]);
446+
/// let py = gil.python();
447+
/// let py_array = PyArray::arange(py, 0, 4, 1).reshape([2, 2]).unwrap();
448+
/// assert_eq!(py_array.as_slice().unwrap(), &[0, 1, 2, 3]);
449+
/// let locals = [("np", numpy::get_array_module(py).unwrap())].into_py_dict(py);
450+
/// let not_contiguous: &PyArray1<f32> = py
451+
/// .eval("np.zeros((3, 5))[[0, 2], [3, 4]]", Some(locals), None)
452+
/// .unwrap()
453+
/// .downcast_ref()
454+
/// .unwrap();
455+
/// assert!(not_contiguous.as_slice().is_err());
405456
/// # }
406457
/// ```
407-
pub fn as_slice(&self) -> &[T] {
408-
self.type_check_assert();
409-
unsafe { ::std::slice::from_raw_parts(self.data(), self.len()) }
458+
pub fn as_slice(&self) -> Result<&[T], ErrorKind> {
459+
self.type_check()?;
460+
if !self.is_contiguous() {
461+
Err(ErrorKind::NotContiguous)
462+
} else {
463+
Ok(unsafe { ::std::slice::from_raw_parts(self.data(), self.len()) })
464+
}
410465
}
411466

412467
/// Get the mmutable view of the internal data of `PyArray`, as slice.
413-
pub fn as_slice_mut(&self) -> &mut [T] {
414-
self.type_check_assert();
415-
unsafe { ::std::slice::from_raw_parts_mut(self.data(), self.len()) }
468+
pub fn as_slice_mut(&self) -> Result<&mut [T], ErrorKind> {
469+
self.type_check()?;
470+
if !self.is_contiguous() {
471+
Err(ErrorKind::NotContiguous)
472+
} else {
473+
Ok(unsafe { ::std::slice::from_raw_parts_mut(self.data(), self.len()) })
474+
}
416475
}
417476

418477
/// Construct PyArray from `ndarray::ArrayBase`.
@@ -614,10 +673,7 @@ impl<T: TypeNum + Clone, D: Dimension> PyArray<T, D> {
614673
/// # }
615674
/// ```
616675
pub fn to_owned_array(&self) -> Array<T, D> {
617-
unsafe {
618-
let vec = self.as_slice().to_vec();
619-
Array::from_shape_vec_unchecked(self.ndarray_shape(), vec)
620-
}
676+
self.as_array().to_owned()
621677
}
622678
}
623679

@@ -631,7 +687,7 @@ impl<T: TypeNum> PyArray<T, Ix1> {
631687
/// let gil = pyo3::Python::acquire_gil();
632688
/// let array = [1, 2, 3, 4, 5];
633689
/// let pyarray = PyArray::from_slice(gil.python(), &array);
634-
/// assert_eq!(pyarray.as_slice(), &[1, 2, 3, 4, 5]);
690+
/// assert_eq!(pyarray.as_slice().unwrap(), &[1, 2, 3, 4, 5]);
635691
/// # }
636692
/// ```
637693
pub fn from_slice<'py>(py: Python<'py>, slice: &[T]) -> &'py Self {
@@ -652,7 +708,7 @@ impl<T: TypeNum> PyArray<T, Ix1> {
652708
/// let gil = pyo3::Python::acquire_gil();
653709
/// let vec = vec![1, 2, 3, 4, 5];
654710
/// let pyarray = PyArray::from_vec(gil.python(), vec);
655-
/// assert_eq!(pyarray.as_slice(), &[1, 2, 3, 4, 5]);
711+
/// assert_eq!(pyarray.as_slice().unwrap(), &[1, 2, 3, 4, 5]);
656712
/// # }
657713
/// ```
658714
pub fn from_vec<'py>(py: Python<'py>, vec: Vec<T>) -> &'py Self {
@@ -670,7 +726,7 @@ impl<T: TypeNum> PyArray<T, Ix1> {
670726
/// let gil = pyo3::Python::acquire_gil();
671727
/// let vec = vec![1, 2, 3, 4, 5];
672728
/// let pyarray = PyArray::from_iter(gil.python(), vec.iter().map(|&x| x));
673-
/// assert_eq!(pyarray.as_slice(), &[1, 2, 3, 4, 5]);
729+
/// assert_eq!(pyarray.as_slice().unwrap(), &[1, 2, 3, 4, 5]);
674730
/// # }
675731
/// ```
676732
pub fn from_exact_iter(py: Python, iter: impl ExactSizeIterator<Item = T>) -> &Self {
@@ -697,7 +753,7 @@ impl<T: TypeNum> PyArray<T, Ix1> {
697753
/// let gil = pyo3::Python::acquire_gil();
698754
/// let set: BTreeSet<u32> = [4, 3, 2, 5, 1].into_iter().cloned().collect();
699755
/// let pyarray = PyArray::from_iter(gil.python(), set);
700-
/// assert_eq!(pyarray.as_slice(), &[1, 2, 3, 4, 5]);
756+
/// assert_eq!(pyarray.as_slice().unwrap(), &[1, 2, 3, 4, 5]);
701757
/// # }
702758
/// ```
703759
pub fn from_iter(py: Python, iter: impl IntoIterator<Item = T>) -> &Self {
@@ -872,7 +928,7 @@ impl<T: TypeNum, D> PyArray<T, D> {
872928
/// let pyarray_f = PyArray::arange(gil.python(), 2.0, 5.0, 1.0);
873929
/// let pyarray_i = PyArray::<i64, _>::new(gil.python(), [3], false);
874930
/// assert!(pyarray_f.copy_to(pyarray_i).is_ok());
875-
/// assert_eq!(pyarray_i.as_slice(), &[2, 3, 4]);
931+
/// assert_eq!(pyarray_i.as_slice().unwrap(), &[2, 3, 4]);
876932
/// # }
877933
pub fn copy_to<U: TypeNum>(&self, other: &PyArray<U, D>) -> Result<(), ErrorKind> {
878934
let self_ptr = self.as_array_ptr();
@@ -893,7 +949,7 @@ impl<T: TypeNum, D> PyArray<T, D> {
893949
/// let gil = pyo3::Python::acquire_gil();
894950
/// let pyarray_f = PyArray::arange(gil.python(), 2.0, 5.0, 1.0);
895951
/// let pyarray_i = pyarray_f.cast::<i32>(false).unwrap();
896-
/// assert_eq!(pyarray_i.as_slice(), &[2, 3, 4]);
952+
/// assert_eq!(pyarray_i.as_slice().unwrap(), &[2, 3, 4]);
897953
/// # }
898954
pub fn cast<'py, U: TypeNum>(
899955
&'py self,
@@ -980,9 +1036,9 @@ impl<T: TypeNum + AsPrimitive<f64>> PyArray<T, Ix1> {
9801036
/// use numpy::PyArray;
9811037
/// let gil = pyo3::Python::acquire_gil();
9821038
/// let pyarray = PyArray::arange(gil.python(), 2.0, 4.0, 0.5);
983-
/// assert_eq!(pyarray.as_slice(), &[2.0, 2.5, 3.0, 3.5]);
1039+
/// assert_eq!(pyarray.as_slice().unwrap(), &[2.0, 2.5, 3.0, 3.5]);
9841040
/// let pyarray = PyArray::arange(gil.python(), -2, 4, 3);
985-
/// assert_eq!(pyarray.as_slice(), &[-2, 1]);
1041+
/// assert_eq!(pyarray.as_slice().unwrap(), &[-2, 1]);
9861042
/// # }
9871043
pub fn arange<'py>(py: Python<'py>, start: T, stop: T, step: T) -> &'py Self {
9881044
unsafe {

src/convert.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use npyffi::npy_intp;
2323
/// use numpy::{PyArray, IntoPyArray};
2424
/// let gil = pyo3::Python::acquire_gil();
2525
/// let py_array = vec![1, 2, 3].into_pyarray(gil.python());
26-
/// assert_eq!(py_array.as_slice(), &[1, 2, 3]);
26+
/// assert_eq!(py_array.as_slice().unwrap(), &[1, 2, 3]);
2727
/// assert!(py_array.resize(100).is_err()); // You can't resize owned-by-rust array.
2828
/// # }
2929
/// ```
@@ -75,7 +75,7 @@ where
7575
/// use numpy::{PyArray, ToPyArray};
7676
/// let gil = pyo3::Python::acquire_gil();
7777
/// let py_array = vec![1, 2, 3].to_pyarray(gil.python());
78-
/// assert_eq!(py_array.as_slice(), &[1, 2, 3]);
78+
/// assert_eq!(py_array.as_slice().unwrap(), &[1, 2, 3]);
7979
/// # }
8080
/// ```
8181
pub trait ToPyArray {

src/error.rs

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ pub enum ErrorKind {
9595
FromVec { dim1: usize, dim2: usize },
9696
/// Error in numpy -> numpy data conversion
9797
PyToPy(Box<(ArrayShape, ArrayShape)>),
98+
/// The array need to be contiguous to finish the opretion
99+
NotContiguous,
98100
}
99101

100102
impl ErrorKind {
@@ -146,6 +148,7 @@ impl fmt::Display for ErrorKind {
146148
"Cast failed: from=ndarray({}), to=ndarray(dtype={})",
147149
e.0, e.1,
148150
),
151+
ErrorKind::NotContiguous => write!(f, "This array is not contiguous!"),
149152
}
150153
}
151154
}
@@ -154,11 +157,7 @@ impl error::Error for ErrorKind {}
154157

155158
impl From<ErrorKind> for PyErr {
156159
fn from(err: ErrorKind) -> PyErr {
157-
match err {
158-
ErrorKind::PyToRust { .. } | ErrorKind::FromVec { .. } | ErrorKind::PyToPy(_) => {
159-
PyErr::new::<exc::TypeError, _>(format!("{}", err))
160-
}
161-
}
160+
PyErr::new::<exc::TypeError, _>(format!("{}", err))
162161
}
163162
}
164163

@@ -167,10 +166,6 @@ impl IntoPyErr for ErrorKind {
167166
Into::into(self)
168167
}
169168
fn into_pyerr_with<D: fmt::Display>(self, msg: impl FnOnce() -> D) -> PyErr {
170-
match self {
171-
ErrorKind::PyToRust { .. } | ErrorKind::FromVec { .. } | ErrorKind::PyToPy(_) => {
172-
PyErr::new::<exc::TypeError, _>(format!("{}\n context: {}", self, msg()))
173-
}
174-
}
169+
PyErr::new::<exc::TypeError, _>(format!("{}\n context: {}", self, msg()))
175170
}
176171
}

src/npyffi/array.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ const CAPSULE_NAME: &str = "_ARRAY_API";
2828
/// unsafe {
2929
/// PY_ARRAY_API.PyArray_Sort(array.as_array_ptr(), 0, NPY_SORTKIND::NPY_QUICKSORT);
3030
/// }
31-
/// assert_eq!(array.as_slice(), &[2, 3, 4])
31+
/// assert_eq!(array.as_slice().unwrap(), &[2, 3, 4])
3232
/// # }
3333
/// ```
3434
pub static PY_ARRAY_API: PyArrayAPI = PyArrayAPI {

src/npyffi/flags.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
use std::os::raw::c_int;
2+
3+
pub const NPY_ARRAY_C_CONTIGUOUS: c_int = 0x0001;
4+
pub const NPY_ARRAY_F_CONTIGUOUS: c_int = 0x0002;
5+
pub const NPY_ARRAY_OWNDATA: c_int = 0x0004;
6+
pub const NPY_ARRAY_FORCECAST: c_int = 0x0010;
7+
pub const NPY_ARRAY_ENSURECOPY: c_int = 0x0020;
8+
pub const NPY_ARRAY_ENSUREARRAY: c_int = 0x0040;
9+
pub const NPY_ARRAY_ELEMENTSTRIDES: c_int = 0x0080;
10+
pub const NPY_ARRAY_ALIGNED: c_int = 0x0100;
11+
pub const NPY_ARRAY_NOTSWAPPED: c_int = 0x0200;
12+
pub const NPY_ARRAY_WRITEABLE: c_int = 0x0400;
13+
pub const NPY_ARRAY_UPDATEIFCOPY: c_int = 0x1000;
14+
pub const NPY_ARRAY_WRITEBACKIFCOPY: c_int = 0x2000;
15+
pub const NPY_ARRAY_BEHAVED: c_int = NPY_ARRAY_ALIGNED | NPY_ARRAY_WRITEABLE;
16+
pub const NPY_ARRAY_BEHAVED_NS: c_int = NPY_ARRAY_BEHAVED | NPY_ARRAY_NOTSWAPPED;
17+
pub const NPY_ARRAY_CARRAY: c_int = NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_BEHAVED;
18+
pub const NPY_ARRAY_CARRAY_RO: c_int = NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_ALIGNED;
19+
pub const NPY_ARRAY_FARRAY: c_int = NPY_ARRAY_F_CONTIGUOUS | NPY_ARRAY_BEHAVED;
20+
pub const NPY_ARRAY_FARRAY_RO: c_int = NPY_ARRAY_F_CONTIGUOUS | NPY_ARRAY_ALIGNED;
21+
pub const NPY_ARRAY_DEFAULT: c_int = NPY_ARRAY_CARRAY;
22+
pub const NPY_ARRAY_IN_ARRAY: c_int = NPY_ARRAY_CARRAY_RO;
23+
pub const NPY_ARRAY_OUT_ARRAY: c_int = NPY_ARRAY_CARRAY;
24+
pub const NPY_ARRAY_INOUT_ARRAY: c_int = NPY_ARRAY_CARRAY | NPY_ARRAY_UPDATEIFCOPY;
25+
pub const NPY_ARRAY_INOUT_ARRAY2: c_int = NPY_ARRAY_CARRAY | NPY_ARRAY_WRITEBACKIFCOPY;
26+
pub const NPY_ARRAY_IN_FARRAY: c_int = NPY_ARRAY_FARRAY_RO;
27+
pub const NPY_ARRAY_OUT_FARRAY: c_int = NPY_ARRAY_FARRAY;
28+
pub const NPY_ARRAY_INOUT_FARRAY: c_int = NPY_ARRAY_FARRAY | NPY_ARRAY_UPDATEIFCOPY;
29+
pub const NPY_ARRAY_INOUT_FARRAY2: c_int = NPY_ARRAY_FARRAY | NPY_ARRAY_WRITEBACKIFCOPY;
30+
pub const NPY_ARRAY_UPDATE_ALL: c_int = NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_F_CONTIGUOUS;

src/npyffi/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,13 @@ macro_rules! impl_api {
4848
}
4949

5050
pub mod array;
51+
pub mod flags;
5152
pub mod objects;
5253
pub mod types;
5354
pub mod ufunc;
5455

5556
pub use self::array::*;
57+
pub use self::flags::*;
5658
pub use self::objects::*;
5759
pub use self::types::*;
5860
pub use self::ufunc::*;

0 commit comments

Comments
 (0)