Skip to content

Commit a5d7f75

Browse files
committed
Refactor error types
1 parent 97883a4 commit a5d7f75

File tree

5 files changed

+153
-96
lines changed

5 files changed

+153
-96
lines changed

src/array.rs

Lines changed: 41 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
use ndarray::*;
33
use npyffi::{self, npy_intp, NPY_ORDER, PY_ARRAY_API};
44
use num_traits::AsPrimitive;
5-
use pyo3::{exceptions::TypeError, ffi, prelude::*, types::PyObjectRef};
5+
use pyo3::{ffi, prelude::*, types::PyObjectRef};
66
use pyo3::{PyDowncastError, PyObjectWithToken, ToPyPointer};
77
use std::iter::ExactSizeIterator;
88
use std::marker::PhantomData;
@@ -116,30 +116,21 @@ impl<'a, T, D> ::std::convert::From<&'a PyArray<T, D>> for &'a PyObjectRef {
116116
}
117117

118118
impl<'a, T: TypeNum, D: Dimension> FromPyObject<'a> for &'a PyArray<T, D> {
119-
// here we do type-check twice
119+
// here we do type-check three times
120120
// 1. Checks if the object is PyArray
121121
// 2. Checks if the data type of the array is T
122+
// 3. Checks if the dimension is same as D
122123
fn extract(ob: &'a PyObjectRef) -> PyResult<Self> {
123124
let array = unsafe {
124125
if npyffi::PyArray_Check(ob.as_ptr()) == 0 {
125126
return Err(PyDowncastError.into());
126127
}
127-
if let Some(ndim) = D::NDIM {
128-
let ptr = ob.as_ptr() as *mut npyffi::PyArrayObject;
129-
if (*ptr).nd as usize != ndim {
130-
return Err(PyErr::new::<TypeError, _>(format!(
131-
"specified dim was {}, but actual dim was {}",
132-
ndim,
133-
(*ptr).nd
134-
)));
135-
}
136-
}
137128
&*(ob as *const PyObjectRef as *const PyArray<T, D>)
138129
};
139130
array
140131
.type_check()
141132
.map(|_| array)
142-
.into_pyresult_with(|| "FromPyObject::extract typecheck failed")
133+
.into_pyresult_with(|| "[FromPyObject::extract] typecheck failed")
143134
}
144135
}
145136

@@ -398,6 +389,27 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
398389
}
399390
}
400391

392+
/// Get the immutable view of the internal data of `PyArray`, as slice.
393+
/// # Example
394+
/// ```
395+
/// # extern crate pyo3; extern crate numpy; fn main() {
396+
/// use numpy::PyArray;
397+
/// let gil = pyo3::Python::acquire_gil();
398+
/// let py_array = PyArray::arange(gil.python(), 0, 4, 1).reshape([2, 2]).unwrap();
399+
/// assert_eq!(py_array.as_slice(), &[0, 1, 2, 3]);
400+
/// # }
401+
/// ```
402+
pub fn as_slice(&self) -> &[T] {
403+
self.type_check_assert();
404+
unsafe { ::std::slice::from_raw_parts(self.data(), self.len()) }
405+
}
406+
407+
/// Get the mmutable view of the internal data of `PyArray`, as slice.
408+
pub fn as_slice_mut(&self) -> &mut [T] {
409+
self.type_check_assert();
410+
unsafe { ::std::slice::from_raw_parts_mut(self.data(), self.len()) }
411+
}
412+
401413
/// Construct PyArray from `ndarray::ArrayBase`.
402414
///
403415
/// This method allocates memory in Python's heap via numpy api, and then copies all elements
@@ -584,6 +596,22 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
584596
let python = self.py();
585597
unsafe { PyArray::from_borrowed_ptr(python, self.as_ptr()) }
586598
}
599+
600+
fn type_check_assert(&self) {
601+
let type_check = self.type_check();
602+
assert!(type_check.is_ok(), "{:?}", type_check);
603+
}
604+
605+
fn type_check(&self) -> Result<(), ErrorKind> {
606+
let truth = self.typenum();
607+
let dim = self.shape().len();
608+
let dim_ok = D::NDIM.map(|n| n == dim).unwrap_or(true);
609+
if T::is_same_type(truth) && dim_ok {
610+
Ok(())
611+
} else {
612+
Err(ErrorKind::to_rust(truth, dim, T::npy_data_type(), D::NDIM))
613+
}
614+
}
587615
}
588616

589617
impl<T: TypeNum> PyArray<T, Ix1> {
@@ -828,41 +856,6 @@ impl<T: TypeNum, D> PyArray<T, D> {
828856
NpyDataType::from_i32(self.typenum())
829857
}
830858

831-
fn type_check_assert(&self) {
832-
let type_check = self.type_check();
833-
assert!(type_check.is_ok(), "{:?}", type_check);
834-
}
835-
836-
fn type_check(&self) -> Result<(), ErrorKind> {
837-
let truth = self.typenum();
838-
if T::is_same_type(truth) {
839-
Ok(())
840-
} else {
841-
Err(ErrorKind::to_rust(truth, T::npy_data_type()))
842-
}
843-
}
844-
845-
/// Get the immutable view of the internal data of `PyArray`, as slice.
846-
/// # Example
847-
/// ```
848-
/// # extern crate pyo3; extern crate numpy; fn main() {
849-
/// use numpy::PyArray;
850-
/// let gil = pyo3::Python::acquire_gil();
851-
/// let py_array = PyArray::arange(gil.python(), 0, 4, 1).reshape([2, 2]).unwrap();
852-
/// assert_eq!(py_array.as_slice(), &[0, 1, 2, 3]);
853-
/// # }
854-
/// ```
855-
pub fn as_slice(&self) -> &[T] {
856-
self.type_check_assert();
857-
unsafe { ::std::slice::from_raw_parts(self.data(), self.len()) }
858-
}
859-
860-
/// Get the mmutable view of the internal data of `PyArray`, as slice.
861-
pub fn as_slice_mut(&self) -> &mut [T] {
862-
self.type_check_assert();
863-
unsafe { ::std::slice::from_raw_parts_mut(self.data(), self.len()) }
864-
}
865-
866859
/// Copies self into `other`, performing a data-type conversion if necessary.
867860
/// # Example
868861
/// ```

src/error.rs

Lines changed: 69 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -31,74 +31,101 @@ impl<T, E: IntoPyErr> IntoPyResult for Result<T, E> {
3131
}
3232
}
3333

34-
/// Represents a shape and format of numpy array.
34+
/// Represents a shape and dtype of numpy array.
3535
///
3636
/// Only for error formatting.
3737
#[derive(Debug)]
38-
pub struct ArrayFormat {
38+
pub struct ArrayShape {
3939
pub dims: Box<[usize]>,
4040
pub dtype: NpyDataType,
4141
}
4242

43-
impl fmt::Display for ArrayFormat {
43+
impl ArrayShape {
44+
fn boxed_dims(dims: &[usize]) -> Box<[usize]> {
45+
dims.into_iter()
46+
.map(|&x| x)
47+
.collect::<Vec<_>>()
48+
.into_boxed_slice()
49+
}
50+
fn from_array<T: TypeNum, D>(array: &PyArray<T, D>) -> Self {
51+
ArrayShape {
52+
dims: Self::boxed_dims(array.shape()),
53+
dtype: T::npy_data_type(),
54+
}
55+
}
56+
fn from_dims<T: TypeNum, D: ToNpyDims>(dims: D) -> Self {
57+
ArrayShape {
58+
dims: Self::boxed_dims(dims.slice()),
59+
dtype: T::npy_data_type(),
60+
}
61+
}
62+
}
63+
64+
impl fmt::Display for ArrayShape {
4465
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
4566
write!(f, "dims={:?}, dtype={:?}", self.dims, self.dtype)
4667
}
4768
}
4869

70+
/// Represents a dimension and dtype of numpy array.
71+
///
72+
/// Only for error formatting.
73+
#[derive(Debug)]
74+
pub struct ArrayDim {
75+
pub dim: Option<usize>,
76+
pub dtype: NpyDataType,
77+
}
78+
79+
impl fmt::Display for ArrayDim {
80+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
81+
if let Some(d) = self.dim {
82+
write!(f, "dim={:?}, dtype={:?}", d, self.dtype)
83+
} else {
84+
write!(f, "dim=_, dtype={:?}", self.dtype)
85+
}
86+
}
87+
}
88+
4989
/// Represents a casting error between rust types and numpy array.
5090
#[derive(Debug)]
5191
pub enum ErrorKind {
5292
/// Error for casting `PyArray` into `ArrayView` or `ArrayViewMut`
53-
PyToRust { from: NpyDataType, to: NpyDataType },
93+
PyToRust { from: ArrayDim, to: ArrayDim },
5494
/// Error for casting rust's `Vec` into numpy array.
5595
FromVec { dim1: usize, dim2: usize },
5696
/// Error in numpy -> numpy data conversion
57-
PyToPy(Box<(ArrayFormat, ArrayFormat)>),
97+
PyToPy(Box<(ArrayShape, ArrayShape)>),
5898
}
5999

60100
impl ErrorKind {
61-
pub(crate) fn to_rust(from: i32, to: NpyDataType) -> Self {
101+
pub(crate) fn to_rust(
102+
from_t: i32,
103+
from_d: usize,
104+
to_t: NpyDataType,
105+
to_d: Option<usize>,
106+
) -> Self {
62107
ErrorKind::PyToRust {
63-
from: NpyDataType::from_i32(from),
64-
to,
108+
from: ArrayDim {
109+
dim: Some(from_d),
110+
dtype: NpyDataType::from_i32(from_t),
111+
},
112+
to: ArrayDim {
113+
dim: to_d,
114+
dtype: to_t,
115+
},
65116
}
66117
}
67118
pub(crate) fn dtype_cast<T: TypeNum, D>(from: &PyArray<T, D>, to: NpyDataType) -> Self {
68-
let dims = from
69-
.shape()
70-
.into_iter()
71-
.map(|&x| x)
72-
.collect::<Vec<_>>()
73-
.into_boxed_slice();
74-
let from = ArrayFormat {
75-
dims: dims.clone(),
76-
dtype: T::npy_data_type(),
119+
let from = ArrayShape::from_array(from);
120+
let to = ArrayShape {
121+
dims: from.dims.clone(),
122+
dtype: to,
77123
};
78-
let to = ArrayFormat { dims, dtype: to };
79124
ErrorKind::PyToPy(Box::new((from, to)))
80125
}
81126
pub(crate) fn dims_cast<T: TypeNum, D>(from: &PyArray<T, D>, to_dim: impl ToNpyDims) -> Self {
82-
let dims_from = from
83-
.shape()
84-
.into_iter()
85-
.map(|&x| x)
86-
.collect::<Vec<_>>()
87-
.into_boxed_slice();
88-
let dims_to = to_dim
89-
.slice()
90-
.into_iter()
91-
.map(|&x| x)
92-
.collect::<Vec<_>>()
93-
.into_boxed_slice();
94-
let from = ArrayFormat {
95-
dims: dims_from,
96-
dtype: T::npy_data_type(),
97-
};
98-
let to = ArrayFormat {
99-
dims: dims_to,
100-
dtype: T::npy_data_type(),
101-
};
127+
let from = ArrayShape::from_array(from);
128+
let to = ArrayShape::from_dims::<T, _>(to_dim);
102129
ErrorKind::PyToPy(Box::new((from, to)))
103130
}
104131
}
@@ -107,16 +134,16 @@ impl fmt::Display for ErrorKind {
107134
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
108135
match self {
109136
ErrorKind::PyToRust { from, to } => {
110-
write!(f, "Cast failed: from={:?}, to={:?}", from, to)
137+
write!(f, "Extraction failed:\n from=({}), to=({})", from, to)
111138
}
112139
ErrorKind::FromVec { dim1, dim2 } => write!(
113140
f,
114-
"Cast failed: Vec To PyArray: expect all dim {} but {} was found",
141+
"Cast failed: Vec To PyArray:\n expect all dim {} but {} was found",
115142
dim1, dim2
116143
),
117144
ErrorKind::PyToPy(e) => write!(
118145
f,
119-
"Cast failed: from=ndarray({:?}), to=ndarray(dtype={:?})",
146+
"Cast failed: from=ndarray({}), to=ndarray(dtype={})",
120147
e.0, e.1,
121148
),
122149
}
@@ -142,7 +169,7 @@ impl IntoPyErr for ErrorKind {
142169
fn into_pyerr_with<D: fmt::Display>(self, msg: impl FnOnce() -> D) -> PyErr {
143170
match self {
144171
ErrorKind::PyToRust { .. } | ErrorKind::FromVec { .. } | ErrorKind::PyToPy(_) => {
145-
PyErr::new::<exc::TypeError, _>(format!("{} msg: {}", self, msg()))
172+
PyErr::new::<exc::TypeError, _>(format!("{}\n context: {}", self, msg()))
146173
}
147174
}
148175
}

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ pub use array::{
5353
PyArrayDyn,
5454
};
5555
pub use convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
56-
pub use error::{ArrayFormat, ErrorKind, IntoPyErr, IntoPyResult};
56+
pub use error::{ArrayShape, ErrorKind, IntoPyErr, IntoPyResult};
5757
pub use ndarray::{Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
5858
pub use npyffi::{PY_ARRAY_API, PY_UFUNC_API};
5959
pub use types::{c32, c64, NpyDataType, TypeNum};

src/slice_box.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ impl<T> SliceBox<T> {
1414
let type_ob = <Self as typeob::PyTypeInfo>::type_object() as *mut _;
1515
let base = ffi::_PyObject_New(type_ob);
1616
*base = ffi::PyObject_HEAD_INIT;
17-
(*base).ob_type = <Self as typeob::PyTypeInfo>::type_object() as *mut _;
17+
(*base).ob_type = type_ob;
1818
let self_ = base as *mut SliceBox<T>;
1919
(*self_).inner = Box::into_raw(box_);
2020
&*self_
@@ -33,7 +33,7 @@ impl<T> typeob::PyTypeInfo for SliceBox<T> {
3333
const SIZE: usize = { Self::OFFSET as usize + std::mem::size_of::<Self>() + 0 + 0 };
3434
const OFFSET: isize = 0;
3535
#[inline]
36-
unsafe fn type_object() -> &'static mut ::pyo3::ffi::PyTypeObject {
36+
unsafe fn type_object() -> &'static mut ffi::PyTypeObject {
3737
static mut TYPE_OBJECT: ::pyo3::ffi::PyTypeObject = ::pyo3::ffi::PyTypeObject_INIT;
3838
&mut TYPE_OBJECT
3939
}

tests/array.rs

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ fn from_vec3() {
148148
}
149149

150150
#[test]
151-
fn from_eval() {
151+
fn from_eval_to_fixed() {
152152
let gil = pyo3::Python::acquire_gil();
153153
let np = get_array_module(gil.python()).unwrap();
154154
let dict = PyDict::new(gil.python());
@@ -163,7 +163,26 @@ fn from_eval() {
163163
}
164164

165165
#[test]
166-
fn from_eval_fail() {
166+
fn from_eval_to_dyn() {
167+
let gil = pyo3::Python::acquire_gil();
168+
let np = get_array_module(gil.python()).unwrap();
169+
let dict = PyDict::new(gil.python());
170+
dict.set_item("np", np).unwrap();
171+
let pyarray: &PyArrayDyn<i32> = gil
172+
.python()
173+
.eval(
174+
"np.array([[1, 2], [3, 4]], dtype='int32')",
175+
Some(&dict),
176+
None,
177+
)
178+
.unwrap()
179+
.extract()
180+
.unwrap();
181+
assert_eq!(pyarray.as_slice(), &[1, 2, 3, 4]);
182+
}
183+
184+
#[test]
185+
fn from_eval_fail_by_dtype() {
167186
let gil = pyo3::Python::acquire_gil();
168187
let np = get_array_module(gil.python()).unwrap();
169188
let dict = PyDict::new(gil.python());
@@ -173,7 +192,25 @@ fn from_eval_fail() {
173192
.eval("np.array([1, 2, 3], dtype='float64')", Some(&dict), None)
174193
.unwrap()
175194
.extract();
176-
assert!(converted.is_err());
195+
converted
196+
.unwrap_err()
197+
.print_and_set_sys_last_vars(gil.python());
198+
}
199+
200+
#[test]
201+
fn from_eval_fail_by_dim() {
202+
let gil = pyo3::Python::acquire_gil();
203+
let np = get_array_module(gil.python()).unwrap();
204+
let dict = PyDict::new(gil.python());
205+
dict.set_item("np", np).unwrap();
206+
let converted: Result<&PyArray2<i32>, _> = gil
207+
.python()
208+
.eval("np.array([1, 2, 3], dtype='int32')", Some(&dict), None)
209+
.unwrap()
210+
.extract();
211+
converted
212+
.unwrap_err()
213+
.print_and_set_sys_last_vars(gil.python());
177214
}
178215

179216
macro_rules! small_array_test {

0 commit comments

Comments
 (0)