Skip to content

Commit 9c4dcec

Browse files
committed
Expose PyArray::get and uget
1 parent dfabc95 commit 9c4dcec

File tree

3 files changed

+118
-24
lines changed

3 files changed

+118
-24
lines changed

src/array.rs

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::mem;
99
use std::os::raw::c_int;
1010
use std::ptr;
1111

12-
use convert::ToNpyDims;
12+
use convert::{NpyIndex, ToNpyDims};
1313
use error::{ErrorKind, IntoPyErr};
1414
use types::{NpyDataType, TypeNum, NPY_ORDER};
1515

@@ -166,8 +166,8 @@ impl<T> PyArray<T> {
166166
}
167167
}
168168

169-
/// Same as [shape](./struct.PyArray.html#method.shape)
170-
#[inline]
169+
/// Same as [shape](#method.shape)
170+
#[inline(always)]
171171
pub fn dims(&self) -> &[usize] {
172172
self.shape()
173173
}
@@ -201,21 +201,65 @@ impl<T> PyArray<T> {
201201
(*ptr).data as *mut T
202202
}
203203

204-
// TODO: we should provide safe access API
205-
unsafe fn get_unchecked(&self, index: &[isize]) -> *const T {
206-
let size = mem::size_of::<T>() as isize;
207-
index
208-
.iter()
209-
.zip(self.strides())
210-
.fold(self.data(), |pointer, (idx, stride)| {
211-
pointer.offset(stride * idx / size)
212-
})
204+
/// Get an immutable reference of a specified element, without checking the
205+
/// passed index is valid.
206+
///
207+
/// See [NpyIndex](../convert/trait.NpyIndex.html) for what types you can use as index.
208+
///
209+
/// Passing an invalid index can cause undefined behavior(mostly SIGSEGV).
210+
///
211+
/// # Example
212+
/// ```
213+
/// # extern crate pyo3; extern crate numpy; fn main() {
214+
/// use numpy::PyArray;
215+
/// let gil = pyo3::Python::acquire_gil();
216+
/// let arr = PyArray::arange(gil.python(), 0, 16, 1).reshape([2, 2, 4]).unwrap();
217+
/// assert_eq!(*arr.get([1, 0, 3]).unwrap(), 11);
218+
/// assert!(arr.get([2, 0, 3]).is_none());
219+
/// assert!(arr.get([1, 0, 3, 4]).is_none());
220+
/// assert!(arr.get([1, 0]).is_none());
221+
/// # }
222+
/// ```
223+
#[inline(always)]
224+
pub fn get<Idx: NpyIndex>(&self, index: Idx) -> Option<&T> {
225+
let offset = index.get_checked::<T>(self.shape(), self.strides())?;
226+
unsafe { Some(&*self.data().offset(offset)) }
227+
}
228+
229+
/// Same as [get](#method.get), but returns `&mut T`.
230+
#[inline(always)]
231+
pub fn get_mut<Idx: NpyIndex>(&self, index: Idx) -> Option<&mut T> {
232+
let offset = index.get_checked::<T>(self.shape(), self.strides())?;
233+
unsafe { Some(&mut *(self.data().offset(offset) as *mut T)) }
234+
}
235+
236+
/// Get an immutable reference of a specified element, without checking the
237+
/// passed index is valid.
238+
///
239+
/// See [NpyIndex](../convert/trait.NpyIndex.html) for what types you can use as index.
240+
///
241+
/// Passing an invalid index can cause undefined behavior(mostly SIGSEGV).
242+
///
243+
/// # Example
244+
/// ```
245+
/// # extern crate pyo3; extern crate numpy; fn main() {
246+
/// use numpy::PyArray;
247+
/// let gil = pyo3::Python::acquire_gil();
248+
/// let arr = PyArray::arange(gil.python(), 0, 16, 1).reshape([2, 2, 4]).unwrap();
249+
/// assert_eq!(unsafe { *arr.uget([1, 0, 3]) }, 11);
250+
/// # }
251+
/// ```
252+
#[inline(always)]
253+
pub unsafe fn uget<Idx: NpyIndex>(&self, index: Idx) -> &T {
254+
let offset = index.get_unchecked::<T>(self.strides());
255+
&*self.data().offset(offset)
213256
}
214257

215-
// TODO: we should provide safe access API
258+
/// Same as [uget](#method.uget), but returns `&mut T`.
216259
#[inline(always)]
217-
unsafe fn get_unchecked_mut(&self, index: &[isize]) -> *mut T {
218-
self.get_unchecked(index) as *mut T
260+
pub unsafe fn uget_mut<Idx: NpyIndex>(&self, index: Idx) -> &mut T {
261+
let offset = index.get_unchecked::<T>(self.strides());
262+
&mut *(self.data().offset(offset) as *mut T)
219263
}
220264
}
221265

@@ -258,7 +302,7 @@ impl<T: TypeNum> PyArray<T> {
258302
let array = Self::new(py, [iter.len()], false);
259303
unsafe {
260304
for (i, item) in iter.enumerate() {
261-
*array.get_unchecked_mut(&[i as isize]) = item;
305+
*array.uget_mut([i]) = item;
262306
}
263307
}
264308
array
@@ -293,7 +337,7 @@ impl<T: TypeNum> PyArray<T> {
293337
.resize_([capacity], 0, NPY_ORDER::NPY_ANYORDER)
294338
.expect("PyArray::from_iter: Failed to allocate memory");
295339
}
296-
*array.get_unchecked_mut(&[i as isize]) = item;
340+
*array.uget_mut([i]) = item;
297341
}
298342
}
299343
if capacity > length {
@@ -334,7 +378,7 @@ impl<T: TypeNum> PyArray<T> {
334378
unsafe {
335379
for y in 0..v.len() {
336380
for x in 0..last_len {
337-
*array.get_unchecked_mut(&[y as isize, x as isize]) = v[y][x].clone();
381+
*array.uget_mut([y, x]) = v[y][x].clone();
338382
}
339383
}
340384
}
@@ -387,8 +431,7 @@ impl<T: TypeNum> PyArray<T> {
387431
for z in 0..v.len() {
388432
for y in 0..dim2 {
389433
for x in 0..dim3 {
390-
*array.get_unchecked_mut(&[z as isize, y as isize, x as isize]) =
391-
v[z][y][x].clone();
434+
*array.uget_mut([z, y, x]) = v[z][y][x].clone();
392435
}
393436
}
394437
}
@@ -739,6 +782,6 @@ fn test_get_unchecked() {
739782
let gil = pyo3::Python::acquire_gil();
740783
let array = PyArray::from_slice(gil.python(), &[1i32, 2, 3]);
741784
unsafe {
742-
assert_eq!(*array.get_unchecked(&[1]), 2);
785+
assert_eq!(*array.uget([1]), 2);
743786
}
744787
}

src/convert.rs

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
//! Defines conversion traits between rust types and numpy data types.
22
3-
use ndarray::{ArrayBase, Data, Dimension};
3+
use ndarray::{ArrayBase, Data, Dimension, IntoDimension};
44
use pyo3::Python;
55

6+
use std::mem;
67
use std::os::raw::c_int;
78

89
use super::*;
@@ -61,11 +62,61 @@ pub trait ToNpyDims: Dimension {
6162
fn __private__(&self) -> PrivateMarker;
6263
}
6364

64-
impl<T: Dimension> ToNpyDims for T {
65+
impl<D: Dimension> ToNpyDims for D {
6566
fn __private__(&self) -> PrivateMarker {
6667
PrivateMarker
6768
}
6869
}
6970

71+
/// Types that can be used to index an array.
72+
///
73+
/// See[IntoDimension](https://docs.rs/ndarray/0.12/ndarray/dimension/conversion/trait.IntoDimension.html)
74+
/// for what types you can use as `NpyIndex`.
75+
///
76+
/// But basically, you can use
77+
/// - Tuple
78+
/// - Fixed sized array
79+
/// - Slice
80+
// Since Numpy's strides is byte offset, we can't use ndarray::NdIndex directly here.
81+
pub trait NpyIndex {
82+
fn get_checked<T>(self, dims: &[usize], strides: &[isize]) -> Option<isize>;
83+
fn get_unchecked<T>(self, strides: &[isize]) -> isize;
84+
fn __private__(self) -> PrivateMarker;
85+
}
86+
87+
impl<D: IntoDimension> NpyIndex for D {
88+
fn get_checked<T>(self, dims: &[usize], strides: &[isize]) -> Option<isize> {
89+
let indices_ = self.into_dimension();
90+
let indices = indices_.slice();
91+
if indices.len() != dims.len() {
92+
return None;
93+
}
94+
if indices.into_iter().zip(dims).any(|(i, d)| i >= d) {
95+
return None;
96+
}
97+
Some(get_unchecked_impl(
98+
indices,
99+
strides,
100+
mem::size_of::<T>() as isize,
101+
))
102+
}
103+
fn get_unchecked<T>(self, strides: &[isize]) -> isize {
104+
let indices_ = self.into_dimension();
105+
let indices = indices_.slice();
106+
get_unchecked_impl(indices, strides, mem::size_of::<T>() as isize)
107+
}
108+
fn __private__(self) -> PrivateMarker {
109+
PrivateMarker
110+
}
111+
}
112+
113+
fn get_unchecked_impl(indices: &[usize], strides: &[isize], size: isize) -> isize {
114+
indices
115+
.iter()
116+
.zip(strides)
117+
.map(|(&i, stride)| stride * i as isize / size)
118+
.sum()
119+
}
120+
70121
#[doc(hidden)]
71122
pub struct PrivateMarker;

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ pub mod npyffi;
4444
pub mod types;
4545

4646
pub use array::{get_array_module, PyArray};
47-
pub use convert::{ToNpyDims, ToPyArray};
47+
pub use convert::{NpyIndex, ToNpyDims, ToPyArray};
4848
pub use error::*;
4949
pub use npyffi::{PY_ARRAY_API, PY_UFUNC_API};
5050
pub use types::*;

0 commit comments

Comments
 (0)