Skip to content

Commit 75dee30

Browse files
committed
Implement PyArray::reshape
1 parent ef78219 commit 75dee30

File tree

4 files changed

+126
-25
lines changed

4 files changed

+126
-25
lines changed

src/array.rs

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -481,10 +481,7 @@ impl<T: TypeNum> PyArray<T> {
481481
let other_ptr = other.as_array_ptr();
482482
let result = unsafe { PY_ARRAY_API.PyArray_CopyInto(other_ptr, self_ptr) };
483483
if result == -1 {
484-
Err(ArrayCastError::Numpy {
485-
from: T::npy_data_type(),
486-
to: U::npy_data_type(),
487-
})
484+
Err(ArrayCastError::dtype_cast(self, U::npy_data_type()))
488485
} else {
489486
Ok(())
490487
}
@@ -506,10 +503,7 @@ impl<T: TypeNum> PyArray<T> {
506503
let other_ptr = other.as_array_ptr();
507504
let result = unsafe { PY_ARRAY_API.PyArray_MoveInto(other_ptr, self_ptr) };
508505
if result == -1 {
509-
Err(ArrayCastError::Numpy {
510-
from: T::npy_data_type(),
511-
to: U::npy_data_type(),
512-
})
506+
Err(ArrayCastError::dtype_cast(self, U::npy_data_type()))
513507
} else {
514508
Ok(())
515509
}
@@ -522,12 +516,11 @@ impl<T: TypeNum> PyArray<T> {
522516
/// use numpy::{PyArray, IntoPyArray};
523517
/// let gil = pyo3::Python::acquire_gil();
524518
/// let pyarray_f = PyArray::<f64>::arange(gil.python(), 2.0, 5.0, 1.0);
525-
/// let pyarray_i = pyarray_f.cast::<i32>(gil.python(), false).unwrap();
519+
/// let pyarray_i = pyarray_f.cast::<i32>(false).unwrap();
526520
/// assert_eq!(pyarray_i.as_slice().unwrap(), &[2, 3, 4]);
527521
/// # }
528522
pub fn cast<'py, U: TypeNum>(
529-
&self,
530-
py: Python<'py>,
523+
&'py self,
531524
is_fortran: bool,
532525
) -> Result<&'py PyArray<U>, ArrayCastError> {
533526
let ptr = unsafe {
@@ -539,12 +532,53 @@ impl<T: TypeNum> PyArray<T> {
539532
)
540533
};
541534
if ptr.is_null() {
542-
Err(ArrayCastError::Numpy {
543-
from: T::npy_data_type(),
544-
to: U::npy_data_type(),
545-
})
535+
Err(ArrayCastError::dtype_cast(self, U::npy_data_type()))
536+
} else {
537+
Ok(unsafe { PyArray::<U>::from_owned_ptr(self.py(), ptr) })
538+
}
539+
}
540+
541+
/// Construct a new array which has same values as self, same matrix order, but has different
542+
/// dimensions specified by `dims`.
543+
///
544+
/// Since a returned array can contain a same pointer as self, we highly recommend to drop an
545+
/// old array, if this method returns `Ok`.
546+
///
547+
/// # Example
548+
///
549+
/// ```
550+
/// # #[macro_use] extern crate ndarray; extern crate pyo3; extern crate numpy; fn main() {
551+
/// use numpy::PyArray;
552+
/// let gil = pyo3::Python::acquire_gil();
553+
/// let array = PyArray::from_vec(gil.python(), (0..9).collect());
554+
/// let array = array.reshape([3, 3]).unwrap();
555+
/// assert_eq!(array.as_array().unwrap(), array![[0, 1, 2], [3, 4, 5], [6, 7, 8]].into_dyn());
556+
/// assert!(array.reshape([5]).is_err());
557+
/// # }
558+
/// ```
559+
#[inline(always)]
560+
pub fn reshape<'py, D: ToNpyDims>(&'py self, dims: D) -> Result<&Self, ArrayCastError> {
561+
self.reshape_with_order(dims, NPY_ORDER::NPY_ANYORDER)
562+
}
563+
564+
/// Same as [reshape](method.reshape.html), but you can change the order of returned matrix.
565+
pub fn reshape_with_order<'py, D: ToNpyDims>(
566+
&'py self,
567+
dims: D,
568+
order: NPY_ORDER,
569+
) -> Result<&Self, ArrayCastError> {
570+
let mut np_dims = dims.to_npy_dims();
571+
let ptr = unsafe {
572+
PY_ARRAY_API.PyArray_Newshape(
573+
self.as_array_ptr(),
574+
&mut np_dims as *mut npyffi::PyArray_Dims,
575+
order,
576+
)
577+
};
578+
if ptr.is_null() {
579+
Err(ArrayCastError::dims_cast(self, dims))
546580
} else {
547-
unsafe { Ok(PyArray::<U>::from_owned_ptr(py, ptr)) }
581+
Ok(unsafe { PyArray::<T>::from_owned_ptr(self.py(), ptr) })
548582
}
549583
}
550584
}

src/convert.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ pub(crate) unsafe fn into_raw<T>(x: Vec<T>) -> *mut c_void {
9292
pub trait ToNpyDims {
9393
fn dims_len(&self) -> c_int;
9494
fn dims_ptr(&self) -> *mut npy_intp;
95+
fn dims_ref(&self) -> &[usize];
9596
fn to_npy_dims(&self) -> npyffi::PyArray_Dims {
9697
npyffi::PyArray_Dims {
9798
ptr: self.dims_ptr(),
@@ -110,6 +111,9 @@ macro_rules! array_dim_impls {
110111
fn dims_ptr(&self) -> *mut npy_intp {
111112
self.as_ptr() as *mut npy_intp
112113
}
114+
fn dims_ref(&self) -> &[usize] {
115+
self
116+
}
113117
}
114118
impl<'a> ToNpyDims for &'a [usize; $N] {
115119
fn dims_len(&self) -> c_int {
@@ -118,6 +122,9 @@ macro_rules! array_dim_impls {
118122
fn dims_ptr(&self) -> *mut npy_intp {
119123
self.as_ptr() as *mut npy_intp
120124
}
125+
fn dims_ref(&self) -> &[usize] {
126+
*self
127+
}
121128
}
122129
)+
123130
}
@@ -137,6 +144,9 @@ impl<'a> ToNpyDims for &'a [usize] {
137144
fn dims_ptr(&self) -> *mut npy_intp {
138145
self.as_ptr() as *mut npy_intp
139146
}
147+
fn dims_ref(&self) -> &[usize] {
148+
*self
149+
}
140150
}
141151

142152
impl ToNpyDims for Vec<usize> {
@@ -146,6 +156,9 @@ impl ToNpyDims for Vec<usize> {
146156
fn dims_ptr(&self) -> *mut npy_intp {
147157
self.as_ptr() as *mut npy_intp
148158
}
159+
fn dims_ref(&self) -> &[usize] {
160+
&*self
161+
}
149162
}
150163

151164
impl ToNpyDims for Box<[usize]> {
@@ -155,4 +168,7 @@ impl ToNpyDims for Box<[usize]> {
155168
fn dims_ptr(&self) -> *mut npy_intp {
156169
self.as_ptr() as *mut npy_intp
157170
}
171+
fn dims_ref(&self) -> &[usize] {
172+
&*self
173+
}
158174
}

src/error.rs

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
//! Defines error types.
22
3+
use array::PyArray;
4+
use convert::ToNpyDims;
35
use pyo3::*;
46
use std::error;
57
use std::fmt;
6-
use types::NpyDataType;
8+
use types::{NpyDataType, TypeNum};
79

810
pub trait IntoPyErr {
911
fn into_pyerr(self, msg: &str) -> PyErr;
@@ -21,6 +23,19 @@ impl<T, E: IntoPyErr> IntoPyResult for Result<T, E> {
2123
}
2224
}
2325

26+
/// Represents a shape and format of numpy array.
27+
#[derive(Debug)]
28+
pub struct ArrayFormat {
29+
pub dims: Box<[usize]>,
30+
pub dtype: NpyDataType,
31+
}
32+
33+
impl fmt::Display for ArrayFormat {
34+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
35+
write!(f, "dims={:?}, dtype={:?}", self.dims, self.dtype)
36+
}
37+
}
38+
2439
/// Represents a casting error between rust types and numpy array.
2540
#[derive(Debug)]
2641
pub enum ArrayCastError {
@@ -29,7 +44,7 @@ pub enum ArrayCastError {
2944
/// Error for casting rust's `Vec` into numpy array.
3045
FromVec,
3146
/// Error in numpy -> numpy data conversion
32-
Numpy { from: NpyDataType, to: NpyDataType },
47+
Numpy(Box<(ArrayFormat, ArrayFormat)>),
3348
}
3449

3550
impl ArrayCastError {
@@ -39,6 +54,43 @@ impl ArrayCastError {
3954
to,
4055
}
4156
}
57+
pub(crate) fn dtype_cast<T: TypeNum>(from: &PyArray<T>, to: NpyDataType) -> Self {
58+
let dims = from
59+
.shape()
60+
.into_iter()
61+
.map(|&x| x)
62+
.collect::<Vec<_>>()
63+
.into_boxed_slice();
64+
let from = ArrayFormat {
65+
dims: dims.clone(),
66+
dtype: T::npy_data_type(),
67+
};
68+
let to = ArrayFormat { dims, dtype: to };
69+
ArrayCastError::Numpy(Box::new((from, to)))
70+
}
71+
pub(crate) fn dims_cast<T: TypeNum>(from: &PyArray<T>, to_dim: impl ToNpyDims) -> Self {
72+
let dims_from = from
73+
.shape()
74+
.into_iter()
75+
.map(|&x| x)
76+
.collect::<Vec<_>>()
77+
.into_boxed_slice();
78+
let dims_to = to_dim
79+
.dims_ref()
80+
.into_iter()
81+
.map(|&x| x)
82+
.collect::<Vec<_>>()
83+
.into_boxed_slice();
84+
let from = ArrayFormat {
85+
dims: dims_from,
86+
dtype: T::npy_data_type(),
87+
};
88+
let to = ArrayFormat {
89+
dims: dims_to,
90+
dtype: T::npy_data_type(),
91+
};
92+
ArrayCastError::Numpy(Box::new((from, to)))
93+
}
4294
}
4395

4496
impl fmt::Display for ArrayCastError {
@@ -48,10 +100,10 @@ impl fmt::Display for ArrayCastError {
48100
write!(f, "Cast failed: from={:?}, to={:?}", from, to)
49101
}
50102
ArrayCastError::FromVec => write!(f, "Cast failed: FromVec (maybe invalid dimension)"),
51-
ArrayCastError::Numpy { from, to } => write!(
103+
ArrayCastError::Numpy(e) => write!(
52104
f,
53-
"Cast failed: from=ndarray(dtype={:?}), to=ndarray(dtype={:?})",
54-
from, to
105+
"Cast failed: from=ndarray({:?}), to=ndarray(dtype={:?})",
106+
e.0, e.1,
55107
),
56108
}
57109
}
@@ -67,9 +119,9 @@ impl IntoPyErr for ArrayCastError {
67119
from, to, msg
68120
),
69121
ArrayCastError::FromVec => format!("ArrayCastError::FromVec: {}", msg),
70-
ArrayCastError::Numpy { from, to } => format!(
122+
ArrayCastError::Numpy(e) => format!(
71123
"ArrayCastError::Numpy: from: {:?}, to: {:?}, msg: {}",
72-
from, to, msg
124+
e.0, e.1, msg
73125
),
74126
};
75127
PyErr::new::<exc::TypeError, _>(msg)

tests/array.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,9 @@ small_array_test!(i8 u8 i16 u16 i32 u32 i64 u64);
174174
#[test]
175175
fn array_cast() {
176176
let gil = pyo3::Python::acquire_gil();
177-
let py = gil.python();
178177
let vec2 = vec![vec![1.0, 2.0, 3.0]; 2];
179178
let arr_f64 = PyArray::from_vec2(gil.python(), &vec2).unwrap();
180-
let arr_i32: &PyArray<i32> = arr_f64.cast(py, false).unwrap();
179+
let arr_i32: &PyArray<i32> = arr_f64.cast(false).unwrap();
181180
assert_eq!(
182181
arr_i32.as_array().unwrap(),
183182
array![[1, 2, 3], [1, 2, 3]].into_dyn()

0 commit comments

Comments
 (0)