@@ -358,12 +358,46 @@ impl<T, D> PyArray<T, D> {
358
358
}
359
359
}
360
360
361
- struct InvertedAxises ( Vec < Axis > ) ;
361
+ enum InvertedAxes {
362
+ Short ( u32 ) ,
363
+ Long ( Vec < usize > ) ,
364
+ }
365
+
366
+ impl InvertedAxes {
367
+ fn new ( len : usize ) -> Self {
368
+ if len <= 32 {
369
+ Self :: Short ( 0 )
370
+ } else {
371
+ Self :: Long ( Vec :: new ( ) )
372
+ }
373
+ }
374
+
375
+ fn push ( & mut self , axis : usize ) {
376
+ match self {
377
+ Self :: Short ( axes) => {
378
+ debug_assert ! ( axis < 32 ) ;
379
+ * axes |= 1 << axis;
380
+ }
381
+ Self :: Long ( axes) => {
382
+ axes. push ( axis) ;
383
+ }
384
+ }
385
+ }
362
386
363
- impl InvertedAxises {
364
387
fn invert < S : RawData , D : Dimension > ( self , array : & mut ArrayBase < S , D > ) {
365
- for axis in self . 0 {
366
- array. invert_axis ( axis) ;
388
+ match self {
389
+ Self :: Short ( mut axes) => {
390
+ while axes != 0 {
391
+ let axis = axes. trailing_zeros ( ) as usize ;
392
+ axes &= !( 1 << axis) ;
393
+ array. invert_axis ( Axis ( axis) ) ;
394
+ }
395
+ }
396
+ Self :: Long ( axes) => {
397
+ for axis in axes {
398
+ array. invert_axis ( Axis ( axis) ) ;
399
+ }
400
+ }
367
401
}
368
402
}
369
403
}
@@ -372,36 +406,39 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
372
406
/// Same as [shape](#method.shape), but returns `D`
373
407
#[ inline( always) ]
374
408
pub fn dims ( & self ) -> D {
375
- D :: from_dimension ( & Dim ( self . shape ( ) ) ) . expect ( "PyArray::dims different dimension " )
409
+ D :: from_dimension ( & Dim ( self . shape ( ) ) ) . expect ( "mismatching dimensions " )
376
410
}
377
411
378
- fn ndarray_shape_ptr ( & self ) -> ( StrideShape < D > , * mut T , InvertedAxises ) {
379
- let shape_slice = self . shape ( ) ;
380
- let shape: Shape < _ > = Dim ( self . dims ( ) ) . into ( ) ;
381
- let sizeof_t = mem:: size_of :: < T > ( ) ;
412
+ fn ndarray_shape_ptr ( & self ) -> ( StrideShape < D > , * mut T , InvertedAxes ) {
413
+ let shape = self . shape ( ) ;
382
414
let strides = self . strides ( ) ;
415
+
383
416
let mut new_strides = D :: zeros ( strides. len ( ) ) ;
384
417
let mut data_ptr = unsafe { self . data ( ) } ;
385
- let mut inverted_axises = vec ! [ ] ;
418
+ let mut inverted_axes = InvertedAxes :: new ( strides. len ( ) ) ;
419
+
386
420
for i in 0 ..strides. len ( ) {
387
421
// TODO(kngwyu): Replace this hacky negative strides support with
388
422
// a proper constructor, when it's implemented.
389
423
// See https://github.com/rust-ndarray/ndarray/issues/842 for more.
390
424
if strides[ i] < 0 {
391
425
// Move the pointer to the start position
392
- let offset = strides[ i] * ( shape_slice [ i] as isize - 1 ) / sizeof_t as isize ;
426
+ let offset = strides[ i] * ( shape [ i] as isize - 1 ) / mem :: size_of :: < T > ( ) as isize ;
393
427
unsafe {
394
428
data_ptr = data_ptr. offset ( offset) ;
395
429
}
396
- new_strides[ i] = ( -strides[ i] ) as usize / sizeof_t;
397
- inverted_axises. push ( Axis ( i) ) ;
430
+ new_strides[ i] = ( -strides[ i] ) as usize / mem:: size_of :: < T > ( ) ;
431
+
432
+ inverted_axes. push ( i) ;
398
433
} else {
399
- new_strides[ i] = strides[ i] as usize / sizeof_t ;
434
+ new_strides[ i] = strides[ i] as usize / mem :: size_of :: < T > ( ) ;
400
435
}
401
436
}
402
- let st = D :: from_dimension ( & Dim ( new_strides) )
403
- . expect ( "PyArray::ndarray_shape: dimension mismatching" ) ;
404
- ( shape. strides ( st) , data_ptr, InvertedAxises ( inverted_axises) )
437
+
438
+ let shape = Shape :: from ( D :: from_dimension ( & Dim ( shape) ) . expect ( "mismatching dimensions" ) ) ;
439
+ let new_strides = D :: from_dimension ( & Dim ( new_strides) ) . expect ( "mismatching dimensions" ) ;
440
+
441
+ ( shape. strides ( new_strides) , data_ptr, inverted_axes)
405
442
}
406
443
407
444
/// Creates a new uninitialized PyArray in python heap.
@@ -818,9 +855,9 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
818
855
/// If the internal array is not readonly and can be mutated from Python code,
819
856
/// holding the `ArrayView` might cause undefined behavior.
820
857
pub unsafe fn as_array ( & self ) -> ArrayView < ' _ , T , D > {
821
- let ( shape, ptr, inverted_axises ) = self . ndarray_shape_ptr ( ) ;
858
+ let ( shape, ptr, inverted_axes ) = self . ndarray_shape_ptr ( ) ;
822
859
let mut res = ArrayView :: from_shape_ptr ( shape, ptr) ;
823
- inverted_axises . invert ( & mut res) ;
860
+ inverted_axes . invert ( & mut res) ;
824
861
res
825
862
}
826
863
@@ -830,25 +867,25 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
830
867
/// If another reference to the internal data exists(e.g., `&[T]` or `ArrayView`),
831
868
/// it might cause undefined behavior.
832
869
pub unsafe fn as_array_mut ( & self ) -> ArrayViewMut < ' _ , T , D > {
833
- let ( shape, ptr, inverted_axises ) = self . ndarray_shape_ptr ( ) ;
870
+ let ( shape, ptr, inverted_axes ) = self . ndarray_shape_ptr ( ) ;
834
871
let mut res = ArrayViewMut :: from_shape_ptr ( shape, ptr) ;
835
- inverted_axises . invert ( & mut res) ;
872
+ inverted_axes . invert ( & mut res) ;
836
873
res
837
874
}
838
875
839
876
/// Returns the internal array as [`RawArrayView`] enabling element access via raw pointers
840
877
pub fn as_raw_array ( & self ) -> RawArrayView < T , D > {
841
- let ( shape, ptr, inverted_axises ) = self . ndarray_shape_ptr ( ) ;
878
+ let ( shape, ptr, inverted_axes ) = self . ndarray_shape_ptr ( ) ;
842
879
let mut res = unsafe { RawArrayView :: from_shape_ptr ( shape, ptr) } ;
843
- inverted_axises . invert ( & mut res) ;
880
+ inverted_axes . invert ( & mut res) ;
844
881
res
845
882
}
846
883
847
884
/// Returns the internal array as [`RawArrayViewMut`] enabling element access via raw pointers
848
885
pub fn as_raw_array_mut ( & self ) -> RawArrayViewMut < T , D > {
849
- let ( shape, ptr, inverted_axises ) = self . ndarray_shape_ptr ( ) ;
886
+ let ( shape, ptr, inverted_axes ) = self . ndarray_shape_ptr ( ) ;
850
887
let mut res = unsafe { RawArrayViewMut :: from_shape_ptr ( shape, ptr) } ;
851
- inverted_axises . invert ( & mut res) ;
888
+ inverted_axes . invert ( & mut res) ;
852
889
res
853
890
}
854
891
0 commit comments