Skip to content

Commit 1425e40

Browse files
authored
Merge pull request #311 from PyO3/debloat-borrow-redux
Further fixes and optimizations for the borrow module
2 parents e843bcc + dfb7ba0 commit 1425e40

File tree

2 files changed

+116
-65
lines changed

2 files changed

+116
-65
lines changed

src/borrow.rs

Lines changed: 113 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ use crate::dtype::Element;
180180
use crate::error::{BorrowError, NotContiguousError};
181181
use crate::npyffi::{self, PyArrayObject, NPY_ARRAY_WRITEABLE};
182182

183-
#[derive(PartialEq, Eq, Hash)]
183+
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
184184
struct BorrowKey {
185185
/// exclusive range of lowest and highest address covered by array
186186
range: (usize, usize),
@@ -199,7 +199,7 @@ impl BorrowKey {
199199
let range = data_range(array);
200200

201201
let data_ptr = array.data() as usize;
202-
let gcd_strides = reduce(array.strides().iter().copied(), gcd).unwrap_or(1);
202+
let gcd_strides = gcd_strides(array.strides());
203203

204204
Self {
205205
range,
@@ -252,16 +252,9 @@ impl BorrowFlags {
252252
(*self.0.get()).get_or_insert_with(AHashMap::new)
253253
}
254254

255-
fn acquire<T, D>(&self, array: &PyArray<T, D>) -> Result<(), BorrowError>
256-
where
257-
T: Element,
258-
D: Dimension,
259-
{
260-
let address = base_address(array);
261-
let key = BorrowKey::from_array(array);
262-
263-
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
264-
// and we are not calling into user code which might re-enter this function.
255+
fn acquire(&self, _py: Python, address: usize, key: BorrowKey) -> Result<(), BorrowError> {
256+
// SAFETY: Having `_py` implies holding the GIL and
257+
// we are not calling into user code which might re-enter this function.
265258
let borrow_flags = unsafe { BORROW_FLAGS.get() };
266259

267260
match borrow_flags.entry(address) {
@@ -302,16 +295,9 @@ impl BorrowFlags {
302295
Ok(())
303296
}
304297

305-
fn release<T, D>(&self, array: &PyArray<T, D>)
306-
where
307-
T: Element,
308-
D: Dimension,
309-
{
310-
let address = base_address(array);
311-
let key = BorrowKey::from_array(array);
312-
313-
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
314-
// and we are not calling into user code which might re-enter this function.
298+
fn release(&self, _py: Python, address: usize, key: BorrowKey) {
299+
// SAFETY: Having `_py` implies holding the GIL and
300+
// we are not calling into user code which might re-enter this function.
315301
let borrow_flags = unsafe { BORROW_FLAGS.get() };
316302

317303
let same_base_arrays = borrow_flags.get_mut(&address).unwrap();
@@ -329,16 +315,9 @@ impl BorrowFlags {
329315
}
330316
}
331317

332-
fn acquire_mut<T, D>(&self, array: &PyArray<T, D>) -> Result<(), BorrowError>
333-
where
334-
T: Element,
335-
D: Dimension,
336-
{
337-
let address = base_address(array);
338-
let key = BorrowKey::from_array(array);
339-
340-
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
341-
// and we are not calling into user code which might re-enter this function.
318+
fn acquire_mut(&self, _py: Python, address: usize, key: BorrowKey) -> Result<(), BorrowError> {
319+
// SAFETY: Having `_py` implies holding the GIL and
320+
// we are not calling into user code which might re-enter this function.
342321
let borrow_flags = unsafe { BORROW_FLAGS.get() };
343322

344323
match borrow_flags.entry(address) {
@@ -373,16 +352,9 @@ impl BorrowFlags {
373352
Ok(())
374353
}
375354

376-
fn release_mut<T, D>(&self, array: &PyArray<T, D>)
377-
where
378-
T: Element,
379-
D: Dimension,
380-
{
381-
let address = base_address(array);
382-
let key = BorrowKey::from_array(array);
383-
384-
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
385-
// and we are not calling into user code which might re-enter this function.
355+
fn release_mut(&self, _py: Python, address: usize, key: BorrowKey) {
356+
// SAFETY: Having `_py` implies holding the GIL and
357+
// we are not calling into user code which might re-enter this function.
386358
let borrow_flags = unsafe { BORROW_FLAGS.get() };
387359

388360
let same_base_arrays = borrow_flags.get_mut(&address).unwrap();
@@ -403,10 +375,16 @@ static BORROW_FLAGS: BorrowFlags = BorrowFlags::new();
403375
/// i.e. that only shared references into the interior of the array can be created safely.
404376
///
405377
/// See the [module-level documentation](self) for more.
406-
pub struct PyReadonlyArray<'py, T, D>(&'py PyArray<T, D>)
378+
#[repr(C)]
379+
pub struct PyReadonlyArray<'py, T, D>
407380
where
408381
T: Element,
409-
D: Dimension;
382+
D: Dimension,
383+
{
384+
array: &'py PyArray<T, D>,
385+
address: usize,
386+
key: BorrowKey,
387+
}
410388

411389
/// Read-only borrow of a one-dimensional array.
412390
pub type PyReadonlyArray1<'py, T> = PyReadonlyArray<'py, T, Ix1>;
@@ -437,7 +415,7 @@ where
437415
type Target = PyArray<T, D>;
438416

439417
fn deref(&self) -> &Self::Target {
440-
self.0
418+
self.array
441419
}
442420
}
443421

@@ -454,23 +432,30 @@ where
454432
D: Dimension,
455433
{
456434
pub(crate) fn try_new(array: &'py PyArray<T, D>) -> Result<Self, BorrowError> {
457-
BORROW_FLAGS.acquire(array)?;
435+
let address = base_address(array);
436+
let key = BorrowKey::from_array(array);
458437

459-
Ok(Self(array))
438+
BORROW_FLAGS.acquire(array.py(), address, key)?;
439+
440+
Ok(Self {
441+
array,
442+
address,
443+
key,
444+
})
460445
}
461446

462447
/// Provides an immutable array view of the interior of the NumPy array.
463448
#[inline(always)]
464449
pub fn as_array(&self) -> ArrayView<T, D> {
465450
// SAFETY: Global borrow flags ensure aliasing discipline.
466-
unsafe { self.0.as_array() }
451+
unsafe { self.array.as_array() }
467452
}
468453

469454
/// Provide an immutable slice view of the interior of the NumPy array if it is contiguous.
470455
#[inline(always)]
471456
pub fn as_slice(&self) -> Result<&[T], NotContiguousError> {
472457
// SAFETY: Global borrow flags ensure aliasing discipline.
473-
unsafe { self.0.as_slice() }
458+
unsafe { self.array.as_slice() }
474459
}
475460

476461
/// Provide an immutable reference to an element of the NumPy array if the index is within bounds.
@@ -479,7 +464,7 @@ where
479464
where
480465
I: NpyIndex<Dim = D>,
481466
{
482-
unsafe { self.0.get(index) }
467+
unsafe { self.array.get(index) }
483468
}
484469
}
485470

@@ -489,7 +474,15 @@ where
489474
D: Dimension,
490475
{
491476
fn clone(&self) -> Self {
492-
Self::try_new(self.0).unwrap()
477+
BORROW_FLAGS
478+
.acquire(self.array.py(), self.address, self.key)
479+
.unwrap();
480+
481+
Self {
482+
array: self.array,
483+
address: self.address,
484+
key: self.key,
485+
}
493486
}
494487
}
495488

@@ -499,7 +492,7 @@ where
499492
D: Dimension,
500493
{
501494
fn drop(&mut self) {
502-
BORROW_FLAGS.release(self.0);
495+
BORROW_FLAGS.release(self.array.py(), self.address, self.key);
503496
}
504497
}
505498

@@ -525,10 +518,16 @@ where
525518
/// i.e. that only a single exclusive reference into the interior of the array can be created safely.
526519
///
527520
/// See the [module-level documentation](self) for more.
528-
pub struct PyReadwriteArray<'py, T, D>(&'py PyArray<T, D>)
521+
#[repr(C)]
522+
pub struct PyReadwriteArray<'py, T, D>
529523
where
530524
T: Element,
531-
D: Dimension;
525+
D: Dimension,
526+
{
527+
array: &'py PyArray<T, D>,
528+
address: usize,
529+
key: BorrowKey,
530+
}
532531

533532
/// Read-write borrow of a one-dimensional array.
534533
pub type PyReadwriteArray1<'py, T> = PyReadwriteArray<'py, T, Ix1>;
@@ -581,23 +580,30 @@ where
581580
return Err(BorrowError::NotWriteable);
582581
}
583582

584-
BORROW_FLAGS.acquire_mut(array)?;
583+
let address = base_address(array);
584+
let key = BorrowKey::from_array(array);
585585

586-
Ok(Self(array))
586+
BORROW_FLAGS.acquire_mut(array.py(), address, key)?;
587+
588+
Ok(Self {
589+
array,
590+
address,
591+
key,
592+
})
587593
}
588594

589595
/// Provides a mutable array view of the interior of the NumPy array.
590596
#[inline(always)]
591597
pub fn as_array_mut(&mut self) -> ArrayViewMut<T, D> {
592598
// SAFETY: Global borrow flags ensure aliasing discipline.
593-
unsafe { self.0.as_array_mut() }
599+
unsafe { self.array.as_array_mut() }
594600
}
595601

596602
/// Provide a mutable slice view of the interior of the NumPy array if it is contiguous.
597603
#[inline(always)]
598604
pub fn as_slice_mut(&mut self) -> Result<&mut [T], NotContiguousError> {
599605
// SAFETY: Global borrow flags ensure aliasing discipline.
600-
unsafe { self.0.as_slice_mut() }
606+
unsafe { self.array.as_slice_mut() }
601607
}
602608

603609
/// Provide a mutable reference to an element of the NumPy array if the index is within bounds.
@@ -606,7 +612,7 @@ where
606612
where
607613
I: NpyIndex<Dim = D>,
608614
{
609-
unsafe { self.0.get_mut(index) }
615+
unsafe { self.array.get_mut(index) }
610616
}
611617
}
612618

@@ -632,16 +638,16 @@ where
632638
/// });
633639
/// ```
634640
pub fn resize(self, new_elems: usize) -> PyResult<Self> {
635-
BORROW_FLAGS.release_mut(self.0);
641+
let array = self.array;
636642

637643
// SAFETY: Ownership of `self` proves exclusive access to the interior of the array.
638644
unsafe {
639-
self.0.resize(new_elems)?;
645+
array.resize(new_elems)?;
640646
}
641647

642-
BORROW_FLAGS.acquire_mut(self.0)?;
648+
drop(self);
643649

644-
Ok(self)
650+
Ok(Self::try_new(array).unwrap())
645651
}
646652
}
647653

@@ -651,7 +657,7 @@ where
651657
D: Dimension,
652658
{
653659
fn drop(&mut self) {
654-
BORROW_FLAGS.release_mut(self.0);
660+
BORROW_FLAGS.release_mut(self.array.py(), self.address, self.key);
655661
}
656662
}
657663

@@ -726,6 +732,10 @@ where
726732
)
727733
}
728734

735+
fn gcd_strides(strides: &[isize]) -> isize {
736+
reduce(strides.iter().copied(), gcd).unwrap_or(1)
737+
}
738+
729739
// FIXME(adamreichold): Use `usize::abs_diff` from std when that becomes stable.
730740
fn abs_diff(lhs: usize, rhs: usize) -> usize {
731741
if lhs >= rhs {
@@ -1276,4 +1286,43 @@ mod tests {
12761286
}
12771287
});
12781288
}
1289+
1290+
#[test]
1291+
#[should_panic(expected = "AlreadyBorrowed")]
1292+
fn cannot_clone_exclusive_borrow_via_deref() {
1293+
Python::with_gil(|py| {
1294+
let array = PyArray::<f64, _>::zeros(py, (3, 2, 1), false);
1295+
1296+
let exclusive = array.readwrite();
1297+
let _shared = exclusive.clone();
1298+
});
1299+
}
1300+
1301+
#[test]
1302+
fn failed_resize_does_not_double_release() {
1303+
Python::with_gil(|py| {
1304+
let array = PyArray::<f64, _>::zeros(py, 10, false);
1305+
1306+
// The view will make the internal reference check of `PyArray_Resize` fail.
1307+
let locals = [("array", array)].into_py_dict(py);
1308+
let _view = py
1309+
.eval("array[:]", None, Some(locals))
1310+
.unwrap()
1311+
.downcast::<PyArray1<f64>>()
1312+
.unwrap();
1313+
1314+
let exclusive = array.readwrite();
1315+
assert!(exclusive.resize(100).is_err());
1316+
});
1317+
}
1318+
1319+
#[test]
1320+
fn ineffective_resize_does_not_conflict() {
1321+
Python::with_gil(|py| {
1322+
let array = PyArray::<f64, _>::zeros(py, 10, false);
1323+
1324+
let exclusive = array.readwrite();
1325+
assert!(exclusive.resize(10).is_ok());
1326+
});
1327+
}
12791328
}

src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ pub use crate::borrow::{
6060
};
6161
pub use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
6262
pub use crate::dtype::{dtype, Complex32, Complex64, Element, PyArrayDescr};
63-
pub use crate::error::{DimensionalityError, FromVecError, NotContiguousError, TypeError};
63+
pub use crate::error::{
64+
BorrowError, DimensionalityError, FromVecError, NotContiguousError, TypeError,
65+
};
6466
pub use crate::npyffi::{PY_ARRAY_API, PY_UFUNC_API};
6567
#[allow(deprecated)]
6668
pub use crate::npyiter::{

0 commit comments

Comments
 (0)