Skip to content

Commit 11182ec

Browse files
committed
Make the PyArray::new method unsafe and document what can and cannot be done with its return value.
1 parent e4e33ab commit 11182ec

File tree

2 files changed

+70
-36
lines changed

2 files changed

+70
-36
lines changed

src/array.rs

Lines changed: 67 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ impl<T, D> PyArray<T, D> {
249249
/// ```
250250
/// use numpy::PyArray3;
251251
/// pyo3::Python::with_gil(|py| {
252-
/// let arr = PyArray3::<f64>::new(py, [4, 5, 6], false);
252+
/// let arr = PyArray3::<f64>::zeros(py, [4, 5, 6], false);
253253
/// assert_eq!(arr.ndim(), 3);
254254
/// });
255255
/// ```
@@ -266,7 +266,7 @@ impl<T, D> PyArray<T, D> {
266266
/// ```
267267
/// use numpy::PyArray3;
268268
/// pyo3::Python::with_gil(|py| {
269-
/// let arr = PyArray3::<f64>::new(py, [4, 5, 6], false);
269+
/// let arr = PyArray3::<f64>::zeros(py, [4, 5, 6], false);
270270
/// assert_eq!(arr.strides(), &[240, 48, 8]);
271271
/// });
272272
/// ```
@@ -287,7 +287,7 @@ impl<T, D> PyArray<T, D> {
287287
/// ```
288288
/// use numpy::PyArray3;
289289
/// pyo3::Python::with_gil(|py| {
290-
/// let arr = PyArray3::<f64>::new(py, [4, 5, 6], false);
290+
/// let arr = PyArray3::<f64>::zeros(py, [4, 5, 6], false);
291291
/// assert_eq!(arr.shape(), &[4, 5, 6]);
292292
/// });
293293
/// ```
@@ -371,20 +371,46 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
371371
///
372372
/// If `is_fortran == true`, returns Fortran-order array. Else, returns C-order array.
373373
///
374+
/// # Safety
375+
///
376+
/// The returned array will always be safe to be dropped as the elements must either
377+
/// be trivially copyable or have `DATA_TYPE == DataType::Object`, i.e. be pointers
378+
/// into Python's heap, which NumPy will automatically zero-initialize.
379+
///
380+
/// However, the elements themselves will not be valid and should only be accessed
381+
/// via raw pointers obtained via [uget_raw](#method.uget_raw).
382+
///
383+
/// All methods which produce references to the elements invoke undefined behaviour.
384+
/// In particular, zero-initialized pointers are _not_ valid instances of `PyObject`.
385+
///
374386
/// # Example
375387
/// ```
376388
/// use numpy::PyArray3;
389+
///
377390
/// pyo3::Python::with_gil(|py| {
378-
/// let arr = PyArray3::<i32>::new(py, [4, 5, 6], false);
391+
/// let arr = unsafe {
392+
/// let arr = PyArray3::<i32>::new(py, [4, 5, 6], false);
393+
///
394+
/// for i in 0..4 {
395+
/// for j in 0..5 {
396+
/// for k in 0..6 {
397+
/// arr.uget_raw([i, j, k]).write((i * j * k) as i32);
398+
/// }
399+
/// }
400+
/// }
401+
///
402+
/// arr
403+
/// };
404+
///
379405
/// assert_eq!(arr.shape(), &[4, 5, 6]);
380406
/// });
381407
/// ```
382-
pub fn new<ID>(py: Python, dims: ID, is_fortran: bool) -> &Self
408+
pub unsafe fn new<ID>(py: Python, dims: ID, is_fortran: bool) -> &Self
383409
where
384410
ID: IntoDimension<Dim = D>,
385411
{
386412
let flags = if is_fortran { 1 } else { 0 };
387-
unsafe { PyArray::new_(py, dims, ptr::null_mut(), flags) }
413+
PyArray::new_(py, dims, ptr::null_mut(), flags)
388414
}
389415

390416
pub(crate) unsafe fn new_<ID>(
@@ -448,7 +474,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
448474
/// a fortran order array is created, otherwise a C-order array is created.
449475
///
450476
/// For elements with `DATA_TYPE == DataType::Object`, this will fill the array
451-
/// valid pointers to objects of type `<class 'int'>` with value zero.
477+
/// with valid pointers to zero-valued Python integer objects.
452478
///
453479
/// See also [PyArray_Zeros](https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Zeros)
454480
///
@@ -596,6 +622,16 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
596622
&mut *(self.data().offset(offset) as *mut _)
597623
}
598624

625+
/// Same as [uget](#method.uget), but returns `*mut T`.
626+
#[inline(always)]
627+
pub unsafe fn uget_raw<Idx>(&self, index: Idx) -> *mut T
628+
where
629+
Idx: NpyIndex<Dim = D>,
630+
{
631+
let offset = index.get_unchecked::<T>(self.strides());
632+
self.data().offset(offset) as *mut _
633+
}
634+
599635
/// Get dynamic dimensioned array from fixed dimension array.
600636
pub fn to_dyn(&self) -> &PyArray<T, IxDyn> {
601637
let python = self.py();
@@ -733,20 +769,18 @@ impl<T: Element> PyArray<T, Ix1> {
733769
/// });
734770
/// ```
735771
pub fn from_slice<'py>(py: Python<'py>, slice: &[T]) -> &'py Self {
736-
let array = PyArray::new(py, [slice.len()], false);
737-
if T::DATA_TYPE != DataType::Object {
738-
unsafe {
772+
unsafe {
773+
let array = PyArray::new(py, [slice.len()], false);
774+
if T::DATA_TYPE != DataType::Object {
739775
array.copy_ptr(slice.as_ptr(), slice.len());
740-
}
741-
} else {
742-
unsafe {
776+
} else {
743777
let data_ptr = array.data();
744778
for (i, item) in slice.iter().enumerate() {
745779
data_ptr.add(i).write(item.clone());
746780
}
747781
}
782+
array
748783
}
749-
array
750784
}
751785

752786
/// Construct one-dimension PyArray
@@ -781,13 +815,13 @@ impl<T: Element> PyArray<T, Ix1> {
781815
pub fn from_exact_iter(py: Python<'_>, iter: impl ExactSizeIterator<Item = T>) -> &Self {
782816
// NumPy will always zero-initialize object pointers,
783817
// so the array can be dropped safely if the iterator panics.
784-
let array = Self::new(py, [iter.len()], false);
785818
unsafe {
819+
let array = Self::new(py, [iter.len()], false);
786820
for (i, item) in iter.enumerate() {
787-
*array.uget_mut([i]) = item;
821+
array.uget_raw([i]).write(item);
788822
}
823+
array
789824
}
790-
array
791825
}
792826

793827
/// Construct one-dimension PyArray from a type which implements
@@ -809,11 +843,11 @@ impl<T: Element> PyArray<T, Ix1> {
809843
let iter = iter.into_iter();
810844
let (min_len, max_len) = iter.size_hint();
811845
let mut capacity = max_len.unwrap_or_else(|| min_len.max(512 / mem::size_of::<T>()));
812-
// NumPy will always zero-initialize object pointers,
813-
// so the array can be dropped safely if the iterator panics.
814-
let array = Self::new(py, [capacity], false);
815-
let mut length = 0;
816846
unsafe {
847+
// NumPy will always zero-initialize object pointers,
848+
// so the array can be dropped safely if the iterator panics.
849+
let array = Self::new(py, [capacity], false);
850+
let mut length = 0;
817851
for (i, item) in iter.enumerate() {
818852
length += 1;
819853
if length > capacity {
@@ -822,13 +856,13 @@ impl<T: Element> PyArray<T, Ix1> {
822856
.resize(capacity)
823857
.expect("PyArray::from_iter: Failed to allocate memory");
824858
}
825-
*array.uget_mut([i]) = item;
859+
array.uget_raw([i]).write(item);
826860
}
861+
if capacity > length {
862+
array.resize(length).unwrap()
863+
}
864+
array
827865
}
828-
if capacity > length {
829-
array.resize(length).unwrap()
830-
}
831-
array
832866
}
833867

834868
/// Extends or trancates the length of 1 dimension PyArray.
@@ -902,15 +936,15 @@ impl<T: Element> PyArray<T, Ix2> {
902936
return Err(FromVecError::new(v.len(), last_len));
903937
}
904938
let dims = [v.len(), last_len];
905-
let array = Self::new(py, dims, false);
906939
unsafe {
940+
let array = Self::new(py, dims, false);
907941
for (y, vy) in v.iter().enumerate() {
908942
for (x, vyx) in vy.iter().enumerate() {
909-
*array.uget_mut([y, x]) = vyx.clone();
943+
array.uget_raw([y, x]).write(vyx.clone());
910944
}
911945
}
946+
Ok(array)
912947
}
913-
Ok(array)
914948
}
915949
}
916950

@@ -944,17 +978,17 @@ impl<T: Element> PyArray<T, Ix3> {
944978
return Err(FromVecError::new(v.len(), len3));
945979
}
946980
let dims = [v.len(), len2, len3];
947-
let array = Self::new(py, dims, false);
948981
unsafe {
982+
let array = Self::new(py, dims, false);
949983
for (z, vz) in v.iter().enumerate() {
950984
for (y, vzy) in vz.iter().enumerate() {
951985
for (x, vzyx) in vzy.iter().enumerate() {
952-
*array.uget_mut([z, y, x]) = vzyx.clone();
986+
array.uget_raw([z, y, x]).write(vzyx.clone());
953987
}
954988
}
955989
}
990+
Ok(array)
956991
}
957-
Ok(array)
958992
}
959993
}
960994

@@ -965,7 +999,7 @@ impl<T: Element, D> PyArray<T, D> {
965999
/// use numpy::PyArray;
9661000
/// pyo3::Python::with_gil(|py| {
9671001
/// let pyarray_f = PyArray::arange(py, 2.0, 5.0, 1.0);
968-
/// let pyarray_i = PyArray::<i64, _>::new(py, [3], false);
1002+
/// let pyarray_i = unsafe { PyArray::<i64, _>::new(py, [3], false) };
9691003
/// assert!(pyarray_f.copy_to(pyarray_i).is_ok());
9701004
/// assert_eq!(pyarray_i.readonly().as_slice().unwrap(), &[2, 3, 4]);
9711005
/// });

tests/array.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ fn not_contiguous_array<'py>(py: Python<'py>) -> &'py PyArray1<i32> {
2525
fn new_c_order() {
2626
let dim = [3, 5];
2727
pyo3::Python::with_gil(|py| {
28-
let arr = PyArray::<f64, _>::new(py, dim, false);
28+
let arr = PyArray::<f64, _>::zeros(py, dim, false);
2929
assert!(arr.ndim() == 2);
3030
assert!(arr.dims() == dim);
3131
let size = std::mem::size_of::<f64>() as isize;
@@ -37,7 +37,7 @@ fn new_c_order() {
3737
fn new_fortran_order() {
3838
let dim = [3, 5];
3939
pyo3::Python::with_gil(|py| {
40-
let arr = PyArray::<f64, _>::new(py, dim, true);
40+
let arr = PyArray::<f64, _>::zeros(py, dim, true);
4141
assert!(arr.ndim() == 2);
4242
assert!(arr.dims() == dim);
4343
let size = std::mem::size_of::<f64>() as isize;
@@ -109,7 +109,7 @@ fn as_slice() {
109109
#[test]
110110
fn is_instance() {
111111
pyo3::Python::with_gil(|py| {
112-
let arr = PyArray2::<f64>::new(py, [3, 5], false);
112+
let arr = PyArray2::<f64>::zeros(py, [3, 5], false);
113113
assert!(arr.is_instance::<PyArray2<f64>>().unwrap());
114114
assert!(!arr.is_instance::<PyList>().unwrap());
115115
})

0 commit comments

Comments
 (0)