Skip to content

Commit 550efd0

Browse files
committed
Fix array.slice.to_pyarray
1 parent b75729e commit 550efd0

File tree

4 files changed

+97
-25
lines changed

4 files changed

+97
-25
lines changed

src/array.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ impl<T, D> PyArray<T, D> {
296296
}
297297

298298
/// Returns the pointer to the first element of the inner array.
299-
unsafe fn data(&self) -> *mut T {
299+
pub(crate) unsafe fn data(&self) -> *mut T {
300300
let ptr = self.as_array_ptr();
301301
(*ptr).data as *mut T
302302
}

src/convert.rs

Lines changed: 75 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ where
5858
type Item = A;
5959
type Dim = D;
6060
fn into_pyarray<'py>(self, py: Python<'py>) -> &'py PyArray<Self::Item, Self::Dim> {
61-
let strides = NpyStrides::from_array(&self);
61+
let strides = self.npy_strides();
6262
let dim = self.raw_dim();
6363
let boxed = self.into_raw_vec().into_boxed_slice();
6464
unsafe { PyArray::from_boxed_slice(py, dim, strides.as_ptr(), boxed) }
@@ -102,11 +102,68 @@ where
102102
type Dim = D;
103103
fn to_pyarray<'py>(&self, py: Python<'py>) -> &'py PyArray<Self::Item, Self::Dim> {
104104
let len = self.len();
105-
let strides = NpyStrides::from_array(self);
106-
unsafe {
107-
let array = PyArray::new_(py, self.raw_dim(), strides.as_ptr(), 0);
108-
array.copy_ptr(self.as_ptr(), len);
109-
array
105+
if let Some(order) = self.order() {
106+
// if the array is contiguous, copy it by `copy_ptr`.
107+
let strides = self.npy_strides();
108+
unsafe {
109+
let array = PyArray::new_(py, self.raw_dim(), strides.as_ptr(), order.to_flag());
110+
array.copy_ptr(self.as_ptr(), len);
111+
array
112+
}
113+
} else {
114+
// if the array is not contiguous, copy all elements by `ArrayBase::iter`.
115+
let dim = self.raw_dim();
116+
let strides = NpyStrides::from_dim(&dim, mem::size_of::<A>());
117+
unsafe {
118+
let array = PyArray::<A, _>::new_(py, dim, strides.as_ptr(), 0);
119+
let data_ptr = array.data();
120+
for (i, item) in self.iter().enumerate() {
121+
data_ptr.offset(i as isize).write(*item);
122+
}
123+
array
124+
}
125+
}
126+
}
127+
}
128+
129+
enum Order {
130+
Standard,
131+
Fortran,
132+
}
133+
134+
impl Order {
135+
fn to_flag(&self) -> c_int {
136+
match self {
137+
Order::Standard => 0,
138+
Order::Fortran => 1,
139+
}
140+
}
141+
}
142+
143+
trait ArrayExt {
144+
fn npy_strides(&self) -> NpyStrides;
145+
fn order(&self) -> Option<Order>;
146+
}
147+
148+
impl<A, S, D> ArrayExt for ArrayBase<S, D>
149+
where
150+
S: Data<Elem = A>,
151+
D: Dimension,
152+
{
153+
fn npy_strides(&self) -> NpyStrides {
154+
NpyStrides::new(
155+
self.strides().into_iter().map(|&x| x as npyffi::npy_intp),
156+
mem::size_of::<A>(),
157+
)
158+
}
159+
160+
fn order(&self) -> Option<Order> {
161+
if self.is_standard_layout() {
162+
Some(Order::Standard)
163+
} else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() {
164+
Some(Order::Fortran)
165+
} else {
166+
None
110167
}
111168
}
112169
}
@@ -124,31 +181,26 @@ impl NpyStrides {
124181
NpyStrides::Long(inner) => inner.as_ptr(),
125182
}
126183
}
127-
128-
fn from_array<A, S, D>(array: &ArrayBase<S, D>) -> Self
129-
where
130-
S: Data<Elem = A>,
131-
D: Dimension,
132-
A: TypeNum,
133-
{
134-
Self::from_strides(array.strides(), mem::size_of::<A>())
184+
fn from_dim<D: Dimension>(dim: &D, type_size: usize) -> Self {
185+
Self::new(
186+
dim.default_strides()
187+
.slice()
188+
.into_iter()
189+
.map(|&x| x as npyffi::npy_intp),
190+
type_size,
191+
)
135192
}
136-
fn from_strides(strides: &[isize], type_size: usize) -> Self {
193+
fn new(strides: impl ExactSizeIterator<Item = npyffi::npy_intp>, type_size: usize) -> Self {
137194
let len = strides.len();
138195
let type_size = type_size as npyffi::npy_intp;
139196
if len <= 8 {
140197
let mut res = [0; 8];
141-
for i in 0..len {
142-
res[i] = strides[i] as npyffi::npy_intp * type_size;
198+
for (i, s) in strides.enumerate() {
199+
res[i] = s * type_size;
143200
}
144201
NpyStrides::Short(res)
145202
} else {
146-
NpyStrides::Long(
147-
strides
148-
.into_iter()
149-
.map(|&n| n as npyffi::npy_intp * type_size)
150-
.collect(),
151-
)
203+
NpyStrides::Long(strides.map(|n| n as npyffi::npy_intp * type_size).collect())
152204
}
153205
}
154206
}

src/types.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ impl NpyDataType {
6868
}
6969
}
7070

71-
pub trait TypeNum {
71+
pub trait TypeNum: std::fmt::Debug + Copy {
7272
fn is_same_type(other: i32) -> bool;
7373
fn npy_data_type() -> NpyDataType;
7474
fn typenum_default() -> i32;

tests/to_py.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,23 @@ fn into_pyarray_cant_resize() {
108108
let arr = a.into_pyarray(gil.python());
109109
assert!(arr.resize(100).is_err())
110110
}
111+
112+
#[test]
113+
fn forder_to_pyarray() {
114+
let gil = pyo3::Python::acquire_gil();
115+
let py = gil.python();
116+
let matrix = Array2::from_shape_vec([4, 2], vec![0, 1, 2, 3, 4, 5, 6, 7]).unwrap();
117+
let fortran_matrix = matrix.reversed_axes();
118+
let fmat_py = fortran_matrix.to_pyarray(py);
119+
assert_eq!(fmat_py.as_array(), array![[0, 2, 4, 6], [1, 3, 5, 7]],);
120+
}
121+
122+
#[test]
123+
fn slice_to_pyarray() {
124+
let gil = pyo3::Python::acquire_gil();
125+
let py = gil.python();
126+
let matrix = Array2::from_shape_vec([4, 2], vec![0, 1, 2, 3, 4, 5, 6, 7]).unwrap();
127+
let slice = matrix.slice(s![1..4; -1, ..]);
128+
let slice_py = slice.to_pyarray(py);
129+
assert_eq!(slice_py.as_array(), array![[6, 7], [4, 5], [2, 3]],);
130+
}

0 commit comments

Comments
 (0)