Skip to content

Commit ef78219

Browse files
authored
Merge pull request #63 from kngwyu/to-npy-dims
Introduce ToNpyDims trait
2 parents 8c0f427 + 55f7b1d commit ef78219

File tree

3 files changed

+91
-24
lines changed

3 files changed

+91
-24
lines changed

src/array.rs

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ impl<T> PyArray<T> {
7777
/// use numpy::PyArray;
7878
/// fn return_py_array() -> PyArray<i32> {
7979
/// let gil = Python::acquire_gil();
80-
/// let array = PyArray::zeros(gil.python(), &[5], false);
80+
/// let array = PyArray::zeros(gil.python(), [5], false);
8181
/// array.to_owned(gil.python())
8282
/// }
8383
/// let array = return_py_array();
@@ -108,7 +108,7 @@ impl<T> PyArray<T> {
108108
/// # extern crate pyo3; extern crate numpy; fn main() {
109109
/// use numpy::PyArray;
110110
/// let gil = pyo3::Python::acquire_gil();
111-
/// let arr = PyArray::<f64>::new(gil.python(), &[4, 5, 6]);
111+
/// let arr = PyArray::<f64>::new(gil.python(), [4, 5, 6]);
112112
/// assert_eq!(arr.ndim(), 3);
113113
/// # }
114114
/// ```
@@ -148,7 +148,7 @@ impl<T> PyArray<T> {
148148
/// # extern crate pyo3; extern crate numpy; fn main() {
149149
/// use numpy::PyArray;
150150
/// let gil = pyo3::Python::acquire_gil();
151-
/// let arr = PyArray::<f64>::new(gil.python(), &[4, 5, 6]);
151+
/// let arr = PyArray::<f64>::new(gil.python(), [4, 5, 6]);
152152
/// assert_eq!(arr.shape(), &[4, 5, 6]);
153153
/// # }
154154
/// ```
@@ -248,7 +248,7 @@ impl<T: TypeNum> PyArray<T> {
248248
let flattend: Vec<_> = v.iter().cloned().flatten().collect();
249249
unsafe {
250250
let data = convert::into_raw(flattend);
251-
Ok(PyArray::new_(py, &dims, null_mut(), data))
251+
Ok(PyArray::new_(py, dims, null_mut(), data))
252252
}
253253
}
254254

@@ -287,7 +287,7 @@ impl<T: TypeNum> PyArray<T> {
287287
let flattend: Vec<_> = v.iter().flat_map(|v| v.iter().cloned().flatten()).collect();
288288
unsafe {
289289
let data = convert::into_raw(flattend);
290-
Ok(PyArray::new_(py, &dims, null_mut(), data))
290+
Ok(PyArray::new_(py, dims, null_mut(), data))
291291
}
292292
}
293293

@@ -379,17 +379,16 @@ impl<T: TypeNum> PyArray<T> {
379379
/// Construct a new PyArray given a raw pointer and dimensions.
380380
///
381381
/// Please use `new` or from methods instead.
382-
pub unsafe fn new_<'py>(
382+
pub unsafe fn new_<'py, D: ToNpyDims>(
383383
py: Python<'py>,
384-
dims: &[usize],
384+
dims: D,
385385
strides: *mut npy_intp,
386386
data: *mut c_void,
387387
) -> &'py Self {
388-
let dims: Vec<_> = dims.iter().map(|d| *d as npy_intp).collect();
389388
let ptr = PY_ARRAY_API.PyArray_New(
390389
PY_ARRAY_API.get_type_object(npyffi::ArrayType::PyArray_Type),
391-
dims.len() as i32,
392-
dims.as_ptr() as *mut npy_intp,
390+
dims.dims_len(),
391+
dims.dims_ptr(),
393392
T::typenum_default(),
394393
strides,
395394
data,
@@ -409,11 +408,11 @@ impl<T: TypeNum> PyArray<T> {
409408
/// # extern crate pyo3; extern crate numpy; #[macro_use] extern crate ndarray; fn main() {
410409
/// use numpy::PyArray;
411410
/// let gil = pyo3::Python::acquire_gil();
412-
/// let pyarray = PyArray::<i32>::new(gil.python(), &[4, 5, 6]);
411+
/// let pyarray = PyArray::<i32>::new(gil.python(), [4, 5, 6]);
413412
/// assert_eq!(pyarray.shape(), &[4, 5, 6]);
414413
/// # }
415414
/// ```
416-
pub fn new<'py>(py: Python<'py>, dims: &[usize]) -> &'py Self {
415+
pub fn new<'py, D: ToNpyDims>(py: Python<'py>, dims: D) -> &'py Self {
417416
unsafe { Self::new_(py, dims, null_mut(), null_mut()) }
418417
}
419418

@@ -427,17 +426,16 @@ impl<T: TypeNum> PyArray<T> {
427426
/// # extern crate pyo3; extern crate numpy; #[macro_use] extern crate ndarray; fn main() {
428427
/// use numpy::PyArray;
429428
/// let gil = pyo3::Python::acquire_gil();
430-
/// let pyarray = PyArray::zeros(gil.python(), &[2, 2], false);
429+
/// let pyarray = PyArray::zeros(gil.python(), [2, 2], false);
431430
/// assert_eq!(pyarray.as_array().unwrap(), array![[0, 0], [0, 0]].into_dyn());
432431
/// # }
433432
/// ```
434-
pub fn zeros<'py>(py: Python<'py>, dims: &[usize], is_fortran: bool) -> &'py Self {
435-
let dims: Vec<npy_intp> = dims.iter().map(|d| *d as npy_intp).collect();
433+
pub fn zeros<'py, D: ToNpyDims>(py: Python<'py>, dims: D, is_fortran: bool) -> &'py Self {
436434
unsafe {
437435
let descr = PY_ARRAY_API.PyArray_DescrFromType(T::typenum_default());
438436
let ptr = PY_ARRAY_API.PyArray_Zeros(
439-
dims.len() as i32,
440-
dims.as_ptr() as *mut npy_intp,
437+
dims.dims_len(),
438+
dims.dims_ptr(),
441439
descr,
442440
if is_fortran { -1 } else { 0 },
443441
);
@@ -499,7 +497,7 @@ impl<T: TypeNum> PyArray<T> {
499497
/// use numpy::{PyArray, IntoPyArray};
500498
/// let gil = pyo3::Python::acquire_gil();
501499
/// let pyarray_f = PyArray::<f64>::arange(gil.python(), 2.0, 5.0, 1.0);
502-
/// let pyarray_i = PyArray::<i64>::new(gil.python(), &[3]);
500+
/// let pyarray_i = PyArray::<i64>::new(gil.python(), [3]);
503501
/// assert!(pyarray_f.move_to(pyarray_i).is_ok());
504502
/// assert_eq!(pyarray_i.as_slice().unwrap(), &[2, 3, 4]);
505503
/// # }

src/convert.rs

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use pyo3::Python;
55

66
use std::iter::Iterator;
77
use std::mem::size_of;
8-
use std::os::raw::c_void;
8+
use std::os::raw::{c_int, c_void};
99
use std::ptr::null_mut;
1010

1111
use super::*;
@@ -31,15 +31,15 @@ impl<T: TypeNum> IntoPyArray for Box<[T]> {
3131
fn into_pyarray(self, py: Python) -> &PyArray<Self::Item> {
3232
let dims = [self.len()];
3333
let ptr = Box::into_raw(self);
34-
unsafe { PyArray::new_(py, &dims, null_mut(), ptr as *mut c_void) }
34+
unsafe { PyArray::new_(py, dims, null_mut(), ptr as *mut c_void) }
3535
}
3636
}
3737

3838
impl<T: TypeNum> IntoPyArray for Vec<T> {
3939
type Item = T;
4040
fn into_pyarray(self, py: Python) -> &PyArray<Self::Item> {
4141
let dims = [self.len()];
42-
unsafe { PyArray::new_(py, &dims, null_mut(), into_raw(self)) }
42+
unsafe { PyArray::new_(py, dims, null_mut(), into_raw(self)) }
4343
}
4444
}
4545

@@ -54,7 +54,7 @@ impl<A: TypeNum, D: Dimension> IntoPyArray for Array<A, D> {
5454
.collect();
5555
unsafe {
5656
let data = into_raw(self.into_raw_vec());
57-
PyArray::new_(py, &dims, strides.as_mut_ptr(), data)
57+
PyArray::new_(py, dims, strides.as_mut_ptr(), data)
5858
}
5959
}
6060
}
@@ -68,7 +68,7 @@ macro_rules! array_impls {
6868
let dims = [$N];
6969
let ptr = Box::into_raw(Box::new(self));
7070
unsafe {
71-
PyArray::new_(py, &dims, null_mut(), ptr as *mut c_void)
71+
PyArray::new_(py, dims, null_mut(), ptr as *mut c_void)
7272
}
7373
}
7474
}
@@ -87,3 +87,72 @@ pub(crate) unsafe fn into_raw<T>(x: Vec<T>) -> *mut c_void {
8787
let ptr = Box::into_raw(x.into_boxed_slice());
8888
ptr as *mut c_void
8989
}
90+
91+
/// Utility trait to specify the dimention of array
92+
pub trait ToNpyDims {
93+
fn dims_len(&self) -> c_int;
94+
fn dims_ptr(&self) -> *mut npy_intp;
95+
fn to_npy_dims(&self) -> npyffi::PyArray_Dims {
96+
npyffi::PyArray_Dims {
97+
ptr: self.dims_ptr(),
98+
len: self.dims_len(),
99+
}
100+
}
101+
}
102+
103+
macro_rules! array_dim_impls {
104+
($($N: expr)+) => {
105+
$(
106+
impl ToNpyDims for [usize; $N] {
107+
fn dims_len(&self) -> c_int {
108+
$N as c_int
109+
}
110+
fn dims_ptr(&self) -> *mut npy_intp {
111+
self.as_ptr() as *mut npy_intp
112+
}
113+
}
114+
impl<'a> ToNpyDims for &'a [usize; $N] {
115+
fn dims_len(&self) -> c_int {
116+
$N as c_int
117+
}
118+
fn dims_ptr(&self) -> *mut npy_intp {
119+
self.as_ptr() as *mut npy_intp
120+
}
121+
}
122+
)+
123+
}
124+
}
125+
126+
array_dim_impls! {
127+
0 1 2 3 4 5 6 7 8 9
128+
10 11 12 13 14 15 16 17 18 19
129+
20 21 22 23 24 25 26 27 28 29
130+
30 31 32
131+
}
132+
133+
impl<'a> ToNpyDims for &'a [usize] {
134+
fn dims_len(&self) -> c_int {
135+
self.len() as c_int
136+
}
137+
fn dims_ptr(&self) -> *mut npy_intp {
138+
self.as_ptr() as *mut npy_intp
139+
}
140+
}
141+
142+
impl ToNpyDims for Vec<usize> {
143+
fn dims_len(&self) -> c_int {
144+
self.len() as c_int
145+
}
146+
fn dims_ptr(&self) -> *mut npy_intp {
147+
self.as_ptr() as *mut npy_intp
148+
}
149+
}
150+
151+
impl ToNpyDims for Box<[usize]> {
152+
fn dims_len(&self) -> c_int {
153+
self.len() as c_int
154+
}
155+
fn dims_ptr(&self) -> *mut npy_intp {
156+
self.as_ptr() as *mut npy_intp
157+
}
158+
}

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ pub mod npyffi;
4444
pub mod types;
4545

4646
pub use array::{get_array_module, PyArray};
47-
pub use convert::IntoPyArray;
47+
pub use convert::{IntoPyArray, ToNpyDims};
4848
pub use error::*;
4949
pub use npyffi::{PY_ARRAY_API, PY_UFUNC_API};
5050
pub use types::*;

0 commit comments

Comments
 (0)