Skip to content

Commit 2b411d2

Browse files
committed
Add is_*contiguous methods + Modify as_slice* returns Result
- To resolve #101, we check if the array is contiguous when getting a slice
1 parent 935463e commit 2b411d2

File tree

5 files changed

+141
-73
lines changed

5 files changed

+141
-73
lines changed

src/array.rs

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,24 @@ impl<T, D> PyArray<T, D> {
145145
self.as_ptr() as _
146146
}
147147

148+
#[inline(always)]
149+
fn check_flag(&self, flag: c_int) -> bool {
150+
unsafe { (*self.as_array_ptr()).flags & flag == flag }
151+
}
152+
153+
pub fn is_contiguous(&self) -> bool {
154+
self.check_flag(npyffi::NPY_ARRAY_C_CONTIGUOUS)
155+
| self.check_flag(npyffi::NPY_ARRAY_F_CONTIGUOUS)
156+
}
157+
158+
pub fn is_forran_contiguous(&self) -> bool {
159+
self.check_flag(npyffi::NPY_ARRAY_F_CONTIGUOUS)
160+
}
161+
162+
pub fn is_c_contiguous(&self) -> bool {
163+
self.check_flag(npyffi::NPY_ARRAY_C_CONTIGUOUS)
164+
}
165+
148166
/// Get `Py<PyArray>` from `&PyArray`, which is the owned wrapper of PyObject.
149167
///
150168
/// You can use this method when you have to avoid lifetime annotation to your function args
@@ -404,15 +422,23 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
404422
/// assert_eq!(py_array.as_slice(), &[0, 1, 2, 3]);
405423
/// # }
406424
/// ```
407-
pub fn as_slice(&self) -> &[T] {
408-
self.type_check_assert();
409-
unsafe { ::std::slice::from_raw_parts(self.data(), self.len()) }
425+
pub fn as_slice(&self) -> Result<&[T], ErrorKind> {
426+
self.type_check()?;
427+
if !self.is_contiguous() {
428+
Err(ErrorKind::NotContiguous)
429+
} else {
430+
Ok(unsafe { ::std::slice::from_raw_parts(self.data(), self.len()) })
431+
}
410432
}
411433

412434
/// Get the mmutable view of the internal data of `PyArray`, as slice.
413-
pub fn as_slice_mut(&self) -> &mut [T] {
414-
self.type_check_assert();
415-
unsafe { ::std::slice::from_raw_parts_mut(self.data(), self.len()) }
435+
pub fn as_slice_mut(&self) -> Result<&mut [T], ErrorKind> {
436+
self.type_check()?;
437+
if !self.is_contiguous() {
438+
Err(ErrorKind::NotContiguous)
439+
} else {
440+
Ok(unsafe { ::std::slice::from_raw_parts_mut(self.data(), self.len()) })
441+
}
416442
}
417443

418444
/// Construct PyArray from `ndarray::ArrayBase`.
@@ -614,10 +640,7 @@ impl<T: TypeNum + Clone, D: Dimension> PyArray<T, D> {
614640
/// # }
615641
/// ```
616642
pub fn to_owned_array(&self) -> Array<T, D> {
617-
unsafe {
618-
let vec = self.as_slice().to_vec();
619-
Array::from_shape_vec_unchecked(self.ndarray_shape(), vec)
620-
}
643+
self.as_array().to_owned()
621644
}
622645
}
623646

src/error.rs

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ pub enum ErrorKind {
9595
FromVec { dim1: usize, dim2: usize },
9696
/// Error in numpy -> numpy data conversion
9797
PyToPy(Box<(ArrayShape, ArrayShape)>),
98+
/// The array need to be contiguous to finish the opretion
99+
NotContiguous,
98100
}
99101

100102
impl ErrorKind {
@@ -146,6 +148,7 @@ impl fmt::Display for ErrorKind {
146148
"Cast failed: from=ndarray({}), to=ndarray(dtype={})",
147149
e.0, e.1,
148150
),
151+
ErrorKind::NotContiguous => write!(f, "This array is not contiguous!"),
149152
}
150153
}
151154
}
@@ -154,11 +157,7 @@ impl error::Error for ErrorKind {}
154157

155158
impl From<ErrorKind> for PyErr {
156159
fn from(err: ErrorKind) -> PyErr {
157-
match err {
158-
ErrorKind::PyToRust { .. } | ErrorKind::FromVec { .. } | ErrorKind::PyToPy(_) => {
159-
PyErr::new::<exc::TypeError, _>(format!("{}", err))
160-
}
161-
}
160+
PyErr::new::<exc::TypeError, _>(format!("{}", err))
162161
}
163162
}
164163

@@ -167,10 +166,6 @@ impl IntoPyErr for ErrorKind {
167166
Into::into(self)
168167
}
169168
fn into_pyerr_with<D: fmt::Display>(self, msg: impl FnOnce() -> D) -> PyErr {
170-
match self {
171-
ErrorKind::PyToRust { .. } | ErrorKind::FromVec { .. } | ErrorKind::PyToPy(_) => {
172-
PyErr::new::<exc::TypeError, _>(format!("{}\n context: {}", self, msg()))
173-
}
174-
}
169+
PyErr::new::<exc::TypeError, _>(format!("{}\n context: {}", self, msg()))
175170
}
176171
}

src/npyffi/flags.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
use std::os::raw::c_int;
2+
3+
pub const NPY_ARRAY_C_CONTIGUOUS: c_int = 0x0001;
4+
pub const NPY_ARRAY_F_CONTIGUOUS: c_int = 0x0002;
5+
pub const NPY_ARRAY_OWNDATA: c_int = 0x0004;
6+
pub const NPY_ARRAY_FORCECAST: c_int = 0x0010;
7+
pub const NPY_ARRAY_ENSURECOPY: c_int = 0x0020;
8+
pub const NPY_ARRAY_ENSUREARRAY: c_int = 0x0040;
9+
pub const NPY_ARRAY_ELEMENTSTRIDES: c_int = 0x0080;
10+
pub const NPY_ARRAY_ALIGNED: c_int = 0x0100;
11+
pub const NPY_ARRAY_NOTSWAPPED: c_int = 0x0200;
12+
pub const NPY_ARRAY_WRITEABLE: c_int = 0x0400;
13+
pub const NPY_ARRAY_UPDATEIFCOPY: c_int = 0x1000;
14+
pub const NPY_ARRAY_WRITEBACKIFCOPY: c_int = 0x2000;
15+
pub const NPY_ARRAY_BEHAVED: c_int = NPY_ARRAY_ALIGNED | NPY_ARRAY_WRITEABLE;
16+
pub const NPY_ARRAY_BEHAVED_NS: c_int = NPY_ARRAY_BEHAVED | NPY_ARRAY_NOTSWAPPED;
17+
pub const NPY_ARRAY_CARRAY: c_int = NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_BEHAVED;
18+
pub const NPY_ARRAY_CARRAY_RO: c_int = NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_ALIGNED;
19+
pub const NPY_ARRAY_FARRAY: c_int = NPY_ARRAY_F_CONTIGUOUS | NPY_ARRAY_BEHAVED;
20+
pub const NPY_ARRAY_FARRAY_RO: c_int = NPY_ARRAY_F_CONTIGUOUS | NPY_ARRAY_ALIGNED;
21+
pub const NPY_ARRAY_DEFAULT: c_int = NPY_ARRAY_CARRAY;
22+
pub const NPY_ARRAY_IN_ARRAY: c_int = NPY_ARRAY_CARRAY_RO;
23+
pub const NPY_ARRAY_OUT_ARRAY: c_int = NPY_ARRAY_CARRAY;
24+
pub const NPY_ARRAY_INOUT_ARRAY: c_int = NPY_ARRAY_CARRAY | NPY_ARRAY_UPDATEIFCOPY;
25+
pub const NPY_ARRAY_INOUT_ARRAY2: c_int = NPY_ARRAY_CARRAY | NPY_ARRAY_WRITEBACKIFCOPY;
26+
pub const NPY_ARRAY_IN_FARRAY: c_int = NPY_ARRAY_FARRAY_RO;
27+
pub const NPY_ARRAY_OUT_FARRAY: c_int = NPY_ARRAY_FARRAY;
28+
pub const NPY_ARRAY_INOUT_FARRAY: c_int = NPY_ARRAY_FARRAY | NPY_ARRAY_UPDATEIFCOPY;
29+
pub const NPY_ARRAY_INOUT_FARRAY2: c_int = NPY_ARRAY_FARRAY | NPY_ARRAY_WRITEBACKIFCOPY;
30+
pub const NPY_ARRAY_UPDATE_ALL: c_int = NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_F_CONTIGUOUS;

src/npyffi/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,13 @@ macro_rules! impl_api {
4848
}
4949

5050
pub mod array;
51+
pub mod flags;
5152
pub mod objects;
5253
pub mod types;
5354
pub mod ufunc;
5455

5556
pub use self::array::*;
57+
pub use self::flags::*;
5658
pub use self::objects::*;
5759
pub use self::types::*;
5860
pub use self::ufunc::*;

0 commit comments

Comments
 (0)