@@ -9,7 +9,7 @@ use std::mem;
9
9
use std:: os:: raw:: c_int;
10
10
use std:: ptr;
11
11
12
- use convert:: ToNpyDims ;
12
+ use convert:: { NpyIndex , ToNpyDims } ;
13
13
use error:: { ErrorKind , IntoPyErr } ;
14
14
use types:: { NpyDataType , TypeNum , NPY_ORDER } ;
15
15
@@ -166,8 +166,8 @@ impl<T> PyArray<T> {
166
166
}
167
167
}
168
168
169
- /// Same as [shape](./struct.PyArray.html #method.shape)
170
- #[ inline]
169
+ /// Same as [shape](#method.shape)
170
+ #[ inline( always ) ]
171
171
pub fn dims ( & self ) -> & [ usize ] {
172
172
self . shape ( )
173
173
}
@@ -201,21 +201,65 @@ impl<T> PyArray<T> {
201
201
( * ptr) . data as * mut T
202
202
}
203
203
204
- // TODO: we should provide safe access API
205
- unsafe fn get_unchecked ( & self , index : & [ isize ] ) -> * const T {
206
- let size = mem:: size_of :: < T > ( ) as isize ;
207
- index
208
- . iter ( )
209
- . zip ( self . strides ( ) )
210
- . fold ( self . data ( ) , |pointer, ( idx, stride) | {
211
- pointer. offset ( stride * idx / size)
212
- } )
204
+ /// Get an immutable reference of a specified element, without checking the
205
+ /// passed index is valid.
206
+ ///
207
+ /// See [NpyIndex](../convert/trait.NpyIndex.html) for what types you can use as index.
208
+ ///
209
+ /// Passing an invalid index can cause undefined behavior(mostly SIGSEGV).
210
+ ///
211
+ /// # Example
212
+ /// ```
213
+ /// # extern crate pyo3; extern crate numpy; fn main() {
214
+ /// use numpy::PyArray;
215
+ /// let gil = pyo3::Python::acquire_gil();
216
+ /// let arr = PyArray::arange(gil.python(), 0, 16, 1).reshape([2, 2, 4]).unwrap();
217
+ /// assert_eq!(*arr.get([1, 0, 3]).unwrap(), 11);
218
+ /// assert!(arr.get([2, 0, 3]).is_none());
219
+ /// assert!(arr.get([1, 0, 3, 4]).is_none());
220
+ /// assert!(arr.get([1, 0]).is_none());
221
+ /// # }
222
+ /// ```
223
+ #[ inline( always) ]
224
+ pub fn get < Idx : NpyIndex > ( & self , index : Idx ) -> Option < & T > {
225
+ let offset = index. get_checked :: < T > ( self . shape ( ) , self . strides ( ) ) ?;
226
+ unsafe { Some ( & * self . data ( ) . offset ( offset) ) }
227
+ }
228
+
229
+ /// Same as [get](#method.get), but returns `&mut T`.
230
+ #[ inline( always) ]
231
+ pub fn get_mut < Idx : NpyIndex > ( & self , index : Idx ) -> Option < & mut T > {
232
+ let offset = index. get_checked :: < T > ( self . shape ( ) , self . strides ( ) ) ?;
233
+ unsafe { Some ( & mut * ( self . data ( ) . offset ( offset) as * mut T ) ) }
234
+ }
235
+
236
+ /// Get an immutable reference of a specified element, without checking the
237
+ /// passed index is valid.
238
+ ///
239
+ /// See [NpyIndex](../convert/trait.NpyIndex.html) for what types you can use as index.
240
+ ///
241
+ /// Passing an invalid index can cause undefined behavior(mostly SIGSEGV).
242
+ ///
243
+ /// # Example
244
+ /// ```
245
+ /// # extern crate pyo3; extern crate numpy; fn main() {
246
+ /// use numpy::PyArray;
247
+ /// let gil = pyo3::Python::acquire_gil();
248
+ /// let arr = PyArray::arange(gil.python(), 0, 16, 1).reshape([2, 2, 4]).unwrap();
249
+ /// assert_eq!(unsafe { *arr.uget([1, 0, 3]) }, 11);
250
+ /// # }
251
+ /// ```
252
+ #[ inline( always) ]
253
+ pub unsafe fn uget < Idx : NpyIndex > ( & self , index : Idx ) -> & T {
254
+ let offset = index. get_unchecked :: < T > ( self . strides ( ) ) ;
255
+ & * self . data ( ) . offset ( offset)
213
256
}
214
257
215
- // TODO: we should provide safe access API
258
+ /// Same as [uget](#method.uget), but returns `&mut T`.
216
259
#[ inline( always) ]
217
- unsafe fn get_unchecked_mut ( & self , index : & [ isize ] ) -> * mut T {
218
- self . get_unchecked ( index) as * mut T
260
+ pub unsafe fn uget_mut < Idx : NpyIndex > ( & self , index : Idx ) -> & mut T {
261
+ let offset = index. get_unchecked :: < T > ( self . strides ( ) ) ;
262
+ & mut * ( self . data ( ) . offset ( offset) as * mut T )
219
263
}
220
264
}
221
265
@@ -258,7 +302,7 @@ impl<T: TypeNum> PyArray<T> {
258
302
let array = Self :: new ( py, [ iter. len ( ) ] , false ) ;
259
303
unsafe {
260
304
for ( i, item) in iter. enumerate ( ) {
261
- * array. get_unchecked_mut ( & [ i as isize ] ) = item;
305
+ * array. uget_mut ( [ i ] ) = item;
262
306
}
263
307
}
264
308
array
@@ -293,7 +337,7 @@ impl<T: TypeNum> PyArray<T> {
293
337
. resize_ ( [ capacity] , 0 , NPY_ORDER :: NPY_ANYORDER )
294
338
. expect ( "PyArray::from_iter: Failed to allocate memory" ) ;
295
339
}
296
- * array. get_unchecked_mut ( & [ i as isize ] ) = item;
340
+ * array. uget_mut ( [ i ] ) = item;
297
341
}
298
342
}
299
343
if capacity > length {
@@ -334,7 +378,7 @@ impl<T: TypeNum> PyArray<T> {
334
378
unsafe {
335
379
for y in 0 ..v. len ( ) {
336
380
for x in 0 ..last_len {
337
- * array. get_unchecked_mut ( & [ y as isize , x as isize ] ) = v[ y] [ x] . clone ( ) ;
381
+ * array. uget_mut ( [ y , x] ) = v[ y] [ x] . clone ( ) ;
338
382
}
339
383
}
340
384
}
@@ -387,8 +431,7 @@ impl<T: TypeNum> PyArray<T> {
387
431
for z in 0 ..v. len ( ) {
388
432
for y in 0 ..dim2 {
389
433
for x in 0 ..dim3 {
390
- * array. get_unchecked_mut ( & [ z as isize , y as isize , x as isize ] ) =
391
- v[ z] [ y] [ x] . clone ( ) ;
434
+ * array. uget_mut ( [ z, y, x] ) = v[ z] [ y] [ x] . clone ( ) ;
392
435
}
393
436
}
394
437
}
@@ -739,6 +782,6 @@ fn test_get_unchecked() {
739
782
let gil = pyo3:: Python :: acquire_gil ( ) ;
740
783
let array = PyArray :: from_slice ( gil. python ( ) , & [ 1i32 , 2 , 3 ] ) ;
741
784
unsafe {
742
- assert_eq ! ( * array. get_unchecked ( & [ 1 ] ) , 2 ) ;
785
+ assert_eq ! ( * array. uget ( [ 1 ] ) , 2 ) ;
743
786
}
744
787
}
0 commit comments