@@ -184,9 +184,9 @@ use crate::npyffi::{self, PyArrayObject, NPY_ARRAY_WRITEABLE};
184
184
#[ derive( Clone , Copy , PartialEq , Eq , Hash ) ]
185
185
struct BorrowKey {
186
186
/// exclusive range of lowest and highest address covered by array
187
- range : ( usize , usize ) ,
187
+ range : ( * mut u8 , * mut u8 ) ,
188
188
/// the data address on which address computations are based
189
- data_ptr : usize ,
189
+ data_ptr : * mut u8 ,
190
190
/// the greatest common divisor of the strides of the array
191
191
gcd_strides : isize ,
192
192
}
@@ -199,7 +199,7 @@ impl BorrowKey {
199
199
{
200
200
let range = data_range ( array) ;
201
201
202
- let data_ptr = array. data ( ) as usize ;
202
+ let data_ptr = array. data ( ) as * mut u8 ;
203
203
let gcd_strides = gcd_strides ( array. strides ( ) ) ;
204
204
205
205
Self {
@@ -225,7 +225,7 @@ impl BorrowKey {
225
225
// but fails when slicing an array with a step size that does not divide the dimension along that axis.
226
226
//
227
227
// https://users.rust-lang.org/t/math-for-borrow-checking-numpy-arrays/73303
228
- let ptr_diff = abs_diff ( self . data_ptr , other. data_ptr ) as isize ;
228
+ let ptr_diff = unsafe { self . data_ptr . offset_from ( other. data_ptr ) . abs ( ) } ;
229
229
let gcd_strides = gcd ( self . gcd_strides , other. gcd_strides ) ;
230
230
231
231
if ptr_diff % gcd_strides != 0 {
@@ -237,7 +237,7 @@ impl BorrowKey {
237
237
}
238
238
}
239
239
240
- type BorrowFlagsInner = AHashMap < usize , AHashMap < BorrowKey , isize > > ;
240
+ type BorrowFlagsInner = AHashMap < * mut u8 , AHashMap < BorrowKey , isize > > ;
241
241
242
242
struct BorrowFlags ( UnsafeCell < Option < BorrowFlagsInner > > ) ;
243
243
@@ -253,7 +253,7 @@ impl BorrowFlags {
253
253
( * self . 0 . get ( ) ) . get_or_insert_with ( AHashMap :: new)
254
254
}
255
255
256
- fn acquire ( & self , _py : Python , address : usize , key : BorrowKey ) -> Result < ( ) , BorrowError > {
256
+ fn acquire ( & self , _py : Python , address : * mut u8 , key : BorrowKey ) -> Result < ( ) , BorrowError > {
257
257
// SAFETY: Having `_py` implies holding the GIL and
258
258
// we are not calling into user code which might re-enter this function.
259
259
let borrow_flags = unsafe { BORROW_FLAGS . get ( ) } ;
@@ -296,7 +296,7 @@ impl BorrowFlags {
296
296
Ok ( ( ) )
297
297
}
298
298
299
- fn release ( & self , _py : Python , address : usize , key : BorrowKey ) {
299
+ fn release ( & self , _py : Python , address : * mut u8 , key : BorrowKey ) {
300
300
// SAFETY: Having `_py` implies holding the GIL and
301
301
// we are not calling into user code which might re-enter this function.
302
302
let borrow_flags = unsafe { BORROW_FLAGS . get ( ) } ;
@@ -316,7 +316,12 @@ impl BorrowFlags {
316
316
}
317
317
}
318
318
319
- fn acquire_mut ( & self , _py : Python , address : usize , key : BorrowKey ) -> Result < ( ) , BorrowError > {
319
+ fn acquire_mut (
320
+ & self ,
321
+ _py : Python ,
322
+ address : * mut u8 ,
323
+ key : BorrowKey ,
324
+ ) -> Result < ( ) , BorrowError > {
320
325
// SAFETY: Having `_py` implies holding the GIL and
321
326
// we are not calling into user code which might re-enter this function.
322
327
let borrow_flags = unsafe { BORROW_FLAGS . get ( ) } ;
@@ -353,7 +358,7 @@ impl BorrowFlags {
353
358
Ok ( ( ) )
354
359
}
355
360
356
- fn release_mut ( & self , _py : Python , address : usize , key : BorrowKey ) {
361
+ fn release_mut ( & self , _py : Python , address : * mut u8 , key : BorrowKey ) {
357
362
// SAFETY: Having `_py` implies holding the GIL and
358
363
// we are not calling into user code which might re-enter this function.
359
364
let borrow_flags = unsafe { BORROW_FLAGS . get ( ) } ;
@@ -383,7 +388,7 @@ where
383
388
D : Dimension ,
384
389
{
385
390
array : & ' py PyArray < T , D > ,
386
- address : usize ,
391
+ address : * mut u8 ,
387
392
key : BorrowKey ,
388
393
}
389
394
@@ -526,7 +531,7 @@ where
526
531
D : Dimension ,
527
532
{
528
533
array : & ' py PyArray < T , D > ,
529
- address : usize ,
534
+ address : * mut u8 ,
530
535
key : BorrowKey ,
531
536
}
532
537
@@ -680,30 +685,35 @@ where
680
685
}
681
686
}
682
687
683
- fn base_address < T , D > ( array : & PyArray < T , D > ) -> usize {
684
- fn inner ( py : Python , mut array : * mut PyArrayObject ) -> usize {
688
+ fn base_address < T , D > ( array : & PyArray < T , D > ) -> * mut u8 {
689
+ fn inner ( py : Python , mut array : * mut PyArrayObject ) -> * mut u8 {
685
690
loop {
686
691
let base = unsafe { ( * array) . base } ;
687
692
688
693
if base. is_null ( ) {
689
- return array as usize ;
694
+ return array as * mut u8 ;
690
695
} else if unsafe { npyffi:: PyArray_Check ( py, base) } != 0 {
691
696
array = base as * mut PyArrayObject ;
692
697
} else {
693
- return base as usize ;
698
+ return base as * mut u8 ;
694
699
}
695
700
}
696
701
}
697
702
698
703
inner ( array. py ( ) , array. as_array_ptr ( ) )
699
704
}
700
705
701
- fn data_range < T , D > ( array : & PyArray < T , D > ) -> ( usize , usize )
706
+ fn data_range < T , D > ( array : & PyArray < T , D > ) -> ( * mut u8 , * mut u8 )
702
707
where
703
708
T : Element ,
704
709
D : Dimension ,
705
710
{
706
- fn inner ( shape : & [ usize ] , strides : & [ isize ] , itemsize : isize , data : * mut u8 ) -> ( usize , usize ) {
711
+ fn inner (
712
+ shape : & [ usize ] ,
713
+ strides : & [ isize ] ,
714
+ itemsize : isize ,
715
+ data : * mut u8 ,
716
+ ) -> ( * mut u8 , * mut u8 ) {
707
717
let mut start = 0 ;
708
718
let mut end = 0 ;
709
719
@@ -721,33 +731,24 @@ where
721
731
end += itemsize;
722
732
}
723
733
724
- let start = unsafe { data. offset ( start) } as usize ;
725
- let end = unsafe { data. offset ( end) } as usize ;
734
+ let start = unsafe { data. offset ( start) } ;
735
+ let end = unsafe { data. offset ( end) } ;
726
736
727
737
( start, end)
728
738
}
729
739
730
740
inner (
731
741
array. shape ( ) ,
732
742
array. strides ( ) ,
733
- size_of :: < T > ( ) as _ ,
734
- array. data ( ) as _ ,
743
+ size_of :: < T > ( ) as isize ,
744
+ array. data ( ) as * mut u8 ,
735
745
)
736
746
}
737
747
738
748
fn gcd_strides ( strides : & [ isize ] ) -> isize {
739
749
reduce ( strides. iter ( ) . copied ( ) , gcd) . unwrap_or ( 1 )
740
750
}
741
751
742
- // FIXME(adamreichold): Use `usize::abs_diff` from std when our MSRV reaches 1.60.
743
- fn abs_diff ( lhs : usize , rhs : usize ) -> usize {
744
- if lhs >= rhs {
745
- lhs - rhs
746
- } else {
747
- rhs - lhs
748
- }
749
- }
750
-
751
752
// FIXME(adamreichold): Use `Iterator::reduce` from std when our MSRV reaches 1.51.
752
753
fn reduce < I , F > ( mut iter : I , f : F ) -> Option < I :: Item >
753
754
where
@@ -777,11 +778,11 @@ mod tests {
777
778
assert ! ( base. is_null( ) ) ;
778
779
779
780
let base_address = base_address ( array) ;
780
- assert_eq ! ( base_address, array as * const _ as usize ) ;
781
+ assert_eq ! ( base_address, array as * const _ as * mut u8 ) ;
781
782
782
783
let data_range = data_range ( array) ;
783
- assert_eq ! ( data_range. 0 , array. data( ) as usize ) ;
784
- assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 6 ) } as usize ) ;
784
+ assert_eq ! ( data_range. 0 , array. data( ) as * mut u8 ) ;
785
+ assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 6 ) } as * mut u8 ) ;
785
786
} ) ;
786
787
}
787
788
@@ -794,12 +795,12 @@ mod tests {
794
795
assert ! ( !base. is_null( ) ) ;
795
796
796
797
let base_address = base_address ( array) ;
797
- assert_ne ! ( base_address, array as * const _ as usize ) ;
798
- assert_eq ! ( base_address, base as usize ) ;
798
+ assert_ne ! ( base_address, array as * const _ as * mut u8 ) ;
799
+ assert_eq ! ( base_address, base as * mut u8 ) ;
799
800
800
801
let data_range = data_range ( array) ;
801
- assert_eq ! ( data_range. 0 , array. data( ) as usize ) ;
802
- assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 6 ) } as usize ) ;
802
+ assert_eq ! ( data_range. 0 , array. data( ) as * mut u8 ) ;
803
+ assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 6 ) } as * mut u8 ) ;
803
804
} ) ;
804
805
}
805
806
@@ -814,18 +815,18 @@ mod tests {
814
815
. unwrap ( )
815
816
. downcast :: < PyArray2 < f64 > > ( )
816
817
. unwrap ( ) ;
817
- assert_ne ! ( view as * const _ as usize , array as * const _ as usize ) ;
818
+ assert_ne ! ( view as * const _ as * mut u8 , array as * const _ as * mut u8 ) ;
818
819
819
820
let base = unsafe { ( * view. as_array_ptr ( ) ) . base } ;
820
- assert_eq ! ( base as usize , array as * const _ as usize ) ;
821
+ assert_eq ! ( base as * mut u8 , array as * const _ as * mut u8 ) ;
821
822
822
823
let base_address = base_address ( view) ;
823
- assert_ne ! ( base_address, view as * const _ as usize ) ;
824
- assert_eq ! ( base_address, base as usize ) ;
824
+ assert_ne ! ( base_address, view as * const _ as * mut u8 ) ;
825
+ assert_eq ! ( base_address, base as * mut u8 ) ;
825
826
826
827
let data_range = data_range ( view) ;
827
- assert_eq ! ( data_range. 0 , array. data( ) as usize ) ;
828
- assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 4 ) } as usize ) ;
828
+ assert_eq ! ( data_range. 0 , array. data( ) as * mut u8 ) ;
829
+ assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 4 ) } as * mut u8 ) ;
829
830
} ) ;
830
831
}
831
832
@@ -840,22 +841,22 @@ mod tests {
840
841
. unwrap ( )
841
842
. downcast :: < PyArray2 < f64 > > ( )
842
843
. unwrap ( ) ;
843
- assert_ne ! ( view as * const _ as usize , array as * const _ as usize ) ;
844
+ assert_ne ! ( view as * const _ as * mut u8 , array as * const _ as * mut u8 ) ;
844
845
845
846
let base = unsafe { ( * view. as_array_ptr ( ) ) . base } ;
846
- assert_eq ! ( base as usize , array as * const _ as usize ) ;
847
+ assert_eq ! ( base as * mut u8 , array as * const _ as * mut u8 ) ;
847
848
848
849
let base = unsafe { ( * array. as_array_ptr ( ) ) . base } ;
849
850
assert ! ( !base. is_null( ) ) ;
850
851
851
852
let base_address = base_address ( view) ;
852
- assert_ne ! ( base_address, view as * const _ as usize ) ;
853
- assert_ne ! ( base_address, array as * const _ as usize ) ;
854
- assert_eq ! ( base_address, base as usize ) ;
853
+ assert_ne ! ( base_address, view as * const _ as * mut u8 ) ;
854
+ assert_ne ! ( base_address, array as * const _ as * mut u8 ) ;
855
+ assert_eq ! ( base_address, base as * mut u8 ) ;
855
856
856
857
let data_range = data_range ( view) ;
857
- assert_eq ! ( data_range. 0 , array. data( ) as usize ) ;
858
- assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 4 ) } as usize ) ;
858
+ assert_eq ! ( data_range. 0 , array. data( ) as * mut u8 ) ;
859
+ assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 4 ) } as * mut u8 ) ;
859
860
} ) ;
860
861
}
861
862
@@ -870,31 +871,31 @@ mod tests {
870
871
. unwrap ( )
871
872
. downcast :: < PyArray2 < f64 > > ( )
872
873
. unwrap ( ) ;
873
- assert_ne ! ( view1 as * const _ as usize , array as * const _ as usize ) ;
874
+ assert_ne ! ( view1 as * const _ as * mut u8 , array as * const _ as * mut u8 ) ;
874
875
875
876
let locals = [ ( "view1" , view1) ] . into_py_dict ( py) ;
876
877
let view2 = py
877
878
. eval ( "view1[:,0]" , None , Some ( locals) )
878
879
. unwrap ( )
879
880
. downcast :: < PyArray1 < f64 > > ( )
880
881
. unwrap ( ) ;
881
- assert_ne ! ( view2 as * const _ as usize , array as * const _ as usize ) ;
882
- assert_ne ! ( view2 as * const _ as usize , view1 as * const _ as usize ) ;
882
+ assert_ne ! ( view2 as * const _ as * mut u8 , array as * const _ as * mut u8 ) ;
883
+ assert_ne ! ( view2 as * const _ as * mut u8 , view1 as * const _ as * mut u8 ) ;
883
884
884
885
let base = unsafe { ( * view2. as_array_ptr ( ) ) . base } ;
885
- assert_eq ! ( base as usize , array as * const _ as usize ) ;
886
+ assert_eq ! ( base as * mut u8 , array as * const _ as * mut u8 ) ;
886
887
887
888
let base = unsafe { ( * view1. as_array_ptr ( ) ) . base } ;
888
- assert_eq ! ( base as usize , array as * const _ as usize ) ;
889
+ assert_eq ! ( base as * mut u8 , array as * const _ as * mut u8 ) ;
889
890
890
891
let base_address = base_address ( view2) ;
891
- assert_ne ! ( base_address, view2 as * const _ as usize ) ;
892
- assert_ne ! ( base_address, view1 as * const _ as usize ) ;
893
- assert_eq ! ( base_address, base as usize ) ;
892
+ assert_ne ! ( base_address, view2 as * const _ as * mut u8 ) ;
893
+ assert_ne ! ( base_address, view1 as * const _ as * mut u8 ) ;
894
+ assert_eq ! ( base_address, base as * mut u8 ) ;
894
895
895
896
let data_range = data_range ( view2) ;
896
- assert_eq ! ( data_range. 0 , array. data( ) as usize ) ;
897
- assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 1 ) } as usize ) ;
897
+ assert_eq ! ( data_range. 0 , array. data( ) as * mut u8 ) ;
898
+ assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 1 ) } as * mut u8 ) ;
898
899
} ) ;
899
900
}
900
901
@@ -909,35 +910,35 @@ mod tests {
909
910
. unwrap ( )
910
911
. downcast :: < PyArray2 < f64 > > ( )
911
912
. unwrap ( ) ;
912
- assert_ne ! ( view1 as * const _ as usize , array as * const _ as usize ) ;
913
+ assert_ne ! ( view1 as * const _ as * mut u8 , array as * const _ as * mut u8 ) ;
913
914
914
915
let locals = [ ( "view1" , view1) ] . into_py_dict ( py) ;
915
916
let view2 = py
916
917
. eval ( "view1[:,0]" , None , Some ( locals) )
917
918
. unwrap ( )
918
919
. downcast :: < PyArray1 < f64 > > ( )
919
920
. unwrap ( ) ;
920
- assert_ne ! ( view2 as * const _ as usize , array as * const _ as usize ) ;
921
- assert_ne ! ( view2 as * const _ as usize , view1 as * const _ as usize ) ;
921
+ assert_ne ! ( view2 as * const _ as * mut u8 , array as * const _ as * mut u8 ) ;
922
+ assert_ne ! ( view2 as * const _ as * mut u8 , view1 as * const _ as * mut u8 ) ;
922
923
923
924
let base = unsafe { ( * view2. as_array_ptr ( ) ) . base } ;
924
- assert_eq ! ( base as usize , array as * const _ as usize ) ;
925
+ assert_eq ! ( base as * mut u8 , array as * const _ as * mut u8 ) ;
925
926
926
927
let base = unsafe { ( * view1. as_array_ptr ( ) ) . base } ;
927
- assert_eq ! ( base as usize , array as * const _ as usize ) ;
928
+ assert_eq ! ( base as * mut u8 , array as * const _ as * mut u8 ) ;
928
929
929
930
let base = unsafe { ( * array. as_array_ptr ( ) ) . base } ;
930
931
assert ! ( !base. is_null( ) ) ;
931
932
932
933
let base_address = base_address ( view2) ;
933
- assert_ne ! ( base_address, view2 as * const _ as usize ) ;
934
- assert_ne ! ( base_address, view1 as * const _ as usize ) ;
935
- assert_ne ! ( base_address, array as * const _ as usize ) ;
936
- assert_eq ! ( base_address, base as usize ) ;
934
+ assert_ne ! ( base_address, view2 as * const _ as * mut u8 ) ;
935
+ assert_ne ! ( base_address, view1 as * const _ as * mut u8 ) ;
936
+ assert_ne ! ( base_address, array as * const _ as * mut u8 ) ;
937
+ assert_eq ! ( base_address, base as * mut u8 ) ;
937
938
938
939
let data_range = data_range ( view2) ;
939
- assert_eq ! ( data_range. 0 , array. data( ) as usize ) ;
940
- assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 1 ) } as usize ) ;
940
+ assert_eq ! ( data_range. 0 , array. data( ) as * mut u8 ) ;
941
+ assert_eq ! ( data_range. 1 , unsafe { array. data( ) . add( 1 ) } as * mut u8 ) ;
941
942
} ) ;
942
943
}
943
944
@@ -952,19 +953,19 @@ mod tests {
952
953
. unwrap ( )
953
954
. downcast :: < PyArray3 < f64 > > ( )
954
955
. unwrap ( ) ;
955
- assert_ne ! ( view as * const _ as usize , array as * const _ as usize ) ;
956
+ assert_ne ! ( view as * const _ as * mut u8 , array as * const _ as * mut u8 ) ;
956
957
957
958
let base = unsafe { ( * view. as_array_ptr ( ) ) . base } ;
958
- assert_eq ! ( base as usize , array as * const _ as usize ) ;
959
+ assert_eq ! ( base as * mut u8 , array as * const _ as * mut u8 ) ;
959
960
960
961
let base_address = base_address ( view) ;
961
- assert_ne ! ( base_address, view as * const _ as usize ) ;
962
- assert_eq ! ( base_address, base as usize ) ;
962
+ assert_ne ! ( base_address, view as * const _ as * mut u8 ) ;
963
+ assert_eq ! ( base_address, base as * mut u8 ) ;
963
964
964
965
let data_range = data_range ( view) ;
965
966
assert_eq ! ( view. data( ) , unsafe { array. data( ) . offset( 2 ) } ) ;
966
- assert_eq ! ( data_range. 0 , unsafe { view. data( ) . offset( -2 ) } as usize ) ;
967
- assert_eq ! ( data_range. 1 , unsafe { view. data( ) . offset( 4 ) } as usize ) ;
967
+ assert_eq ! ( data_range. 0 , unsafe { view. data( ) . offset( -2 ) } as * mut u8 ) ;
968
+ assert_eq ! ( data_range. 1 , unsafe { view. data( ) . offset( 4 ) } as * mut u8 ) ;
968
969
} ) ;
969
970
}
970
971
@@ -977,11 +978,11 @@ mod tests {
977
978
assert ! ( base. is_null( ) ) ;
978
979
979
980
let base_address = base_address ( array) ;
980
- assert_eq ! ( base_address, array as * const _ as usize ) ;
981
+ assert_eq ! ( base_address, array as * const _ as * mut u8 ) ;
981
982
982
983
let data_range = data_range ( array) ;
983
- assert_eq ! ( data_range. 0 , array. data( ) as usize ) ;
984
- assert_eq ! ( data_range. 1 , array. data( ) as usize ) ;
984
+ assert_eq ! ( data_range. 0 , array. data( ) as * mut u8 ) ;
985
+ assert_eq ! ( data_range. 1 , array. data( ) as * mut u8 ) ;
985
986
} ) ;
986
987
}
987
988
0 commit comments