Skip to content

Commit 9d48ef4

Browse files
committed
Revive into_pyarray
1 parent 9dff9ff commit 9d48ef4

File tree

8 files changed

+179
-15
lines changed

8 files changed

+179
-15
lines changed

example/extensions/src/lib.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ extern crate numpy;
33
extern crate pyo3;
44

55
use ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
6-
use numpy::{IntoPyResult, PyArrayDyn, ToPyArray};
6+
use numpy::{IntoPyArray, IntoPyResult, PyArray1, PyArrayDyn, ToPyArray};
77
use pyo3::prelude::{pymodinit, PyModule, PyResult, Python};
88

99
#[pymodinit]
@@ -28,7 +28,7 @@ fn rust_ext(_py: Python, m: &PyModule) -> PyResult<()> {
2828
) -> PyResult<PyArrayDyn<f64>> {
2929
// you can convert numpy error into PyErr via ?
3030
let x = x.as_array()?;
31-
// you can also specify your error context, via closure
31+
// you can also specify your error context, via closure
3232
let y = y.as_array().into_pyresult_with(|| "y must be f64 array")?;
3333
Ok(axpy(a, x, y).to_pyarray(py).to_owned(py))
3434
}
@@ -41,5 +41,15 @@ fn rust_ext(_py: Python, m: &PyModule) -> PyResult<()> {
4141
Ok(())
4242
}
4343

44+
#[pyfn(m, "get_vec")]
45+
fn get_vec(py: Python, size: usize) -> PyResult<&PyArray1<f32>> {
46+
Ok(vec![0.0; size].into_pyarray(py))
47+
}
48+
// use numpy::slice_box::SliceBox;
49+
// #[pyfn(m, "get_slice")]
50+
// fn get_slice(py: Python, size: usize) -> PyResult<SliceBox<f32>> {
51+
// let sbox = numpy::slice_box::SliceBox::new(vec![0.0; size].into_boxed_slice());
52+
// Ok(sbox)
53+
// }
4454
Ok(())
4555
}

example/setup.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,7 @@
1010
class CmdTest(TestCommand):
1111
def run(self):
1212
self.run_command("test_rust")
13-
test_files = os.listdir('./tests')
14-
ok = 0
15-
for f in test_files:
16-
_, ext = os.path.splitext(f)
17-
if ext == '.py':
18-
res = subprocess.call([sys.executable, f], cwd='./tests')
19-
ok = ok | res
20-
sys.exit(res)
13+
subprocess.check_call([sys.executable, 'test_ext.py'], cwd='./tests')
2114

2215

2316
setup_requires = ['setuptools-rust>=0.6.0']

example/tests/test_ext.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from rust_ext import axpy, mult
2+
from rust_ext import axpy, mult, get_vec
33
import unittest
44

55
class TestExt(unittest.TestCase):
@@ -17,6 +17,12 @@ def test_mult(self):
1717
mult(3.0, x)
1818
np.testing.assert_array_almost_equal(x, np.array([3.0, 6.0, 9.0]))
1919

20+
def test_into_pyarray(self):
21+
x = get_vec(1000)
22+
np.testing.assert_array_almost_equal(x, np.zeros(1000))
23+
2024

2125
if __name__ == "__main__":
2226
unittest.main()
27+
28+

src/array.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use std::ptr;
1111

1212
use convert::{NpyIndex, ToNpyDims};
1313
use error::{ErrorKind, IntoPyResult};
14+
use slice_box::SliceBox;
1415
use types::{NpyDataType, TypeNum};
1516

1617
/// A safe, static-typed interface for
@@ -329,6 +330,31 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
329330
Self::from_owned_ptr(py, ptr)
330331
}
331332

333+
pub(crate) unsafe fn new_with_data<'py, ID>(
334+
py: Python<'py>,
335+
dims: ID,
336+
strides: *mut npy_intp,
337+
slice: &SliceBox<T>,
338+
) -> &'py Self
339+
where
340+
ID: IntoDimension<Dim = D>,
341+
{
342+
let dims = dims.into_dimension();
343+
let ptr = PY_ARRAY_API.PyArray_New(
344+
PY_ARRAY_API.get_type_object(npyffi::ArrayType::PyArray_Type),
345+
dims.ndim_cint(),
346+
dims.as_dims_ptr(),
347+
T::typenum_default(),
348+
strides, // strides
349+
slice.data(), // data
350+
0, // itemsize
351+
0, // flag
352+
::std::ptr::null_mut(), //obj
353+
);
354+
PY_ARRAY_API.PyArray_SetBaseObject(ptr as *mut npyffi::PyArrayObject, slice.as_ptr());
355+
Self::from_owned_ptr(py, ptr)
356+
}
357+
332358
/// Construct a new nd-dimensional array filled with 0.
333359
///
334360
/// If `is_fortran` is true, then

src/convert.rs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,54 @@
22
33
use ndarray::{ArrayBase, Data, Dimension, IntoDimension, Ix1};
44
use pyo3::Python;
5+
use slice_box::SliceBox;
56

67
use std::mem;
78
use std::os::raw::c_int;
9+
use std::ptr;
810

911
use super::*;
1012

11-
/// Covversion trait from rust types to `PyArray`.
13+
/// Covnersion trait from some rust types to `PyArray`.
14+
///
15+
/// This trait takes `self`, which means **it holds a pointer to Rust heap, until `resize` or other
16+
/// destructive method is called**.
17+
/// # Example
18+
/// ```
19+
/// # extern crate pyo3; extern crate numpy; fn main() {
20+
/// use numpy::{PyArray, ToPyArray};
21+
/// let gil = pyo3::Python::acquire_gil();
22+
/// let py_array = vec![1, 2, 3].to_pyarray(gil.python());
23+
/// assert_eq!(py_array.as_slice().unwrap(), &[1, 2, 3]);
24+
/// # }
25+
/// ```
26+
pub trait IntoPyArray {
27+
type Item: TypeNum;
28+
type Dim: Dimension;
29+
fn into_pyarray<'py>(self, Python<'py>) -> &'py PyArray<Self::Item, Self::Dim>;
30+
}
31+
32+
impl<T: TypeNum> IntoPyArray for Box<[T]> {
33+
type Item = T;
34+
type Dim = Ix1;
35+
fn into_pyarray<'py>(self, py: Python<'py>) -> &'py PyArray<Self::Item, Self::Dim> {
36+
let len = self.len();
37+
unsafe {
38+
let slice = SliceBox::new(self);
39+
PyArray::new_with_data(py, [len], ptr::null_mut(), slice)
40+
}
41+
}
42+
}
43+
44+
impl<T: TypeNum> IntoPyArray for Vec<T> {
45+
type Item = T;
46+
type Dim = Ix1;
47+
fn into_pyarray<'py>(self, py: Python<'py>) -> &'py PyArray<Self::Item, Self::Dim> {
48+
self.into_boxed_slice().into_pyarray(py)
49+
}
50+
}
51+
52+
/// Conversion trait from rust types to `PyArray`.
1253
///
1354
/// This trait takes `&self`, which means **it alocates in Python heap and then copies
1455
/// elements there**.

src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,14 @@ pub mod array;
4444
pub mod convert;
4545
pub mod error;
4646
pub mod npyffi;
47+
mod slice_box;
4748
pub mod types;
4849

4950
pub use array::{
5051
get_array_module, PyArray, PyArray1, PyArray2, PyArray3, PyArray4, PyArray5, PyArray6,
5152
PyArrayDyn,
5253
};
53-
pub use convert::{NpyIndex, ToNpyDims, ToPyArray};
54+
pub use convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
5455
pub use error::{ArrayFormat, ErrorKind, IntoPyErr, IntoPyResult};
5556
pub use ndarray::{Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
5657
pub use npyffi::{PY_ARRAY_API, PY_UFUNC_API};

src/slice_box.rs

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
use crate::types::TypeNum;
2+
use pyo3::{self, ffi, typeob, PyObjectAlloc, Python, ToPyPointer};
3+
use std::os::raw::c_void;
4+
5+
#[repr(C)]
6+
pub(crate) struct SliceBox<T> {
7+
ob_base: ffi::PyObject,
8+
inner: *mut [T],
9+
}
10+
11+
impl<T> SliceBox<T> {
12+
pub(crate) unsafe fn new<'a>(box_: Box<[T]>) -> &'a Self {
13+
<Self as typeob::PyTypeObject>::init_type();
14+
let type_ob = <Self as typeob::PyTypeInfo>::type_object() as *mut _;
15+
let base = ffi::_PyObject_New(type_ob);
16+
*base = ffi::PyObject_HEAD_INIT;
17+
(*base).ob_type = <Self as typeob::PyTypeInfo>::type_object() as *mut _;
18+
let self_ = base as *mut SliceBox<T>;
19+
(*self_).inner = Box::into_raw(box_);
20+
&*self_
21+
}
22+
pub(crate) fn data(&self) -> *mut c_void {
23+
self.inner as *mut c_void
24+
}
25+
}
26+
27+
impl<T> typeob::PyTypeInfo for SliceBox<T> {
28+
type Type = ();
29+
type BaseType = pyo3::PyObjectRef;
30+
const NAME: &'static str = "SliceBox";
31+
const DESCRIPTION: &'static str = "Memory store for PyArray made by IntoPyArray.";
32+
const FLAGS: usize = 0;
33+
const SIZE: usize = { Self::OFFSET as usize + std::mem::size_of::<Self>() + 0 + 0 };
34+
const OFFSET: isize = 0;
35+
#[inline]
36+
unsafe fn type_object() -> &'static mut ::pyo3::ffi::PyTypeObject {
37+
static mut TYPE_OBJECT: ::pyo3::ffi::PyTypeObject = ::pyo3::ffi::PyTypeObject_INIT;
38+
&mut TYPE_OBJECT
39+
}
40+
}
41+
42+
impl<T: TypeNum> typeob::PyTypeObject for SliceBox<T> {
43+
#[inline(always)]
44+
fn init_type() {
45+
static START: std::sync::Once = std::sync::ONCE_INIT;
46+
START.call_once(|| {
47+
let ty = unsafe { <Self as typeob::PyTypeInfo>::type_object() };
48+
if (ty.tp_flags & ffi::Py_TPFLAGS_READY) == 0 {
49+
let gil = Python::acquire_gil();
50+
let py = gil.python();
51+
let mod_name = format!("rust_numpy.{:?}", T::npy_data_type());
52+
typeob::initialize_type::<Self>(py, Some(&mod_name))
53+
.map_err(|e| e.print(py))
54+
.expect("Failed to initialize SliceBox");
55+
}
56+
});
57+
}
58+
}
59+
60+
impl<T> ToPyPointer for SliceBox<T> {
61+
#[inline]
62+
fn as_ptr(&self) -> *mut ffi::PyObject {
63+
&self.ob_base as *const _ as *mut _
64+
}
65+
}
66+
67+
impl<T> PyObjectAlloc<SliceBox<T>> for SliceBox<T> {
68+
/// Calls the rust destructor for the object.
69+
unsafe fn drop(py: Python, obj: *mut ffi::PyObject) {
70+
let data = (*(obj as *mut SliceBox<T>)).inner;
71+
let box_ = Box::from_raw(data);
72+
drop(box_);
73+
<Self as typeob::PyTypeInfo>::BaseType::drop(py, obj);
74+
}
75+
unsafe fn dealloc(py: Python, obj: *mut ffi::PyObject) {
76+
Self::drop(py, obj);
77+
ffi::PyObject_Free(obj as *mut c_void);
78+
}
79+
}

tests/array.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ fn as_array() {
7575
}
7676

7777
#[test]
78-
fn into_pyarray_vec() {
78+
fn to_pyarray_vec() {
7979
let gil = pyo3::Python::acquire_gil();
8080

8181
let a = vec![1, 2, 3];
@@ -86,7 +86,7 @@ fn into_pyarray_vec() {
8686
}
8787

8888
#[test]
89-
fn into_pyarray_array() {
89+
fn to_pyarray_array() {
9090
let gil = pyo3::Python::acquire_gil();
9191

9292
let a = Array3::<f64>::zeros((3, 4, 2));
@@ -206,3 +206,11 @@ fn array_cast() {
206206
let arr_i32: &PyArray2<i32> = arr_f64.cast(false).unwrap();
207207
assert_eq!(arr_i32.as_array().unwrap(), array![[1, 2, 3], [1, 2, 3]]);
208208
}
209+
210+
#[test]
211+
fn into_pyarray_vec() {
212+
let gil = pyo3::Python::acquire_gil();
213+
let a = vec![1, 2, 3];
214+
let arr = a.into_pyarray(gil.python());
215+
assert_eq!(arr.as_slice().unwrap(), &[1, 2, 3])
216+
}

0 commit comments

Comments
 (0)