@@ -358,12 +358,46 @@ impl<T, D> PyArray<T, D> {
358358 }
359359}
360360
361- struct InvertedAxises ( Vec < Axis > ) ;
361+ enum InvertedAxes {
362+ Short ( u32 ) ,
363+ Long ( Vec < usize > ) ,
364+ }
365+
366+ impl InvertedAxes {
367+ fn new ( len : usize ) -> Self {
368+ if len <= 32 {
369+ Self :: Short ( 0 )
370+ } else {
371+ Self :: Long ( Vec :: new ( ) )
372+ }
373+ }
374+
375+ fn push ( & mut self , axis : usize ) {
376+ match self {
377+ Self :: Short ( axes) => {
378+ debug_assert ! ( axis < 32 ) ;
379+ * axes |= 1 << axis;
380+ }
381+ Self :: Long ( axes) => {
382+ axes. push ( axis) ;
383+ }
384+ }
385+ }
362386
363- impl InvertedAxises {
364387 fn invert < S : RawData , D : Dimension > ( self , array : & mut ArrayBase < S , D > ) {
365- for axis in self . 0 {
366- array. invert_axis ( axis) ;
388+ match self {
389+ Self :: Short ( mut axes) => {
390+ while axes != 0 {
391+ let axis = axes. trailing_zeros ( ) as usize ;
392+ axes &= !( 1 << axis) ;
393+ array. invert_axis ( Axis ( axis) ) ;
394+ }
395+ }
396+ Self :: Long ( axes) => {
397+ for axis in axes {
398+ array. invert_axis ( Axis ( axis) ) ;
399+ }
400+ }
367401 }
368402 }
369403}
@@ -372,36 +406,39 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
372406 /// Same as [shape](#method.shape), but returns `D`
373407 #[ inline( always) ]
374408 pub fn dims ( & self ) -> D {
375- D :: from_dimension ( & Dim ( self . shape ( ) ) ) . expect ( "PyArray::dims different dimension " )
409+ D :: from_dimension ( & Dim ( self . shape ( ) ) ) . expect ( "mismatching dimensions " )
376410 }
377411
378- fn ndarray_shape_ptr ( & self ) -> ( StrideShape < D > , * mut T , InvertedAxises ) {
379- let shape_slice = self . shape ( ) ;
380- let shape: Shape < _ > = Dim ( self . dims ( ) ) . into ( ) ;
381- let sizeof_t = mem:: size_of :: < T > ( ) ;
412+ fn ndarray_shape_ptr ( & self ) -> ( StrideShape < D > , * mut T , InvertedAxes ) {
413+ let shape = self . shape ( ) ;
382414 let strides = self . strides ( ) ;
415+
383416 let mut new_strides = D :: zeros ( strides. len ( ) ) ;
384417 let mut data_ptr = unsafe { self . data ( ) } ;
385- let mut inverted_axises = vec ! [ ] ;
418+ let mut inverted_axes = InvertedAxes :: new ( strides. len ( ) ) ;
419+
386420 for i in 0 ..strides. len ( ) {
387421 // TODO(kngwyu): Replace this hacky negative strides support with
388422 // a proper constructor, when it's implemented.
389423 // See https://github.com/rust-ndarray/ndarray/issues/842 for more.
390424 if strides[ i] < 0 {
391425 // Move the pointer to the start position
392- let offset = strides[ i] * ( shape_slice [ i] as isize - 1 ) / sizeof_t as isize ;
426+ let offset = strides[ i] * ( shape [ i] as isize - 1 ) / mem :: size_of :: < T > ( ) as isize ;
393427 unsafe {
394428 data_ptr = data_ptr. offset ( offset) ;
395429 }
396- new_strides[ i] = ( -strides[ i] ) as usize / sizeof_t;
397- inverted_axises. push ( Axis ( i) ) ;
430+ new_strides[ i] = ( -strides[ i] ) as usize / mem:: size_of :: < T > ( ) ;
431+
432+ inverted_axes. push ( i) ;
398433 } else {
399- new_strides[ i] = strides[ i] as usize / sizeof_t ;
434+ new_strides[ i] = strides[ i] as usize / mem :: size_of :: < T > ( ) ;
400435 }
401436 }
402- let st = D :: from_dimension ( & Dim ( new_strides) )
403- . expect ( "PyArray::ndarray_shape: dimension mismatching" ) ;
404- ( shape. strides ( st) , data_ptr, InvertedAxises ( inverted_axises) )
437+
438+ let shape = Shape :: from ( D :: from_dimension ( & Dim ( shape) ) . expect ( "mismatching dimensions" ) ) ;
439+ let new_strides = D :: from_dimension ( & Dim ( new_strides) ) . expect ( "mismatching dimensions" ) ;
440+
441+ ( shape. strides ( new_strides) , data_ptr, inverted_axes)
405442 }
406443
407444 /// Creates a new uninitialized PyArray in python heap.
@@ -818,9 +855,9 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
818855 /// If the internal array is not readonly and can be mutated from Python code,
819856 /// holding the `ArrayView` might cause undefined behavior.
820857 pub unsafe fn as_array ( & self ) -> ArrayView < ' _ , T , D > {
821- let ( shape, ptr, inverted_axises ) = self . ndarray_shape_ptr ( ) ;
858+ let ( shape, ptr, inverted_axes ) = self . ndarray_shape_ptr ( ) ;
822859 let mut res = ArrayView :: from_shape_ptr ( shape, ptr) ;
823- inverted_axises . invert ( & mut res) ;
860+ inverted_axes . invert ( & mut res) ;
824861 res
825862 }
826863
@@ -830,25 +867,25 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
830867 /// If another reference to the internal data exists(e.g., `&[T]` or `ArrayView`),
831868 /// it might cause undefined behavior.
832869 pub unsafe fn as_array_mut ( & self ) -> ArrayViewMut < ' _ , T , D > {
833- let ( shape, ptr, inverted_axises ) = self . ndarray_shape_ptr ( ) ;
870+ let ( shape, ptr, inverted_axes ) = self . ndarray_shape_ptr ( ) ;
834871 let mut res = ArrayViewMut :: from_shape_ptr ( shape, ptr) ;
835- inverted_axises . invert ( & mut res) ;
872+ inverted_axes . invert ( & mut res) ;
836873 res
837874 }
838875
839876 /// Returns the internal array as [`RawArrayView`] enabling element access via raw pointers
840877 pub fn as_raw_array ( & self ) -> RawArrayView < T , D > {
841- let ( shape, ptr, inverted_axises ) = self . ndarray_shape_ptr ( ) ;
878+ let ( shape, ptr, inverted_axes ) = self . ndarray_shape_ptr ( ) ;
842879 let mut res = unsafe { RawArrayView :: from_shape_ptr ( shape, ptr) } ;
843- inverted_axises . invert ( & mut res) ;
880+ inverted_axes . invert ( & mut res) ;
844881 res
845882 }
846883
847884 /// Returns the internal array as [`RawArrayViewMut`] enabling element access via raw pointers
848885 pub fn as_raw_array_mut ( & self ) -> RawArrayViewMut < T , D > {
849- let ( shape, ptr, inverted_axises ) = self . ndarray_shape_ptr ( ) ;
886+ let ( shape, ptr, inverted_axes ) = self . ndarray_shape_ptr ( ) ;
850887 let mut res = unsafe { RawArrayViewMut :: from_shape_ptr ( shape, ptr) } ;
851- inverted_axises . invert ( & mut res) ;
888+ inverted_axes . invert ( & mut res) ;
852889 res
853890 }
854891
0 commit comments