Skip to content

Commit e843bcc

Browse files
authored
Merge pull request #310 from PyO3/debloat-borrow
Factour out the non-generic parts of base_address and data_range to reduce code bloat.
2 parents 682fb0c + a6700ac commit e843bcc

File tree

1 file changed

+37
-31
lines changed

1 file changed

+37
-31
lines changed

src/borrow.rs

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ use std::ops::Deref;
171171
use ahash::AHashMap;
172172
use ndarray::{ArrayView, ArrayViewMut, Dimension, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
173173
use num_integer::gcd;
174-
use pyo3::{FromPyObject, PyAny, PyResult};
174+
use pyo3::{FromPyObject, PyAny, PyResult, Python};
175175

176176
use crate::array::PyArray;
177177
use crate::cold;
@@ -672,52 +672,58 @@ where
672672
}
673673

674674
fn base_address<T, D>(array: &PyArray<T, D>) -> usize {
675-
let py = array.py();
676-
let mut array = array.as_array_ptr();
677-
678-
loop {
679-
let base = unsafe { (*array).base };
680-
681-
if base.is_null() {
682-
return array as usize;
683-
} else if unsafe { npyffi::PyArray_Check(py, base) } != 0 {
684-
array = base as *mut PyArrayObject;
685-
} else {
686-
return base as usize;
675+
fn inner(py: Python, mut array: *mut PyArrayObject) -> usize {
676+
loop {
677+
let base = unsafe { (*array).base };
678+
679+
if base.is_null() {
680+
return array as usize;
681+
} else if unsafe { npyffi::PyArray_Check(py, base) } != 0 {
682+
array = base as *mut PyArrayObject;
683+
} else {
684+
return base as usize;
685+
}
687686
}
688687
}
688+
689+
inner(array.py(), array.as_array_ptr())
689690
}
690691

691692
fn data_range<T, D>(array: &PyArray<T, D>) -> (usize, usize)
692693
where
693694
T: Element,
694695
D: Dimension,
695696
{
696-
let shape = array.shape();
697-
let strides = array.strides();
698-
699-
let mut start = 0;
700-
let mut end = 0;
697+
fn inner(shape: &[usize], strides: &[isize], itemsize: isize, data: *mut u8) -> (usize, usize) {
698+
let mut start = 0;
699+
let mut end = 0;
701700

702-
if shape.iter().all(|dim| *dim != 0) {
703-
for (&dim, &stride) in shape.iter().zip(strides) {
704-
let offset = (dim - 1) as isize * stride;
701+
if shape.iter().all(|dim| *dim != 0) {
702+
for (&dim, &stride) in shape.iter().zip(strides) {
703+
let offset = (dim - 1) as isize * stride;
705704

706-
if offset >= 0 {
707-
end += offset;
708-
} else {
709-
start += offset;
705+
if offset >= 0 {
706+
end += offset;
707+
} else {
708+
start += offset;
709+
}
710710
}
711+
712+
end += itemsize;
711713
}
712714

713-
end += size_of::<T>() as isize;
714-
}
715+
let start = unsafe { data.offset(start) } as usize;
716+
let end = unsafe { data.offset(end) } as usize;
715717

716-
let data = unsafe { (*array.as_array_ptr()).data };
717-
let start = unsafe { data.offset(start) } as usize;
718-
let end = unsafe { data.offset(end) } as usize;
718+
(start, end)
719+
}
719720

720-
(start, end)
721+
inner(
722+
array.shape(),
723+
array.strides(),
724+
size_of::<T>() as _,
725+
array.data() as _,
726+
)
721727
}
722728

723729
// FIXME(adamreichold): Use `usize::abs_diff` from std when that becomes stable.

0 commit comments

Comments
 (0)