Skip to content

Commit 98f1370

Browse files
committed
Add strides_usize and use it in ndarray_shape
To avoid allocation
1 parent 4ad2b21 commit 98f1370

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

src/array.rs

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,15 @@ impl<T, D> PyArray<T, D> {
264264
pub(crate) unsafe fn copy_ptr(&self, other: *const T, len: usize) {
265265
ptr::copy_nonoverlapping(other, self.data(), len)
266266
}
267+
268+
fn strides_usize(&self) -> &[usize] {
269+
let n = self.ndim();
270+
let ptr = self.as_array_ptr();
271+
unsafe {
272+
let p = (*ptr).strides;
273+
::std::slice::from_raw_parts(p as *const _, n)
274+
}
275+
}
267276
}
268277

269278
impl<T: TypeNum, D: Dimension> PyArray<T, D> {
@@ -276,11 +285,9 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
276285
fn ndarray_shape(&self) -> StrideShape<D> {
277286
let shape: Shape<_> = Dim(self.dims()).into();
278287
let size = mem::size_of::<T>();
279-
let st = D::from_dimension(&Dim(
280-
self.strides().iter()
281-
.map(|&s| s as usize / size)
282-
.collect::<Vec<_>>()
283-
)).unwrap();
288+
let mut st = D::from_dimension(&Dim(self.strides_usize()))
289+
.expect("PyArray::ndarray_shape: dimension mismatching");
290+
st.slice_mut().iter_mut().for_each(|e| *e /= size);
284291
shape.strides(st)
285292
}
286293

@@ -1001,6 +1008,6 @@ fn test_get_unchecked() {
10011008
#[test]
10021009
fn test_dyn_to_owned_array() {
10031010
let gil = pyo3::Python::acquire_gil();
1004-
let array = PyArray::from_vec2(gil.python(), &vec![vec![1,2], vec![3,4]]).unwrap();
1011+
let array = PyArray::from_vec2(gil.python(), &vec![vec![1, 2], vec![3, 4]]).unwrap();
10051012
array.into_dyn().to_owned_array();
10061013
}

0 commit comments

Comments
 (0)