Skip to content

Commit 457d013

Browse files
committed
Further debloat borrow module by reducing unnecessary generics.
1 parent e843bcc commit 457d013

File tree

1 file changed

+46
-47
lines changed

1 file changed

+46
-47
lines changed

src/borrow.rs

Lines changed: 46 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ impl BorrowKey {
199199
let range = data_range(array);
200200

201201
let data_ptr = array.data() as usize;
202-
let gcd_strides = reduce(array.strides().iter().copied(), gcd).unwrap_or(1);
202+
let gcd_strides = gcd_strides(array.strides());
203203

204204
Self {
205205
range,
@@ -252,16 +252,9 @@ impl BorrowFlags {
252252
(*self.0.get()).get_or_insert_with(AHashMap::new)
253253
}
254254

255-
fn acquire<T, D>(&self, array: &PyArray<T, D>) -> Result<(), BorrowError>
256-
where
257-
T: Element,
258-
D: Dimension,
259-
{
260-
let address = base_address(array);
261-
let key = BorrowKey::from_array(array);
262-
263-
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
264-
// and we are not calling into user code which might re-enter this function.
255+
fn acquire(&self, _py: Python, address: usize, key: BorrowKey) -> Result<(), BorrowError> {
256+
// SAFETY: Having `_py` implies holding the GIL and
257+
// we are not calling into user code which might re-enter this function.
265258
let borrow_flags = unsafe { BORROW_FLAGS.get() };
266259

267260
match borrow_flags.entry(address) {
@@ -302,16 +295,9 @@ impl BorrowFlags {
302295
Ok(())
303296
}
304297

305-
fn release<T, D>(&self, array: &PyArray<T, D>)
306-
where
307-
T: Element,
308-
D: Dimension,
309-
{
310-
let address = base_address(array);
311-
let key = BorrowKey::from_array(array);
312-
313-
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
314-
// and we are not calling into user code which might re-enter this function.
298+
fn release(&self, _py: Python, address: usize, key: BorrowKey) {
299+
// SAFETY: Having `_py` implies holding the GIL and
300+
// we are not calling into user code which might re-enter this function.
315301
let borrow_flags = unsafe { BORROW_FLAGS.get() };
316302

317303
let same_base_arrays = borrow_flags.get_mut(&address).unwrap();
@@ -329,16 +315,9 @@ impl BorrowFlags {
329315
}
330316
}
331317

332-
fn acquire_mut<T, D>(&self, array: &PyArray<T, D>) -> Result<(), BorrowError>
333-
where
334-
T: Element,
335-
D: Dimension,
336-
{
337-
let address = base_address(array);
338-
let key = BorrowKey::from_array(array);
339-
340-
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
341-
// and we are not calling into user code which might re-enter this function.
318+
fn acquire_mut(&self, _py: Python, address: usize, key: BorrowKey) -> Result<(), BorrowError> {
319+
// SAFETY: Having `_py` implies holding the GIL and
320+
// we are not calling into user code which might re-enter this function.
342321
let borrow_flags = unsafe { BORROW_FLAGS.get() };
343322

344323
match borrow_flags.entry(address) {
@@ -373,16 +352,9 @@ impl BorrowFlags {
373352
Ok(())
374353
}
375354

376-
fn release_mut<T, D>(&self, array: &PyArray<T, D>)
377-
where
378-
T: Element,
379-
D: Dimension,
380-
{
381-
let address = base_address(array);
382-
let key = BorrowKey::from_array(array);
383-
384-
// SAFETY: Access to `&PyArray<T, D>` implies holding the GIL
385-
// and we are not calling into user code which might re-enter this function.
355+
fn release_mut(&self, _py: Python, address: usize, key: BorrowKey) {
356+
// SAFETY: Having `_py` implies holding the GIL and
357+
// we are not calling into user code which might re-enter this function.
386358
let borrow_flags = unsafe { BORROW_FLAGS.get() };
387359

388360
let same_base_arrays = borrow_flags.get_mut(&address).unwrap();
@@ -454,7 +426,11 @@ where
454426
D: Dimension,
455427
{
456428
pub(crate) fn try_new(array: &'py PyArray<T, D>) -> Result<Self, BorrowError> {
457-
BORROW_FLAGS.acquire(array)?;
429+
let py = array.py();
430+
let address = base_address(array);
431+
let key = BorrowKey::from_array(array);
432+
433+
BORROW_FLAGS.acquire(py, address, key)?;
458434

459435
Ok(Self(array))
460436
}
@@ -499,7 +475,11 @@ where
499475
D: Dimension,
500476
{
501477
fn drop(&mut self) {
502-
BORROW_FLAGS.release(self.0);
478+
let py = self.0.py();
479+
let address = base_address(self.0);
480+
let key = BorrowKey::from_array(self.0);
481+
482+
BORROW_FLAGS.release(py, address, key);
503483
}
504484
}
505485

@@ -581,7 +561,11 @@ where
581561
return Err(BorrowError::NotWriteable);
582562
}
583563

584-
BORROW_FLAGS.acquire_mut(array)?;
564+
let py = array.py();
565+
let address = base_address(array);
566+
let key = BorrowKey::from_array(array);
567+
568+
BORROW_FLAGS.acquire_mut(py, address, key)?;
585569

586570
Ok(Self(array))
587571
}
@@ -632,14 +616,21 @@ where
632616
/// });
633617
/// ```
634618
pub fn resize(self, new_elems: usize) -> PyResult<Self> {
635-
BORROW_FLAGS.release_mut(self.0);
619+
let py = self.0.py();
620+
let address = base_address(self.0);
621+
let key = BorrowKey::from_array(self.0);
622+
623+
BORROW_FLAGS.release_mut(py, address, key);
636624

637625
// SAFETY: Ownership of `self` proves exclusive access to the interior of the array.
638626
unsafe {
639627
self.0.resize(new_elems)?;
640628
}
641629

642-
BORROW_FLAGS.acquire_mut(self.0)?;
630+
let address = base_address(self.0);
631+
let key = BorrowKey::from_array(self.0);
632+
633+
BORROW_FLAGS.acquire_mut(py, address, key)?;
643634

644635
Ok(self)
645636
}
@@ -651,7 +642,11 @@ where
651642
D: Dimension,
652643
{
653644
fn drop(&mut self) {
654-
BORROW_FLAGS.release_mut(self.0);
645+
let py = self.0.py();
646+
let address = base_address(self.0);
647+
let key = BorrowKey::from_array(self.0);
648+
649+
BORROW_FLAGS.release_mut(py, address, key);
655650
}
656651
}
657652

@@ -726,6 +721,10 @@ where
726721
)
727722
}
728723

724+
fn gcd_strides(strides: &[isize]) -> isize {
725+
reduce(strides.iter().copied(), gcd).unwrap_or(1)
726+
}
727+
729728
// FIXME(adamreichold): Use `usize::abs_diff` from std when that becomes stable.
730729
fn abs_diff(lhs: usize, rhs: usize) -> usize {
731730
if lhs >= rhs {

0 commit comments

Comments
 (0)