@@ -171,7 +171,7 @@ use std::ops::Deref;
171
171
use ahash:: AHashMap ;
172
172
use ndarray:: { ArrayView , ArrayViewMut , Dimension , Ix1 , Ix2 , Ix3 , Ix4 , Ix5 , Ix6 , IxDyn } ;
173
173
use num_integer:: gcd;
174
- use pyo3:: { FromPyObject , PyAny , PyResult } ;
174
+ use pyo3:: { FromPyObject , PyAny , PyResult , Python } ;
175
175
176
176
use crate :: array:: PyArray ;
177
177
use crate :: cold;
@@ -672,52 +672,58 @@ where
672
672
}
673
673
674
674
fn base_address < T , D > ( array : & PyArray < T , D > ) -> usize {
675
- let py = array. py ( ) ;
676
- let mut array = array. as_array_ptr ( ) ;
677
-
678
- loop {
679
- let base = unsafe { ( * array) . base } ;
680
-
681
- if base. is_null ( ) {
682
- return array as usize ;
683
- } else if unsafe { npyffi:: PyArray_Check ( py, base) } != 0 {
684
- array = base as * mut PyArrayObject ;
685
- } else {
686
- return base as usize ;
675
+ fn inner ( py : Python , mut array : * mut PyArrayObject ) -> usize {
676
+ loop {
677
+ let base = unsafe { ( * array) . base } ;
678
+
679
+ if base. is_null ( ) {
680
+ return array as usize ;
681
+ } else if unsafe { npyffi:: PyArray_Check ( py, base) } != 0 {
682
+ array = base as * mut PyArrayObject ;
683
+ } else {
684
+ return base as usize ;
685
+ }
687
686
}
688
687
}
688
+
689
+ inner ( array. py ( ) , array. as_array_ptr ( ) )
689
690
}
690
691
691
692
fn data_range < T , D > ( array : & PyArray < T , D > ) -> ( usize , usize )
692
693
where
693
694
T : Element ,
694
695
D : Dimension ,
695
696
{
696
- let shape = array. shape ( ) ;
697
- let strides = array. strides ( ) ;
698
-
699
- let mut start = 0 ;
700
- let mut end = 0 ;
697
+ fn inner ( shape : & [ usize ] , strides : & [ isize ] , itemsize : isize , data : * mut u8 ) -> ( usize , usize ) {
698
+ let mut start = 0 ;
699
+ let mut end = 0 ;
701
700
702
- if shape. iter ( ) . all ( |dim| * dim != 0 ) {
703
- for ( & dim, & stride) in shape. iter ( ) . zip ( strides) {
704
- let offset = ( dim - 1 ) as isize * stride;
701
+ if shape. iter ( ) . all ( |dim| * dim != 0 ) {
702
+ for ( & dim, & stride) in shape. iter ( ) . zip ( strides) {
703
+ let offset = ( dim - 1 ) as isize * stride;
705
704
706
- if offset >= 0 {
707
- end += offset;
708
- } else {
709
- start += offset;
705
+ if offset >= 0 {
706
+ end += offset;
707
+ } else {
708
+ start += offset;
709
+ }
710
710
}
711
+
712
+ end += itemsize;
711
713
}
712
714
713
- end += size_of :: < T > ( ) as isize ;
714
- }
715
+ let start = unsafe { data . offset ( start ) } as usize ;
716
+ let end = unsafe { data . offset ( end ) } as usize ;
715
717
716
- let data = unsafe { ( * array. as_array_ptr ( ) ) . data } ;
717
- let start = unsafe { data. offset ( start) } as usize ;
718
- let end = unsafe { data. offset ( end) } as usize ;
718
+ ( start, end)
719
+ }
719
720
720
- ( start, end)
721
+ inner (
722
+ array. shape ( ) ,
723
+ array. strides ( ) ,
724
+ size_of :: < T > ( ) as _ ,
725
+ array. data ( ) as _ ,
726
+ )
721
727
}
722
728
723
729
// FIXME(adamreichold): Use `usize::abs_diff` from std when that becomes stable.
0 commit comments