Skip to content

Commit b75729e

Browse files
committed
Use short array optimization for numpy strides
1 parent 312f00a commit b75729e

File tree

3 files changed

+46
-20
lines changed

3 files changed

+46
-20
lines changed

src/array.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
355355
pub(crate) unsafe fn new_<'py, ID>(
356356
py: Python<'py>,
357357
dims: ID,
358-
strides: *mut npy_intp,
358+
strides: *const npy_intp,
359359
flag: c_int,
360360
) -> &'py Self
361361
where
@@ -367,7 +367,7 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
367367
dims.ndim_cint(),
368368
dims.as_dims_ptr(),
369369
T::typenum_default(),
370-
strides, // strides
370+
strides as *mut _, // strides
371371
ptr::null_mut(), // data
372372
0, // itemsize
373373
flag, // flag

src/convert.rs

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ where
5858
type Item = A;
5959
type Dim = D;
6060
fn into_pyarray<'py>(self, py: Python<'py>) -> &'py PyArray<Self::Item, Self::Dim> {
61-
let strides = npy_strides(&self);
61+
let strides = NpyStrides::from_array(&self);
6262
let dim = self.raw_dim();
6363
let boxed = self.into_raw_vec().into_boxed_slice();
6464
unsafe { PyArray::from_boxed_slice(py, dim, strides.as_ptr(), boxed) }
@@ -102,26 +102,55 @@ where
102102
type Dim = D;
103103
fn to_pyarray<'py>(&self, py: Python<'py>) -> &'py PyArray<Self::Item, Self::Dim> {
104104
let len = self.len();
105-
let mut strides = npy_strides(self);
105+
let strides = NpyStrides::from_array(self);
106106
unsafe {
107-
let array = PyArray::new_(py, self.raw_dim(), strides.as_mut_ptr() as *mut npy_intp, 0);
107+
let array = PyArray::new_(py, self.raw_dim(), strides.as_ptr(), 0);
108108
array.copy_ptr(self.as_ptr(), len);
109109
array
110110
}
111111
}
112112
}
113113

114-
fn npy_strides<S, D, A>(array: &ArrayBase<S, D>) -> Vec<npyffi::npy_intp>
115-
where
116-
S: Data<Elem = A>,
117-
D: Dimension,
118-
A: TypeNum,
119-
{
120-
array
121-
.strides()
122-
.into_iter()
123-
.map(|n| n * mem::size_of::<A>() as npyffi::npy_intp)
124-
.collect()
114+
/// Numpy strides with short array optimization
115+
enum NpyStrides {
116+
Short([npyffi::npy_intp; 8]),
117+
Long(Vec<npyffi::npy_intp>),
118+
}
119+
120+
impl NpyStrides {
121+
fn as_ptr(&self) -> *const npy_intp {
122+
match self {
123+
NpyStrides::Short(inner) => inner.as_ptr(),
124+
NpyStrides::Long(inner) => inner.as_ptr(),
125+
}
126+
}
127+
128+
fn from_array<A, S, D>(array: &ArrayBase<S, D>) -> Self
129+
where
130+
S: Data<Elem = A>,
131+
D: Dimension,
132+
A: TypeNum,
133+
{
134+
Self::from_strides(array.strides(), mem::size_of::<A>())
135+
}
136+
fn from_strides(strides: &[isize], type_size: usize) -> Self {
137+
let len = strides.len();
138+
let type_size = type_size as npyffi::npy_intp;
139+
if len <= 8 {
140+
let mut res = [0; 8];
141+
for i in 0..len {
142+
res[i] = strides[i] as npyffi::npy_intp * type_size;
143+
}
144+
NpyStrides::Short(res)
145+
} else {
146+
NpyStrides::Long(
147+
strides
148+
.into_iter()
149+
.map(|&n| n as npyffi::npy_intp * type_size)
150+
.collect(),
151+
)
152+
}
153+
}
125154
}
126155

127156
/// Utility trait to specify the dimention of array

src/error.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
//! Defines error types.
2-
3-
use crate::array::PyArray;
4-
use crate::convert::ToNpyDims;
5-
use crate::types::{NpyDataType, TypeNum};
2+
use crate::types::NpyDataType;
63
use pyo3::{exceptions as exc, PyErr, PyResult, Python};
74
use std::error;
85
use std::fmt;

0 commit comments

Comments
 (0)