Skip to content

Commit ca4a49e

Browse files
committed
Small buffer optimization for InvertedAxes
Since InvertedAxes is basically a set of small integers, we can avoid dynamic allocations by using a word-sized bitmap in the common case of up to 32 axes. This should also not unduly increase code slice as there is no dynamic spilling into the large representation, but the decision is rather taken once based on the dimensionality of the array.
1 parent 4fbdebd commit ca4a49e

File tree

1 file changed

+62
-25
lines changed

1 file changed

+62
-25
lines changed

src/array.rs

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -358,12 +358,46 @@ impl<T, D> PyArray<T, D> {
358358
}
359359
}
360360

361-
struct InvertedAxises(Vec<Axis>);
361+
enum InvertedAxes {
362+
Short(u32),
363+
Long(Vec<usize>),
364+
}
365+
366+
impl InvertedAxes {
367+
fn new(len: usize) -> Self {
368+
if len <= 32 {
369+
Self::Short(0)
370+
} else {
371+
Self::Long(Vec::new())
372+
}
373+
}
374+
375+
fn push(&mut self, axis: usize) {
376+
match self {
377+
Self::Short(axes) => {
378+
debug_assert!(axis < 32);
379+
*axes |= 1 << axis;
380+
}
381+
Self::Long(axes) => {
382+
axes.push(axis);
383+
}
384+
}
385+
}
362386

363-
impl InvertedAxises {
364387
fn invert<S: RawData, D: Dimension>(self, array: &mut ArrayBase<S, D>) {
365-
for axis in self.0 {
366-
array.invert_axis(axis);
388+
match self {
389+
Self::Short(mut axes) => {
390+
while axes != 0 {
391+
let axis = axes.trailing_zeros() as usize;
392+
axes &= !(1 << axis);
393+
array.invert_axis(Axis(axis));
394+
}
395+
}
396+
Self::Long(axes) => {
397+
for axis in axes {
398+
array.invert_axis(Axis(axis));
399+
}
400+
}
367401
}
368402
}
369403
}
@@ -372,36 +406,39 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
372406
/// Same as [shape](#method.shape), but returns `D`
373407
#[inline(always)]
374408
pub fn dims(&self) -> D {
375-
D::from_dimension(&Dim(self.shape())).expect("PyArray::dims different dimension")
409+
D::from_dimension(&Dim(self.shape())).expect("mismatching dimensions")
376410
}
377411

378-
fn ndarray_shape_ptr(&self) -> (StrideShape<D>, *mut T, InvertedAxises) {
379-
let shape_slice = self.shape();
380-
let shape: Shape<_> = Dim(self.dims()).into();
381-
let sizeof_t = mem::size_of::<T>();
412+
fn ndarray_shape_ptr(&self) -> (StrideShape<D>, *mut T, InvertedAxes) {
413+
let shape = self.shape();
382414
let strides = self.strides();
415+
383416
let mut new_strides = D::zeros(strides.len());
384417
let mut data_ptr = unsafe { self.data() };
385-
let mut inverted_axises = vec![];
418+
let mut inverted_axes = InvertedAxes::new(strides.len());
419+
386420
for i in 0..strides.len() {
387421
// TODO(kngwyu): Replace this hacky negative strides support with
388422
// a proper constructor, when it's implemented.
389423
// See https://github.com/rust-ndarray/ndarray/issues/842 for more.
390424
if strides[i] < 0 {
391425
// Move the pointer to the start position
392-
let offset = strides[i] * (shape_slice[i] as isize - 1) / sizeof_t as isize;
426+
let offset = strides[i] * (shape[i] as isize - 1) / mem::size_of::<T>() as isize;
393427
unsafe {
394428
data_ptr = data_ptr.offset(offset);
395429
}
396-
new_strides[i] = (-strides[i]) as usize / sizeof_t;
397-
inverted_axises.push(Axis(i));
430+
new_strides[i] = (-strides[i]) as usize / mem::size_of::<T>();
431+
432+
inverted_axes.push(i);
398433
} else {
399-
new_strides[i] = strides[i] as usize / sizeof_t;
434+
new_strides[i] = strides[i] as usize / mem::size_of::<T>();
400435
}
401436
}
402-
let st = D::from_dimension(&Dim(new_strides))
403-
.expect("PyArray::ndarray_shape: dimension mismatching");
404-
(shape.strides(st), data_ptr, InvertedAxises(inverted_axises))
437+
438+
let shape = Shape::from(D::from_dimension(&Dim(shape)).expect("mismatching dimensions"));
439+
let new_strides = D::from_dimension(&Dim(new_strides)).expect("mismatching dimensions");
440+
441+
(shape.strides(new_strides), data_ptr, inverted_axes)
405442
}
406443

407444
/// Creates a new uninitialized PyArray in python heap.
@@ -818,9 +855,9 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
818855
/// If the internal array is not readonly and can be mutated from Python code,
819856
/// holding the `ArrayView` might cause undefined behavior.
820857
pub unsafe fn as_array(&self) -> ArrayView<'_, T, D> {
821-
let (shape, ptr, inverted_axises) = self.ndarray_shape_ptr();
858+
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
822859
let mut res = ArrayView::from_shape_ptr(shape, ptr);
823-
inverted_axises.invert(&mut res);
860+
inverted_axes.invert(&mut res);
824861
res
825862
}
826863

@@ -830,25 +867,25 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
830867
/// If another reference to the internal data exists(e.g., `&[T]` or `ArrayView`),
831868
/// it might cause undefined behavior.
832869
pub unsafe fn as_array_mut(&self) -> ArrayViewMut<'_, T, D> {
833-
let (shape, ptr, inverted_axises) = self.ndarray_shape_ptr();
870+
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
834871
let mut res = ArrayViewMut::from_shape_ptr(shape, ptr);
835-
inverted_axises.invert(&mut res);
872+
inverted_axes.invert(&mut res);
836873
res
837874
}
838875

839876
/// Returns the internal array as [`RawArrayView`] enabling element access via raw pointers
840877
pub fn as_raw_array(&self) -> RawArrayView<T, D> {
841-
let (shape, ptr, inverted_axises) = self.ndarray_shape_ptr();
878+
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
842879
let mut res = unsafe { RawArrayView::from_shape_ptr(shape, ptr) };
843-
inverted_axises.invert(&mut res);
880+
inverted_axes.invert(&mut res);
844881
res
845882
}
846883

847884
/// Returns the internal array as [`RawArrayViewMut`] enabling element access via raw pointers
848885
pub fn as_raw_array_mut(&self) -> RawArrayViewMut<T, D> {
849-
let (shape, ptr, inverted_axises) = self.ndarray_shape_ptr();
886+
let (shape, ptr, inverted_axes) = self.ndarray_shape_ptr();
850887
let mut res = unsafe { RawArrayViewMut::from_shape_ptr(shape, ptr) };
851-
inverted_axises.invert(&mut res);
888+
inverted_axes.invert(&mut res);
852889
res
853890
}
854891

0 commit comments

Comments
 (0)