Skip to content

Commit eee5e5a

Browse files
authored
Merge pull request #80 from kngwyu/refactor
Refactor error types
2 parents 97883a4 + 8baa60c commit eee5e5a

File tree

5 files changed

+157
-98
lines changed

5 files changed

+157
-98
lines changed

src/array.rs

Lines changed: 41 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
use ndarray::*;
33
use npyffi::{self, npy_intp, NPY_ORDER, PY_ARRAY_API};
44
use num_traits::AsPrimitive;
5-
use pyo3::{exceptions::TypeError, ffi, prelude::*, types::PyObjectRef};
5+
use pyo3::{ffi, prelude::*, types::PyObjectRef};
66
use pyo3::{PyDowncastError, PyObjectWithToken, ToPyPointer};
77
use std::iter::ExactSizeIterator;
88
use std::marker::PhantomData;
@@ -116,30 +116,21 @@ impl<'a, T, D> ::std::convert::From<&'a PyArray<T, D>> for &'a PyObjectRef {
116116
}
117117

118118
impl<'a, T: TypeNum, D: Dimension> FromPyObject<'a> for &'a PyArray<T, D> {
119-
// here we do type-check twice
119+
// here we do type-check three times
120120
// 1. Checks if the object is PyArray
121121
// 2. Checks if the data type of the array is T
122+
// 3. Checks if the dimension is same as D
122123
fn extract(ob: &'a PyObjectRef) -> PyResult<Self> {
123124
let array = unsafe {
124125
if npyffi::PyArray_Check(ob.as_ptr()) == 0 {
125126
return Err(PyDowncastError.into());
126127
}
127-
if let Some(ndim) = D::NDIM {
128-
let ptr = ob.as_ptr() as *mut npyffi::PyArrayObject;
129-
if (*ptr).nd as usize != ndim {
130-
return Err(PyErr::new::<TypeError, _>(format!(
131-
"specified dim was {}, but actual dim was {}",
132-
ndim,
133-
(*ptr).nd
134-
)));
135-
}
136-
}
137128
&*(ob as *const PyObjectRef as *const PyArray<T, D>)
138129
};
139130
array
140131
.type_check()
141132
.map(|_| array)
142-
.into_pyresult_with(|| "FromPyObject::extract typecheck failed")
133+
.into_pyresult_with(|| "[FromPyObject::extract] typecheck failed")
143134
}
144135
}
145136

@@ -398,6 +389,27 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
398389
}
399390
}
400391

392+
/// Get the immutable view of the internal data of `PyArray`, as slice.
393+
/// # Example
394+
/// ```
395+
/// # extern crate pyo3; extern crate numpy; fn main() {
396+
/// use numpy::PyArray;
397+
/// let gil = pyo3::Python::acquire_gil();
398+
/// let py_array = PyArray::arange(gil.python(), 0, 4, 1).reshape([2, 2]).unwrap();
399+
/// assert_eq!(py_array.as_slice(), &[0, 1, 2, 3]);
400+
/// # }
401+
/// ```
402+
pub fn as_slice(&self) -> &[T] {
403+
self.type_check_assert();
404+
unsafe { ::std::slice::from_raw_parts(self.data(), self.len()) }
405+
}
406+
407+
/// Get the mmutable view of the internal data of `PyArray`, as slice.
408+
pub fn as_slice_mut(&self) -> &mut [T] {
409+
self.type_check_assert();
410+
unsafe { ::std::slice::from_raw_parts_mut(self.data(), self.len()) }
411+
}
412+
401413
/// Construct PyArray from `ndarray::ArrayBase`.
402414
///
403415
/// This method allocates memory in Python's heap via numpy api, and then copies all elements
@@ -584,6 +596,22 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
584596
let python = self.py();
585597
unsafe { PyArray::from_borrowed_ptr(python, self.as_ptr()) }
586598
}
599+
600+
fn type_check_assert(&self) {
601+
let type_check = self.type_check();
602+
assert!(type_check.is_ok(), "{:?}", type_check);
603+
}
604+
605+
fn type_check(&self) -> Result<(), ErrorKind> {
606+
let truth = self.typenum();
607+
let dim = self.shape().len();
608+
let dim_ok = D::NDIM.map(|n| n == dim).unwrap_or(true);
609+
if T::is_same_type(truth) && dim_ok {
610+
Ok(())
611+
} else {
612+
Err(ErrorKind::to_rust(truth, dim, T::npy_data_type(), D::NDIM))
613+
}
614+
}
587615
}
588616

589617
impl<T: TypeNum> PyArray<T, Ix1> {
@@ -828,41 +856,6 @@ impl<T: TypeNum, D> PyArray<T, D> {
828856
NpyDataType::from_i32(self.typenum())
829857
}
830858

831-
fn type_check_assert(&self) {
832-
let type_check = self.type_check();
833-
assert!(type_check.is_ok(), "{:?}", type_check);
834-
}
835-
836-
fn type_check(&self) -> Result<(), ErrorKind> {
837-
let truth = self.typenum();
838-
if T::is_same_type(truth) {
839-
Ok(())
840-
} else {
841-
Err(ErrorKind::to_rust(truth, T::npy_data_type()))
842-
}
843-
}
844-
845-
/// Get the immutable view of the internal data of `PyArray`, as slice.
846-
/// # Example
847-
/// ```
848-
/// # extern crate pyo3; extern crate numpy; fn main() {
849-
/// use numpy::PyArray;
850-
/// let gil = pyo3::Python::acquire_gil();
851-
/// let py_array = PyArray::arange(gil.python(), 0, 4, 1).reshape([2, 2]).unwrap();
852-
/// assert_eq!(py_array.as_slice(), &[0, 1, 2, 3]);
853-
/// # }
854-
/// ```
855-
pub fn as_slice(&self) -> &[T] {
856-
self.type_check_assert();
857-
unsafe { ::std::slice::from_raw_parts(self.data(), self.len()) }
858-
}
859-
860-
/// Get the mmutable view of the internal data of `PyArray`, as slice.
861-
pub fn as_slice_mut(&self) -> &mut [T] {
862-
self.type_check_assert();
863-
unsafe { ::std::slice::from_raw_parts_mut(self.data(), self.len()) }
864-
}
865-
866859
/// Copies self into `other`, performing a data-type conversion if necessary.
867860
/// # Example
868861
/// ```

src/error.rs

Lines changed: 69 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -31,74 +31,101 @@ impl<T, E: IntoPyErr> IntoPyResult for Result<T, E> {
3131
}
3232
}
3333

34-
/// Represents a shape and format of numpy array.
34+
/// Represents a shape and dtype of numpy array.
3535
///
3636
/// Only for error formatting.
3737
#[derive(Debug)]
38-
pub struct ArrayFormat {
38+
pub struct ArrayShape {
3939
pub dims: Box<[usize]>,
4040
pub dtype: NpyDataType,
4141
}
4242

43-
impl fmt::Display for ArrayFormat {
43+
impl ArrayShape {
44+
fn boxed_dims(dims: &[usize]) -> Box<[usize]> {
45+
dims.into_iter()
46+
.map(|&x| x)
47+
.collect::<Vec<_>>()
48+
.into_boxed_slice()
49+
}
50+
fn from_array<T: TypeNum, D>(array: &PyArray<T, D>) -> Self {
51+
ArrayShape {
52+
dims: Self::boxed_dims(array.shape()),
53+
dtype: T::npy_data_type(),
54+
}
55+
}
56+
fn from_dims<T: TypeNum, D: ToNpyDims>(dims: D) -> Self {
57+
ArrayShape {
58+
dims: Self::boxed_dims(dims.slice()),
59+
dtype: T::npy_data_type(),
60+
}
61+
}
62+
}
63+
64+
impl fmt::Display for ArrayShape {
4465
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
4566
write!(f, "dims={:?}, dtype={:?}", self.dims, self.dtype)
4667
}
4768
}
4869

70+
/// Represents a dimension and dtype of numpy array.
71+
///
72+
/// Only for error formatting.
73+
#[derive(Debug)]
74+
pub struct ArrayDim {
75+
pub dim: Option<usize>,
76+
pub dtype: NpyDataType,
77+
}
78+
79+
impl fmt::Display for ArrayDim {
80+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
81+
if let Some(d) = self.dim {
82+
write!(f, "dim={:?}, dtype={:?}", d, self.dtype)
83+
} else {
84+
write!(f, "dim=_, dtype={:?}", self.dtype)
85+
}
86+
}
87+
}
88+
4989
/// Represents a casting error between rust types and numpy array.
5090
#[derive(Debug)]
5191
pub enum ErrorKind {
5292
/// Error for casting `PyArray` into `ArrayView` or `ArrayViewMut`
53-
PyToRust { from: NpyDataType, to: NpyDataType },
93+
PyToRust { from: ArrayDim, to: ArrayDim },
5494
/// Error for casting rust's `Vec` into numpy array.
5595
FromVec { dim1: usize, dim2: usize },
5696
/// Error in numpy -> numpy data conversion
57-
PyToPy(Box<(ArrayFormat, ArrayFormat)>),
97+
PyToPy(Box<(ArrayShape, ArrayShape)>),
5898
}
5999

60100
impl ErrorKind {
61-
pub(crate) fn to_rust(from: i32, to: NpyDataType) -> Self {
101+
pub(crate) fn to_rust(
102+
from_t: i32,
103+
from_d: usize,
104+
to_t: NpyDataType,
105+
to_d: Option<usize>,
106+
) -> Self {
62107
ErrorKind::PyToRust {
63-
from: NpyDataType::from_i32(from),
64-
to,
108+
from: ArrayDim {
109+
dim: Some(from_d),
110+
dtype: NpyDataType::from_i32(from_t),
111+
},
112+
to: ArrayDim {
113+
dim: to_d,
114+
dtype: to_t,
115+
},
65116
}
66117
}
67118
pub(crate) fn dtype_cast<T: TypeNum, D>(from: &PyArray<T, D>, to: NpyDataType) -> Self {
68-
let dims = from
69-
.shape()
70-
.into_iter()
71-
.map(|&x| x)
72-
.collect::<Vec<_>>()
73-
.into_boxed_slice();
74-
let from = ArrayFormat {
75-
dims: dims.clone(),
76-
dtype: T::npy_data_type(),
119+
let from = ArrayShape::from_array(from);
120+
let to = ArrayShape {
121+
dims: from.dims.clone(),
122+
dtype: to,
77123
};
78-
let to = ArrayFormat { dims, dtype: to };
79124
ErrorKind::PyToPy(Box::new((from, to)))
80125
}
81126
pub(crate) fn dims_cast<T: TypeNum, D>(from: &PyArray<T, D>, to_dim: impl ToNpyDims) -> Self {
82-
let dims_from = from
83-
.shape()
84-
.into_iter()
85-
.map(|&x| x)
86-
.collect::<Vec<_>>()
87-
.into_boxed_slice();
88-
let dims_to = to_dim
89-
.slice()
90-
.into_iter()
91-
.map(|&x| x)
92-
.collect::<Vec<_>>()
93-
.into_boxed_slice();
94-
let from = ArrayFormat {
95-
dims: dims_from,
96-
dtype: T::npy_data_type(),
97-
};
98-
let to = ArrayFormat {
99-
dims: dims_to,
100-
dtype: T::npy_data_type(),
101-
};
127+
let from = ArrayShape::from_array(from);
128+
let to = ArrayShape::from_dims::<T, _>(to_dim);
102129
ErrorKind::PyToPy(Box::new((from, to)))
103130
}
104131
}
@@ -107,16 +134,16 @@ impl fmt::Display for ErrorKind {
107134
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
108135
match self {
109136
ErrorKind::PyToRust { from, to } => {
110-
write!(f, "Cast failed: from={:?}, to={:?}", from, to)
137+
write!(f, "Extraction failed:\n from=({}), to=({})", from, to)
111138
}
112139
ErrorKind::FromVec { dim1, dim2 } => write!(
113140
f,
114-
"Cast failed: Vec To PyArray: expect all dim {} but {} was found",
141+
"Cast failed: Vec To PyArray:\n expect all dim {} but {} was found",
115142
dim1, dim2
116143
),
117144
ErrorKind::PyToPy(e) => write!(
118145
f,
119-
"Cast failed: from=ndarray({:?}), to=ndarray(dtype={:?})",
146+
"Cast failed: from=ndarray({}), to=ndarray(dtype={})",
120147
e.0, e.1,
121148
),
122149
}
@@ -142,7 +169,7 @@ impl IntoPyErr for ErrorKind {
142169
fn into_pyerr_with<D: fmt::Display>(self, msg: impl FnOnce() -> D) -> PyErr {
143170
match self {
144171
ErrorKind::PyToRust { .. } | ErrorKind::FromVec { .. } | ErrorKind::PyToPy(_) => {
145-
PyErr::new::<exc::TypeError, _>(format!("{} msg: {}", self, msg()))
172+
PyErr::new::<exc::TypeError, _>(format!("{}\n context: {}", self, msg()))
146173
}
147174
}
148175
}

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ pub use array::{
5353
PyArrayDyn,
5454
};
5555
pub use convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
56-
pub use error::{ArrayFormat, ErrorKind, IntoPyErr, IntoPyResult};
56+
pub use error::{ArrayShape, ErrorKind, IntoPyErr, IntoPyResult};
5757
pub use ndarray::{Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
5858
pub use npyffi::{PY_ARRAY_API, PY_UFUNC_API};
5959
pub use types::{c32, c64, NpyDataType, TypeNum};

src/slice_box.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ use crate::types::TypeNum;
22
use pyo3::{ffi, typeob, types::PyObjectRef, PyObjectAlloc, Python, ToPyPointer};
33
use std::os::raw::c_void;
44

5+
/// It's a memory store for IntoPyArray.
6+
/// See IntoPyArray's doc for what concretely this type is for.
57
#[repr(C)]
68
pub(crate) struct SliceBox<T> {
79
ob_base: ffi::PyObject,
@@ -14,7 +16,7 @@ impl<T> SliceBox<T> {
1416
let type_ob = <Self as typeob::PyTypeInfo>::type_object() as *mut _;
1517
let base = ffi::_PyObject_New(type_ob);
1618
*base = ffi::PyObject_HEAD_INIT;
17-
(*base).ob_type = <Self as typeob::PyTypeInfo>::type_object() as *mut _;
19+
(*base).ob_type = type_ob;
1820
let self_ = base as *mut SliceBox<T>;
1921
(*self_).inner = Box::into_raw(box_);
2022
&*self_
@@ -28,12 +30,12 @@ impl<T> typeob::PyTypeInfo for SliceBox<T> {
2830
type Type = ();
2931
type BaseType = PyObjectRef;
3032
const NAME: &'static str = "SliceBox";
31-
const DESCRIPTION: &'static str = "Memory store for PyArray made by IntoPyArray.";
33+
const DESCRIPTION: &'static str = "Memory store for PyArray using rust's Box<[T]>.";
3234
const FLAGS: usize = 0;
33-
const SIZE: usize = { Self::OFFSET as usize + std::mem::size_of::<Self>() + 0 + 0 };
35+
const SIZE: usize = std::mem::size_of::<Self>();
3436
const OFFSET: isize = 0;
3537
#[inline]
36-
unsafe fn type_object() -> &'static mut ::pyo3::ffi::PyTypeObject {
38+
unsafe fn type_object() -> &'static mut ffi::PyTypeObject {
3739
static mut TYPE_OBJECT: ::pyo3::ffi::PyTypeObject = ::pyo3::ffi::PyTypeObject_INIT;
3840
&mut TYPE_OBJECT
3941
}

0 commit comments

Comments
 (0)