Skip to content

Commit 6b75cf9

Browse files
committed
Fix a drop bug in PyReadwriteArray::resize and optimize borrowing by caching address and key.
1 parent 457d013 commit 6b75cf9

File tree

1 file changed

+90
-40
lines changed

1 file changed

+90
-40
lines changed

src/borrow.rs

Lines changed: 90 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ use crate::dtype::Element;
180180
use crate::error::{BorrowError, NotContiguousError};
181181
use crate::npyffi::{self, PyArrayObject, NPY_ARRAY_WRITEABLE};
182182

183-
#[derive(PartialEq, Eq, Hash)]
183+
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
184184
struct BorrowKey {
185185
/// exclusive range of lowest and highest address covered by array
186186
range: (usize, usize),
@@ -375,10 +375,16 @@ static BORROW_FLAGS: BorrowFlags = BorrowFlags::new();
375375
/// i.e. that only shared references into the interior of the array can be created safely.
376376
///
377377
/// See the [module-level documentation](self) for more.
378-
pub struct PyReadonlyArray<'py, T, D>(&'py PyArray<T, D>)
378+
#[repr(C)]
379+
pub struct PyReadonlyArray<'py, T, D>
379380
where
380381
T: Element,
381-
D: Dimension;
382+
D: Dimension,
383+
{
384+
array: &'py PyArray<T, D>,
385+
address: usize,
386+
key: BorrowKey,
387+
}
382388

383389
/// Read-only borrow of a one-dimensional array.
384390
pub type PyReadonlyArray1<'py, T> = PyReadonlyArray<'py, T, Ix1>;
@@ -409,7 +415,7 @@ where
409415
type Target = PyArray<T, D>;
410416

411417
fn deref(&self) -> &Self::Target {
412-
self.0
418+
self.array
413419
}
414420
}
415421

@@ -426,27 +432,30 @@ where
426432
D: Dimension,
427433
{
428434
pub(crate) fn try_new(array: &'py PyArray<T, D>) -> Result<Self, BorrowError> {
429-
let py = array.py();
430435
let address = base_address(array);
431436
let key = BorrowKey::from_array(array);
432437

433-
BORROW_FLAGS.acquire(py, address, key)?;
438+
BORROW_FLAGS.acquire(array.py(), address, key)?;
434439

435-
Ok(Self(array))
440+
Ok(Self {
441+
array,
442+
address,
443+
key,
444+
})
436445
}
437446

438447
/// Provides an immutable array view of the interior of the NumPy array.
439448
#[inline(always)]
440449
pub fn as_array(&self) -> ArrayView<T, D> {
441450
// SAFETY: Global borrow flags ensure aliasing discipline.
442-
unsafe { self.0.as_array() }
451+
unsafe { self.array.as_array() }
443452
}
444453

445454
/// Provide an immutable slice view of the interior of the NumPy array if it is contiguous.
446455
#[inline(always)]
447456
pub fn as_slice(&self) -> Result<&[T], NotContiguousError> {
448457
// SAFETY: Global borrow flags ensure aliasing discipline.
449-
unsafe { self.0.as_slice() }
458+
unsafe { self.array.as_slice() }
450459
}
451460

452461
/// Provide an immutable reference to an element of the NumPy array if the index is within bounds.
@@ -455,7 +464,7 @@ where
455464
where
456465
I: NpyIndex<Dim = D>,
457466
{
458-
unsafe { self.0.get(index) }
467+
unsafe { self.array.get(index) }
459468
}
460469
}
461470

@@ -465,7 +474,15 @@ where
465474
D: Dimension,
466475
{
467476
fn clone(&self) -> Self {
468-
Self::try_new(self.0).unwrap()
477+
BORROW_FLAGS
478+
.acquire(self.array.py(), self.address, self.key)
479+
.unwrap();
480+
481+
Self {
482+
array: self.array,
483+
address: self.address,
484+
key: self.key,
485+
}
469486
}
470487
}
471488

@@ -475,11 +492,7 @@ where
475492
D: Dimension,
476493
{
477494
fn drop(&mut self) {
478-
let py = self.0.py();
479-
let address = base_address(self.0);
480-
let key = BorrowKey::from_array(self.0);
481-
482-
BORROW_FLAGS.release(py, address, key);
495+
BORROW_FLAGS.release(self.array.py(), self.address, self.key);
483496
}
484497
}
485498

@@ -505,10 +518,16 @@ where
505518
/// i.e. that only a single exclusive reference into the interior of the array can be created safely.
506519
///
507520
/// See the [module-level documentation](self) for more.
508-
pub struct PyReadwriteArray<'py, T, D>(&'py PyArray<T, D>)
521+
#[repr(C)]
522+
pub struct PyReadwriteArray<'py, T, D>
509523
where
510524
T: Element,
511-
D: Dimension;
525+
D: Dimension,
526+
{
527+
array: &'py PyArray<T, D>,
528+
address: usize,
529+
key: BorrowKey,
530+
}
512531

513532
/// Read-write borrow of a one-dimensional array.
514533
pub type PyReadwriteArray1<'py, T> = PyReadwriteArray<'py, T, Ix1>;
@@ -561,27 +580,30 @@ where
561580
return Err(BorrowError::NotWriteable);
562581
}
563582

564-
let py = array.py();
565583
let address = base_address(array);
566584
let key = BorrowKey::from_array(array);
567585

568-
BORROW_FLAGS.acquire_mut(py, address, key)?;
586+
BORROW_FLAGS.acquire_mut(array.py(), address, key)?;
569587

570-
Ok(Self(array))
588+
Ok(Self {
589+
array,
590+
address,
591+
key,
592+
})
571593
}
572594

573595
/// Provides a mutable array view of the interior of the NumPy array.
574596
#[inline(always)]
575597
pub fn as_array_mut(&mut self) -> ArrayViewMut<T, D> {
576598
// SAFETY: Global borrow flags ensure aliasing discipline.
577-
unsafe { self.0.as_array_mut() }
599+
unsafe { self.array.as_array_mut() }
578600
}
579601

580602
/// Provide a mutable slice view of the interior of the NumPy array if it is contiguous.
581603
#[inline(always)]
582604
pub fn as_slice_mut(&mut self) -> Result<&mut [T], NotContiguousError> {
583605
// SAFETY: Global borrow flags ensure aliasing discipline.
584-
unsafe { self.0.as_slice_mut() }
606+
unsafe { self.array.as_slice_mut() }
585607
}
586608

587609
/// Provide a mutable reference to an element of the NumPy array if the index is within bounds.
@@ -590,7 +612,7 @@ where
590612
where
591613
I: NpyIndex<Dim = D>,
592614
{
593-
unsafe { self.0.get_mut(index) }
615+
unsafe { self.array.get_mut(index) }
594616
}
595617
}
596618

@@ -616,23 +638,16 @@ where
616638
/// });
617639
/// ```
618640
pub fn resize(self, new_elems: usize) -> PyResult<Self> {
619-
let py = self.0.py();
620-
let address = base_address(self.0);
621-
let key = BorrowKey::from_array(self.0);
622-
623-
BORROW_FLAGS.release_mut(py, address, key);
641+
let array = self.array;
624642

625643
// SAFETY: Ownership of `self` proves exclusive access to the interior of the array.
626644
unsafe {
627-
self.0.resize(new_elems)?;
645+
array.resize(new_elems)?;
628646
}
629647

630-
let address = base_address(self.0);
631-
let key = BorrowKey::from_array(self.0);
632-
633-
BORROW_FLAGS.acquire_mut(py, address, key)?;
648+
drop(self);
634649

635-
Ok(self)
650+
Ok(Self::try_new(array).unwrap())
636651
}
637652
}
638653

@@ -642,11 +657,7 @@ where
642657
D: Dimension,
643658
{
644659
fn drop(&mut self) {
645-
let py = self.0.py();
646-
let address = base_address(self.0);
647-
let key = BorrowKey::from_array(self.0);
648-
649-
BORROW_FLAGS.release_mut(py, address, key);
660+
BORROW_FLAGS.release_mut(self.array.py(), self.address, self.key);
650661
}
651662
}
652663

@@ -1275,4 +1286,43 @@ mod tests {
12751286
}
12761287
});
12771288
}
1289+
1290+
#[test]
1291+
#[should_panic(expected = "AlreadyBorrowed")]
1292+
fn cannot_clone_exclusive_borrow_via_deref() {
1293+
Python::with_gil(|py| {
1294+
let array = PyArray::<f64, _>::zeros(py, (3, 2, 1), false);
1295+
1296+
let exclusive = array.readwrite();
1297+
let _shared = exclusive.clone();
1298+
});
1299+
}
1300+
1301+
#[test]
1302+
fn failed_resize_does_not_double_release() {
1303+
Python::with_gil(|py| {
1304+
let array = PyArray::<f64, _>::zeros(py, 10, false);
1305+
1306+
// The view will make the internal reference check of `PyArray_Resize` fail.
1307+
let locals = [("array", array)].into_py_dict(py);
1308+
let _view = py
1309+
.eval("array[:]", None, Some(locals))
1310+
.unwrap()
1311+
.downcast::<PyArray1<f64>>()
1312+
.unwrap();
1313+
1314+
let exclusive = array.readwrite();
1315+
assert!(exclusive.resize(100).is_err());
1316+
});
1317+
}
1318+
1319+
#[test]
1320+
fn ineffective_resize_does_not_conflict() {
1321+
Python::with_gil(|py| {
1322+
let array = PyArray::<f64, _>::zeros(py, 10, false);
1323+
1324+
let exclusive = array.readwrite();
1325+
assert!(exclusive.resize(10).is_ok());
1326+
});
1327+
}
12781328
}

0 commit comments

Comments
 (0)