Skip to content

Commit 14583d9

Browse files
committed
Hash pointers instead of addresses to avoid exposing the allocations backing NumPy arrays under strict provenance.
1 parent 19bfc9d commit 14583d9

File tree

1 file changed

+81
-80
lines changed

1 file changed

+81
-80
lines changed

src/borrow.rs

Lines changed: 81 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,9 @@ use crate::npyffi::{self, PyArrayObject, NPY_ARRAY_WRITEABLE};
184184
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
185185
struct BorrowKey {
186186
/// exclusive range of lowest and highest address covered by array
187-
range: (usize, usize),
187+
range: (*mut u8, *mut u8),
188188
/// the data address on which address computations are based
189-
data_ptr: usize,
189+
data_ptr: *mut u8,
190190
/// the greatest common divisor of the strides of the array
191191
gcd_strides: isize,
192192
}
@@ -199,7 +199,7 @@ impl BorrowKey {
199199
{
200200
let range = data_range(array);
201201

202-
let data_ptr = array.data() as usize;
202+
let data_ptr = array.data() as *mut u8;
203203
let gcd_strides = gcd_strides(array.strides());
204204

205205
Self {
@@ -225,7 +225,7 @@ impl BorrowKey {
225225
// but fails when slicing an array with a step size that does not divide the dimension along that axis.
226226
//
227227
// https://users.rust-lang.org/t/math-for-borrow-checking-numpy-arrays/73303
228-
let ptr_diff = abs_diff(self.data_ptr, other.data_ptr) as isize;
228+
let ptr_diff = unsafe { self.data_ptr.offset_from(other.data_ptr).abs() };
229229
let gcd_strides = gcd(self.gcd_strides, other.gcd_strides);
230230

231231
if ptr_diff % gcd_strides != 0 {
@@ -237,7 +237,7 @@ impl BorrowKey {
237237
}
238238
}
239239

240-
type BorrowFlagsInner = AHashMap<usize, AHashMap<BorrowKey, isize>>;
240+
type BorrowFlagsInner = AHashMap<*mut u8, AHashMap<BorrowKey, isize>>;
241241

242242
struct BorrowFlags(UnsafeCell<Option<BorrowFlagsInner>>);
243243

@@ -253,7 +253,7 @@ impl BorrowFlags {
253253
(*self.0.get()).get_or_insert_with(AHashMap::new)
254254
}
255255

256-
fn acquire(&self, _py: Python, address: usize, key: BorrowKey) -> Result<(), BorrowError> {
256+
fn acquire(&self, _py: Python, address: *mut u8, key: BorrowKey) -> Result<(), BorrowError> {
257257
// SAFETY: Having `_py` implies holding the GIL and
258258
// we are not calling into user code which might re-enter this function.
259259
let borrow_flags = unsafe { BORROW_FLAGS.get() };
@@ -296,7 +296,7 @@ impl BorrowFlags {
296296
Ok(())
297297
}
298298

299-
fn release(&self, _py: Python, address: usize, key: BorrowKey) {
299+
fn release(&self, _py: Python, address: *mut u8, key: BorrowKey) {
300300
// SAFETY: Having `_py` implies holding the GIL and
301301
// we are not calling into user code which might re-enter this function.
302302
let borrow_flags = unsafe { BORROW_FLAGS.get() };
@@ -316,7 +316,12 @@ impl BorrowFlags {
316316
}
317317
}
318318

319-
fn acquire_mut(&self, _py: Python, address: usize, key: BorrowKey) -> Result<(), BorrowError> {
319+
fn acquire_mut(
320+
&self,
321+
_py: Python,
322+
address: *mut u8,
323+
key: BorrowKey,
324+
) -> Result<(), BorrowError> {
320325
// SAFETY: Having `_py` implies holding the GIL and
321326
// we are not calling into user code which might re-enter this function.
322327
let borrow_flags = unsafe { BORROW_FLAGS.get() };
@@ -353,7 +358,7 @@ impl BorrowFlags {
353358
Ok(())
354359
}
355360

356-
fn release_mut(&self, _py: Python, address: usize, key: BorrowKey) {
361+
fn release_mut(&self, _py: Python, address: *mut u8, key: BorrowKey) {
357362
// SAFETY: Having `_py` implies holding the GIL and
358363
// we are not calling into user code which might re-enter this function.
359364
let borrow_flags = unsafe { BORROW_FLAGS.get() };
@@ -383,7 +388,7 @@ where
383388
D: Dimension,
384389
{
385390
array: &'py PyArray<T, D>,
386-
address: usize,
391+
address: *mut u8,
387392
key: BorrowKey,
388393
}
389394

@@ -526,7 +531,7 @@ where
526531
D: Dimension,
527532
{
528533
array: &'py PyArray<T, D>,
529-
address: usize,
534+
address: *mut u8,
530535
key: BorrowKey,
531536
}
532537

@@ -680,30 +685,35 @@ where
680685
}
681686
}
682687

683-
fn base_address<T, D>(array: &PyArray<T, D>) -> usize {
684-
fn inner(py: Python, mut array: *mut PyArrayObject) -> usize {
688+
fn base_address<T, D>(array: &PyArray<T, D>) -> *mut u8 {
689+
fn inner(py: Python, mut array: *mut PyArrayObject) -> *mut u8 {
685690
loop {
686691
let base = unsafe { (*array).base };
687692

688693
if base.is_null() {
689-
return array as usize;
694+
return array as *mut u8;
690695
} else if unsafe { npyffi::PyArray_Check(py, base) } != 0 {
691696
array = base as *mut PyArrayObject;
692697
} else {
693-
return base as usize;
698+
return base as *mut u8;
694699
}
695700
}
696701
}
697702

698703
inner(array.py(), array.as_array_ptr())
699704
}
700705

701-
fn data_range<T, D>(array: &PyArray<T, D>) -> (usize, usize)
706+
fn data_range<T, D>(array: &PyArray<T, D>) -> (*mut u8, *mut u8)
702707
where
703708
T: Element,
704709
D: Dimension,
705710
{
706-
fn inner(shape: &[usize], strides: &[isize], itemsize: isize, data: *mut u8) -> (usize, usize) {
711+
fn inner(
712+
shape: &[usize],
713+
strides: &[isize],
714+
itemsize: isize,
715+
data: *mut u8,
716+
) -> (*mut u8, *mut u8) {
707717
let mut start = 0;
708718
let mut end = 0;
709719

@@ -721,33 +731,24 @@ where
721731
end += itemsize;
722732
}
723733

724-
let start = unsafe { data.offset(start) } as usize;
725-
let end = unsafe { data.offset(end) } as usize;
734+
let start = unsafe { data.offset(start) };
735+
let end = unsafe { data.offset(end) };
726736

727737
(start, end)
728738
}
729739

730740
inner(
731741
array.shape(),
732742
array.strides(),
733-
size_of::<T>() as _,
734-
array.data() as _,
743+
size_of::<T>() as isize,
744+
array.data() as *mut u8,
735745
)
736746
}
737747

738748
fn gcd_strides(strides: &[isize]) -> isize {
739749
reduce(strides.iter().copied(), gcd).unwrap_or(1)
740750
}
741751

742-
// FIXME(adamreichold): Use `usize::abs_diff` from std when our MSRV reaches 1.60.
743-
fn abs_diff(lhs: usize, rhs: usize) -> usize {
744-
if lhs >= rhs {
745-
lhs - rhs
746-
} else {
747-
rhs - lhs
748-
}
749-
}
750-
751752
// FIXME(adamreichold): Use `Iterator::reduce` from std when our MSRV reaches 1.51.
752753
fn reduce<I, F>(mut iter: I, f: F) -> Option<I::Item>
753754
where
@@ -777,11 +778,11 @@ mod tests {
777778
assert!(base.is_null());
778779

779780
let base_address = base_address(array);
780-
assert_eq!(base_address, array as *const _ as usize);
781+
assert_eq!(base_address, array as *const _ as *mut u8);
781782

782783
let data_range = data_range(array);
783-
assert_eq!(data_range.0, array.data() as usize);
784-
assert_eq!(data_range.1, unsafe { array.data().add(6) } as usize);
784+
assert_eq!(data_range.0, array.data() as *mut u8);
785+
assert_eq!(data_range.1, unsafe { array.data().add(6) } as *mut u8);
785786
});
786787
}
787788

@@ -794,12 +795,12 @@ mod tests {
794795
assert!(!base.is_null());
795796

796797
let base_address = base_address(array);
797-
assert_ne!(base_address, array as *const _ as usize);
798-
assert_eq!(base_address, base as usize);
798+
assert_ne!(base_address, array as *const _ as *mut u8);
799+
assert_eq!(base_address, base as *mut u8);
799800

800801
let data_range = data_range(array);
801-
assert_eq!(data_range.0, array.data() as usize);
802-
assert_eq!(data_range.1, unsafe { array.data().add(6) } as usize);
802+
assert_eq!(data_range.0, array.data() as *mut u8);
803+
assert_eq!(data_range.1, unsafe { array.data().add(6) } as *mut u8);
803804
});
804805
}
805806

@@ -814,18 +815,18 @@ mod tests {
814815
.unwrap()
815816
.downcast::<PyArray2<f64>>()
816817
.unwrap();
817-
assert_ne!(view as *const _ as usize, array as *const _ as usize);
818+
assert_ne!(view as *const _ as *mut u8, array as *const _ as *mut u8);
818819

819820
let base = unsafe { (*view.as_array_ptr()).base };
820-
assert_eq!(base as usize, array as *const _ as usize);
821+
assert_eq!(base as *mut u8, array as *const _ as *mut u8);
821822

822823
let base_address = base_address(view);
823-
assert_ne!(base_address, view as *const _ as usize);
824-
assert_eq!(base_address, base as usize);
824+
assert_ne!(base_address, view as *const _ as *mut u8);
825+
assert_eq!(base_address, base as *mut u8);
825826

826827
let data_range = data_range(view);
827-
assert_eq!(data_range.0, array.data() as usize);
828-
assert_eq!(data_range.1, unsafe { array.data().add(4) } as usize);
828+
assert_eq!(data_range.0, array.data() as *mut u8);
829+
assert_eq!(data_range.1, unsafe { array.data().add(4) } as *mut u8);
829830
});
830831
}
831832

@@ -840,22 +841,22 @@ mod tests {
840841
.unwrap()
841842
.downcast::<PyArray2<f64>>()
842843
.unwrap();
843-
assert_ne!(view as *const _ as usize, array as *const _ as usize);
844+
assert_ne!(view as *const _ as *mut u8, array as *const _ as *mut u8);
844845

845846
let base = unsafe { (*view.as_array_ptr()).base };
846-
assert_eq!(base as usize, array as *const _ as usize);
847+
assert_eq!(base as *mut u8, array as *const _ as *mut u8);
847848

848849
let base = unsafe { (*array.as_array_ptr()).base };
849850
assert!(!base.is_null());
850851

851852
let base_address = base_address(view);
852-
assert_ne!(base_address, view as *const _ as usize);
853-
assert_ne!(base_address, array as *const _ as usize);
854-
assert_eq!(base_address, base as usize);
853+
assert_ne!(base_address, view as *const _ as *mut u8);
854+
assert_ne!(base_address, array as *const _ as *mut u8);
855+
assert_eq!(base_address, base as *mut u8);
855856

856857
let data_range = data_range(view);
857-
assert_eq!(data_range.0, array.data() as usize);
858-
assert_eq!(data_range.1, unsafe { array.data().add(4) } as usize);
858+
assert_eq!(data_range.0, array.data() as *mut u8);
859+
assert_eq!(data_range.1, unsafe { array.data().add(4) } as *mut u8);
859860
});
860861
}
861862

@@ -870,31 +871,31 @@ mod tests {
870871
.unwrap()
871872
.downcast::<PyArray2<f64>>()
872873
.unwrap();
873-
assert_ne!(view1 as *const _ as usize, array as *const _ as usize);
874+
assert_ne!(view1 as *const _ as *mut u8, array as *const _ as *mut u8);
874875

875876
let locals = [("view1", view1)].into_py_dict(py);
876877
let view2 = py
877878
.eval("view1[:,0]", None, Some(locals))
878879
.unwrap()
879880
.downcast::<PyArray1<f64>>()
880881
.unwrap();
881-
assert_ne!(view2 as *const _ as usize, array as *const _ as usize);
882-
assert_ne!(view2 as *const _ as usize, view1 as *const _ as usize);
882+
assert_ne!(view2 as *const _ as *mut u8, array as *const _ as *mut u8);
883+
assert_ne!(view2 as *const _ as *mut u8, view1 as *const _ as *mut u8);
883884

884885
let base = unsafe { (*view2.as_array_ptr()).base };
885-
assert_eq!(base as usize, array as *const _ as usize);
886+
assert_eq!(base as *mut u8, array as *const _ as *mut u8);
886887

887888
let base = unsafe { (*view1.as_array_ptr()).base };
888-
assert_eq!(base as usize, array as *const _ as usize);
889+
assert_eq!(base as *mut u8, array as *const _ as *mut u8);
889890

890891
let base_address = base_address(view2);
891-
assert_ne!(base_address, view2 as *const _ as usize);
892-
assert_ne!(base_address, view1 as *const _ as usize);
893-
assert_eq!(base_address, base as usize);
892+
assert_ne!(base_address, view2 as *const _ as *mut u8);
893+
assert_ne!(base_address, view1 as *const _ as *mut u8);
894+
assert_eq!(base_address, base as *mut u8);
894895

895896
let data_range = data_range(view2);
896-
assert_eq!(data_range.0, array.data() as usize);
897-
assert_eq!(data_range.1, unsafe { array.data().add(1) } as usize);
897+
assert_eq!(data_range.0, array.data() as *mut u8);
898+
assert_eq!(data_range.1, unsafe { array.data().add(1) } as *mut u8);
898899
});
899900
}
900901

@@ -909,35 +910,35 @@ mod tests {
909910
.unwrap()
910911
.downcast::<PyArray2<f64>>()
911912
.unwrap();
912-
assert_ne!(view1 as *const _ as usize, array as *const _ as usize);
913+
assert_ne!(view1 as *const _ as *mut u8, array as *const _ as *mut u8);
913914

914915
let locals = [("view1", view1)].into_py_dict(py);
915916
let view2 = py
916917
.eval("view1[:,0]", None, Some(locals))
917918
.unwrap()
918919
.downcast::<PyArray1<f64>>()
919920
.unwrap();
920-
assert_ne!(view2 as *const _ as usize, array as *const _ as usize);
921-
assert_ne!(view2 as *const _ as usize, view1 as *const _ as usize);
921+
assert_ne!(view2 as *const _ as *mut u8, array as *const _ as *mut u8);
922+
assert_ne!(view2 as *const _ as *mut u8, view1 as *const _ as *mut u8);
922923

923924
let base = unsafe { (*view2.as_array_ptr()).base };
924-
assert_eq!(base as usize, array as *const _ as usize);
925+
assert_eq!(base as *mut u8, array as *const _ as *mut u8);
925926

926927
let base = unsafe { (*view1.as_array_ptr()).base };
927-
assert_eq!(base as usize, array as *const _ as usize);
928+
assert_eq!(base as *mut u8, array as *const _ as *mut u8);
928929

929930
let base = unsafe { (*array.as_array_ptr()).base };
930931
assert!(!base.is_null());
931932

932933
let base_address = base_address(view2);
933-
assert_ne!(base_address, view2 as *const _ as usize);
934-
assert_ne!(base_address, view1 as *const _ as usize);
935-
assert_ne!(base_address, array as *const _ as usize);
936-
assert_eq!(base_address, base as usize);
934+
assert_ne!(base_address, view2 as *const _ as *mut u8);
935+
assert_ne!(base_address, view1 as *const _ as *mut u8);
936+
assert_ne!(base_address, array as *const _ as *mut u8);
937+
assert_eq!(base_address, base as *mut u8);
937938

938939
let data_range = data_range(view2);
939-
assert_eq!(data_range.0, array.data() as usize);
940-
assert_eq!(data_range.1, unsafe { array.data().add(1) } as usize);
940+
assert_eq!(data_range.0, array.data() as *mut u8);
941+
assert_eq!(data_range.1, unsafe { array.data().add(1) } as *mut u8);
941942
});
942943
}
943944

@@ -952,19 +953,19 @@ mod tests {
952953
.unwrap()
953954
.downcast::<PyArray3<f64>>()
954955
.unwrap();
955-
assert_ne!(view as *const _ as usize, array as *const _ as usize);
956+
assert_ne!(view as *const _ as *mut u8, array as *const _ as *mut u8);
956957

957958
let base = unsafe { (*view.as_array_ptr()).base };
958-
assert_eq!(base as usize, array as *const _ as usize);
959+
assert_eq!(base as *mut u8, array as *const _ as *mut u8);
959960

960961
let base_address = base_address(view);
961-
assert_ne!(base_address, view as *const _ as usize);
962-
assert_eq!(base_address, base as usize);
962+
assert_ne!(base_address, view as *const _ as *mut u8);
963+
assert_eq!(base_address, base as *mut u8);
963964

964965
let data_range = data_range(view);
965966
assert_eq!(view.data(), unsafe { array.data().offset(2) });
966-
assert_eq!(data_range.0, unsafe { view.data().offset(-2) } as usize);
967-
assert_eq!(data_range.1, unsafe { view.data().offset(4) } as usize);
967+
assert_eq!(data_range.0, unsafe { view.data().offset(-2) } as *mut u8);
968+
assert_eq!(data_range.1, unsafe { view.data().offset(4) } as *mut u8);
968969
});
969970
}
970971

@@ -977,11 +978,11 @@ mod tests {
977978
assert!(base.is_null());
978979

979980
let base_address = base_address(array);
980-
assert_eq!(base_address, array as *const _ as usize);
981+
assert_eq!(base_address, array as *const _ as *mut u8);
981982

982983
let data_range = data_range(array);
983-
assert_eq!(data_range.0, array.data() as usize);
984-
assert_eq!(data_range.1, array.data() as usize);
984+
assert_eq!(data_range.0, array.data() as *mut u8);
985+
assert_eq!(data_range.1, array.data() as *mut u8);
985986
});
986987
}
987988

0 commit comments

Comments
 (0)