Skip to content

Commit e58c0af

Browse files
committed
refactor BorrowFlags tests to not hold a lock in the tests
1 parent cc109bf commit e58c0af

File tree

1 file changed

+102
-107
lines changed

1 file changed

+102
-107
lines changed

src/borrow/shared.rs

Lines changed: 102 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -247,15 +247,14 @@ impl BorrowKey {
247247
}
248248
}
249249

250-
type BorrowFlagsInner = Mutex<FxHashMap<*mut c_void, FxHashMap<BorrowKey, isize>>>;
250+
type BorrowFlagsInner = FxHashMap<*mut c_void, FxHashMap<BorrowKey, isize>>;
251251

252252
#[derive(Default)]
253-
struct BorrowFlags(BorrowFlagsInner);
253+
struct BorrowFlags(Mutex<BorrowFlagsInner>);
254254

255255
impl BorrowFlags {
256256
fn acquire(&self, address: *mut c_void, key: BorrowKey) -> Result<(), ()> {
257257
let mut borrow_flags = self.0.lock().unwrap();
258-
259258
match borrow_flags.entry(address) {
260259
Entry::Occupied(entry) => {
261260
let same_base_arrays = entry.into_mut();
@@ -448,10 +447,27 @@ mod tests {
448447
use crate::untyped_array::PyUntypedArrayMethods;
449448
use pyo3::ffi::c_str;
450449

451-
fn get_borrow_flags<'py>(py: Python<'py>) -> &'py BorrowFlagsInner {
450+
struct BorrowFlagsState(usize, usize, Option<isize>);
451+
452+
fn get_borrow_flags_state<'py>(
453+
py: Python<'py>,
454+
base: *mut c_void,
455+
key: &BorrowKey,
456+
) -> BorrowFlagsState {
452457
let shared = get_or_insert_shared(py).unwrap();
453458
assert_eq!(shared.version, 1);
454-
unsafe { &(*(shared.flags as *mut BorrowFlags)).0 }
459+
let inner = unsafe { &(*(shared.flags as *mut BorrowFlags)).0 }
460+
.lock()
461+
.unwrap();
462+
if let Some(base_arrays) = inner.get(&base) {
463+
BorrowFlagsState(
464+
inner.len(),
465+
base_arrays.len(),
466+
base_arrays.get(key).map(|x| *x),
467+
)
468+
} else {
469+
BorrowFlagsState(0, 0, None)
470+
}
455471
}
456472

457473
#[test]
@@ -778,36 +794,30 @@ mod tests {
778794
let _exclusive1 = array1.readwrite();
779795

780796
{
781-
let borrow_flags = get_borrow_flags(py).lock().unwrap();
797+
let state = get_borrow_flags_state(py, base1, &key1);
782798
#[cfg(not(Py_GIL_DISABLED))]
783-
assert_eq!(borrow_flags.len(), 1);
784-
785-
let same_base_arrays = &borrow_flags[&base1];
786-
assert_eq!(same_base_arrays.len(), 1);
799+
// borrow checking state is shared and other tests might have registered a borrow
800+
assert_eq!(state.0, 1);
787801

788-
let flag = same_base_arrays[&key1];
789-
assert_eq!(flag, -1);
802+
assert_eq!(state.1, 1);
803+
assert_eq!(state.2, Some(-1));
790804
}
791805

792806
let key2 = borrow_key(py, array2.as_array_ptr());
793807
let _shared2 = array2.readonly();
794808

795809
{
796-
let borrow_flags = get_borrow_flags(py).lock().unwrap();
810+
let state = get_borrow_flags_state(py, base1, &key1);
797811
#[cfg(not(Py_GIL_DISABLED))]
798-
assert_eq!(borrow_flags.len(), 2);
799-
800-
let same_base_arrays = &borrow_flags[&base1];
801-
assert_eq!(same_base_arrays.len(), 1);
812+
// borrow checking state is shared and other tests might have registered a borrow
813+
assert_eq!(state.0, 2);
802814

803-
let flag = same_base_arrays[&key1];
804-
assert_eq!(flag, -1);
815+
assert_eq!(state.1, 1);
816+
assert_eq!(state.2, Some(-1));
805817

806-
let same_base_arrays = &borrow_flags[&base2];
807-
assert_eq!(same_base_arrays.len(), 1);
808-
809-
let flag = same_base_arrays[&key2];
810-
assert_eq!(flag, 1);
818+
let state = get_borrow_flags_state(py, base2, &key2);
819+
assert_eq!(state.1, 1);
820+
assert_eq!(state.2, Some(1));
811821
}
812822
});
813823
}
@@ -830,15 +840,13 @@ mod tests {
830840
let exclusive1 = view1.readwrite();
831841

832842
{
833-
let borrow_flags = get_borrow_flags(py).lock().unwrap();
834-
#[cfg(not(Py_GIL_DISABLED))]
835-
assert_eq!(borrow_flags.len(), 1);
836-
837-
let same_base_arrays = &borrow_flags[&base];
838-
assert_eq!(same_base_arrays.len(), 1);
843+
let state = get_borrow_flags_state(py, base, &key1);
839844

840-
let flag = same_base_arrays[&key1];
841-
assert_eq!(flag, -1);
845+
#[cfg(not(Py_GIL_DISABLED))]
846+
// borrow checking state is shared and other tests might have registered a borrow
847+
assert_eq!(state.0, 1);
848+
assert_eq!(state.1, 1);
849+
assert_eq!(state.2, Some(-1));
842850
}
843851

844852
let view2 = py
@@ -851,18 +859,15 @@ mod tests {
851859
let shared2 = view2.readonly();
852860

853861
{
854-
let borrow_flags = get_borrow_flags(py).lock().unwrap();
862+
let state = get_borrow_flags_state(py, base, &key1);
855863
#[cfg(not(Py_GIL_DISABLED))]
856-
assert_eq!(borrow_flags.len(), 1);
857-
858-
let same_base_arrays = &borrow_flags[&base];
859-
assert_eq!(same_base_arrays.len(), 2);
864+
// borrow checking state is shared and other tests might have registered a borrow
865+
assert_eq!(state.0, 1);
866+
assert_eq!(state.1, 2);
867+
assert_eq!(state.2, Some(-1));
860868

861-
let flag = same_base_arrays[&key1];
862-
assert_eq!(flag, -1);
863-
864-
let flag = same_base_arrays[&key2];
865-
assert_eq!(flag, 1);
869+
let state = get_borrow_flags_state(py, base, &key2);
870+
assert_eq!(state.2, Some(1));
866871
}
867872

868873
let view3 = py
@@ -875,21 +880,18 @@ mod tests {
875880
let shared3 = view3.readonly();
876881

877882
{
878-
let borrow_flags = get_borrow_flags(py).lock().unwrap();
883+
let state = get_borrow_flags_state(py, base, &key1);
879884
#[cfg(not(Py_GIL_DISABLED))]
880-
assert_eq!(borrow_flags.len(), 1);
881-
882-
let same_base_arrays = &borrow_flags[&base];
883-
assert_eq!(same_base_arrays.len(), 2);
884-
885-
let flag = same_base_arrays[&key1];
886-
assert_eq!(flag, -1);
885+
// borrow checking state is shared and other tests might have registered a borrow
886+
assert_eq!(state.0, 1);
887+
assert_eq!(state.1, 2);
888+
assert_eq!(state.2, Some(-1));
887889

888-
let flag = same_base_arrays[&key2];
889-
assert_eq!(flag, 2);
890+
let state = get_borrow_flags_state(py, base, &key2);
891+
assert_eq!(state.2, Some(2));
890892

891-
let flag = same_base_arrays[&key3];
892-
assert_eq!(flag, 2);
893+
let state = get_borrow_flags_state(py, base, &key3);
894+
assert_eq!(state.2, Some(2));
893895
}
894896

895897
let view4 = py
@@ -902,96 +904,89 @@ mod tests {
902904
let shared4 = view4.readonly();
903905

904906
{
905-
let borrow_flags = get_borrow_flags(py).lock().unwrap();
907+
let state = get_borrow_flags_state(py, base, &key1);
906908
#[cfg(not(Py_GIL_DISABLED))]
907-
assert_eq!(borrow_flags.len(), 1);
909+
// borrow checking state is shared and other tests might have registered a borrow
910+
assert_eq!(state.0, 1);
911+
assert_eq!(state.1, 3);
912+
assert_eq!(state.2, Some(-1));
908913

909-
let same_base_arrays = &borrow_flags[&base];
910-
assert_eq!(same_base_arrays.len(), 3);
914+
let state = get_borrow_flags_state(py, base, &key2);
915+
assert_eq!(state.2, Some(2));
911916

912-
let flag = same_base_arrays[&key1];
913-
assert_eq!(flag, -1);
917+
let state = get_borrow_flags_state(py, base, &key3);
918+
assert_eq!(state.2, Some(2));
914919

915-
let flag = same_base_arrays[&key2];
916-
assert_eq!(flag, 2);
917-
918-
let flag = same_base_arrays[&key3];
919-
assert_eq!(flag, 2);
920-
921-
let flag = same_base_arrays[&key4];
922-
assert_eq!(flag, 1);
920+
let state = get_borrow_flags_state(py, base, &key4);
921+
assert_eq!(state.2, Some(1));
923922
}
924923

925924
drop(shared2);
926925

927926
{
928-
let borrow_flags = get_borrow_flags(py).lock().unwrap();
927+
let state = get_borrow_flags_state(py, base, &key1);
929928
#[cfg(not(Py_GIL_DISABLED))]
930-
assert_eq!(borrow_flags.len(), 1);
931-
932-
let same_base_arrays = &borrow_flags[&base];
933-
assert_eq!(same_base_arrays.len(), 3);
929+
// borrow checking state is shared and other tests might have registered a borrow
930+
assert_eq!(state.0, 1);
931+
assert_eq!(state.1, 3);
932+
assert_eq!(state.2, Some(-1));
934933

935-
let flag = same_base_arrays[&key1];
936-
assert_eq!(flag, -1);
934+
let state = get_borrow_flags_state(py, base, &key2);
935+
assert_eq!(state.2, Some(1));
937936

938-
let flag = same_base_arrays[&key2];
939-
assert_eq!(flag, 1);
937+
let state = get_borrow_flags_state(py, base, &key3);
938+
assert_eq!(state.2, Some(1));
940939

941-
let flag = same_base_arrays[&key3];
942-
assert_eq!(flag, 1);
943-
944-
let flag = same_base_arrays[&key4];
945-
assert_eq!(flag, 1);
940+
let state = get_borrow_flags_state(py, base, &key4);
941+
assert_eq!(state.2, Some(1));
946942
}
947943

948944
drop(shared3);
949945

950946
{
951-
let borrow_flags = get_borrow_flags(py).lock().unwrap();
947+
let state = get_borrow_flags_state(py, base, &key1);
952948
#[cfg(not(Py_GIL_DISABLED))]
953-
assert_eq!(borrow_flags.len(), 1);
954-
955-
let same_base_arrays = &borrow_flags[&base];
956-
assert_eq!(same_base_arrays.len(), 2);
957-
958-
let flag = same_base_arrays[&key1];
959-
assert_eq!(flag, -1);
949+
// borrow checking state is shared and other tests might have registered a borrow
950+
assert_eq!(state.0, 1);
951+
assert_eq!(state.1, 2);
952+
assert_eq!(state.2, Some(-1));
960953

961-
assert!(!same_base_arrays.contains_key(&key2));
954+
let state = get_borrow_flags_state(py, base, &key2);
955+
assert_eq!(state.2, None);
962956

963-
assert!(!same_base_arrays.contains_key(&key3));
957+
let state = get_borrow_flags_state(py, base, &key3);
958+
assert_eq!(state.2, None);
964959

965-
let flag = same_base_arrays[&key4];
966-
assert_eq!(flag, 1);
960+
let state = get_borrow_flags_state(py, base, &key4);
961+
assert_eq!(state.2, Some(1));
967962
}
968963

969964
drop(exclusive1);
970965

971966
{
972-
let borrow_flags = get_borrow_flags(py).lock().unwrap();
967+
let state = get_borrow_flags_state(py, base, &key1);
973968
#[cfg(not(Py_GIL_DISABLED))]
974-
assert_eq!(borrow_flags.len(), 1);
975-
976-
let same_base_arrays = &borrow_flags[&base];
977-
assert_eq!(same_base_arrays.len(), 1);
978-
979-
assert!(!same_base_arrays.contains_key(&key1));
969+
// borrow checking state is shared and other tests might have registered a borrow
970+
assert_eq!(state.0, 1);
971+
assert_eq!(state.1, 1);
972+
assert_eq!(state.2, None);
980973

981-
assert!(!same_base_arrays.contains_key(&key2));
974+
let state = get_borrow_flags_state(py, base, &key2);
975+
assert_eq!(state.2, None);
982976

983-
assert!(!same_base_arrays.contains_key(&key3));
977+
let state = get_borrow_flags_state(py, base, &key3);
978+
assert_eq!(state.2, None);
984979

985-
let flag = same_base_arrays[&key4];
986-
assert_eq!(flag, 1);
980+
let state = get_borrow_flags_state(py, base, &key4);
981+
assert_eq!(state.2, Some(1));
987982
}
988983

989984
drop(shared4);
990985

991986
#[cfg(not(Py_GIL_DISABLED))]
987+
// borrow checking state is shared and other tests might have registered a borrow
992988
{
993-
let borrow_flags = get_borrow_flags(py).lock().unwrap();
994-
assert_eq!(borrow_flags.len(), 0);
989+
assert_eq!(get_borrow_flags_state(py, base, &key1).0, 0);
995990
}
996991
});
997992
}

0 commit comments

Comments
 (0)