Skip to content

Commit c8a2d8c

Browse files
committed
Do not assume trivally copyable types for Element impls using DataType::Object.
1 parent 9c6244b commit c8a2d8c

File tree

2 files changed

+34
-21
lines changed

2 files changed

+34
-21
lines changed

src/array.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use std::{cell::Cell, mem, os::raw::c_int, ptr, slice};
1010
use std::{iter::ExactSizeIterator, marker::PhantomData};
1111

1212
use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
13-
use crate::dtype::Element;
13+
use crate::dtype::{DataType, Element};
1414
use crate::error::{FromVecError, NotContiguousError, ShapeError};
1515
use crate::slice_box::SliceBox;
1616

@@ -731,8 +731,17 @@ impl<T: Element> PyArray<T, Ix1> {
731731
/// ```
732732
pub fn from_slice<'py>(py: Python<'py>, slice: &[T]) -> &'py Self {
733733
let array = PyArray::new(py, [slice.len()], false);
734-
unsafe {
735-
array.copy_ptr(slice.as_ptr(), slice.len());
734+
if T::DATA_TYPE != DataType::Object {
735+
unsafe {
736+
array.copy_ptr(slice.as_ptr(), slice.len());
737+
}
738+
} else {
739+
unsafe {
740+
let data_ptr = array.data();
741+
for (i, item) in slice.iter().enumerate() {
742+
data_ptr.add(i).write(item.clone());
743+
}
744+
}
736745
}
737746
array
738747
}

src/convert.rs

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::{mem, os::raw::c_int};
77

88
use crate::{
99
npyffi::{self, npy_intp},
10-
Element, PyArray,
10+
DataType, Element, PyArray,
1111
};
1212

1313
/// Covnersion trait from some rust types to `PyArray`.
@@ -130,25 +130,29 @@ where
130130
type Dim = D;
131131
fn to_pyarray<'py>(&self, py: Python<'py>) -> &'py PyArray<Self::Item, Self::Dim> {
132132
let len = self.len();
133-
if let Some(order) = self.order() {
134-
// if the array is contiguous, copy it by `copy_ptr`.
135-
let strides = self.npy_strides();
136-
unsafe {
137-
let array = PyArray::new_(py, self.raw_dim(), strides.as_ptr(), order.to_flag());
138-
array.copy_ptr(self.as_ptr(), len);
139-
array
133+
match self.order() {
134+
Some(order) if A::DATA_TYPE != DataType::Object => {
135+
// if the array is contiguous, copy it by `copy_ptr`.
136+
let strides = self.npy_strides();
137+
unsafe {
138+
let array =
139+
PyArray::new_(py, self.raw_dim(), strides.as_ptr(), order.to_flag());
140+
array.copy_ptr(self.as_ptr(), len);
141+
array
142+
}
140143
}
141-
} else {
142-
// if the array is not contiguous, copy all elements by `ArrayBase::iter`.
143-
let dim = self.raw_dim();
144-
let strides = NpyStrides::from_dim(&dim, mem::size_of::<A>());
145-
unsafe {
146-
let array = PyArray::<A, _>::new_(py, dim, strides.as_ptr(), 0);
147-
let data_ptr = array.data();
148-
for (i, item) in self.iter().enumerate() {
149-
data_ptr.add(i).write(item.clone());
144+
_ => {
145+
// if the array is not contiguous, copy all elements by `ArrayBase::iter`.
146+
let dim = self.raw_dim();
147+
let strides = NpyStrides::from_dim(&dim, mem::size_of::<A>());
148+
unsafe {
149+
let array = PyArray::<A, _>::new_(py, dim, strides.as_ptr(), 0);
150+
let data_ptr = array.data();
151+
for (i, item) in self.iter().enumerate() {
152+
data_ptr.add(i).write(item.clone());
153+
}
154+
array
150155
}
151-
array
152156
}
153157
}
154158
}

0 commit comments

Comments
 (0)