Skip to content

Commit dfabc95

Browse files
committed
Modify ToNpyDims to inherit Dimension
so that we can take impl IntoDimension as argument
1 parent d80f2a3 commit dfabc95

File tree

4 files changed

+41
-51
lines changed

4 files changed

+41
-51
lines changed

src/array.rs

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -415,15 +415,14 @@ impl<T: TypeNum> PyArray<T> {
415415
S: Data<Elem = T>,
416416
D: Dimension,
417417
{
418-
let dims: Vec<_> = arr.shape().iter().cloned().collect();
419-
let len = dims.iter().fold(1, |prod, d| prod * d);
418+
let len = arr.len();
420419
let mut strides: Vec<_> = arr
421420
.strides()
422421
.into_iter()
423422
.map(|n| n * mem::size_of::<T>() as npy_intp)
424423
.collect();
425424
unsafe {
426-
let array = PyArray::new_(py, &*dims, strides.as_mut_ptr() as *mut npy_intp, 0);
425+
let array = PyArray::new_(py, arr.shape(), strides.as_mut_ptr() as *mut npy_intp, 0);
427426
ptr::copy_nonoverlapping(arr.as_ptr(), array.data(), len);
428427
array
429428
}
@@ -507,21 +506,22 @@ impl<T: TypeNum> PyArray<T> {
507506
/// assert_eq!(pyarray.shape(), &[4, 5, 6]);
508507
/// # }
509508
/// ```
510-
pub fn new<'py, D: ToNpyDims>(py: Python<'py>, dims: D, is_fortran: bool) -> &'py Self {
509+
pub fn new<'py, D: IntoDimension>(py: Python<'py>, dims: D, is_fortran: bool) -> &'py Self {
511510
let flags = if is_fortran { 1 } else { 0 };
512511
unsafe { PyArray::new_(py, dims, ptr::null_mut(), flags) }
513512
}
514513

515-
unsafe fn new_<'py, D: ToNpyDims>(
514+
unsafe fn new_<'py, D: IntoDimension>(
516515
py: Python<'py>,
517516
dims: D,
518517
strides: *mut npy_intp,
519518
flag: c_int,
520519
) -> &'py Self {
520+
let dims = dims.into_dimension();
521521
let ptr = PY_ARRAY_API.PyArray_New(
522522
PY_ARRAY_API.get_type_object(npyffi::ArrayType::PyArray_Type),
523-
dims.dims_len(),
524-
dims.dims_ptr(),
523+
dims.ndim_cint(),
524+
dims.as_dims_ptr(),
525525
T::typenum_default(),
526526
strides, // strides
527527
ptr::null_mut(), // data
@@ -548,12 +548,13 @@ impl<T: TypeNum> PyArray<T> {
548548
/// assert_eq!(pyarray.as_array().unwrap(), array![[0, 0], [0, 0]].into_dyn());
549549
/// # }
550550
/// ```
551-
pub fn zeros<'py, D: ToNpyDims>(py: Python<'py>, dims: D, is_fortran: bool) -> &'py Self {
551+
pub fn zeros<'py, D: IntoDimension>(py: Python<'py>, dims: D, is_fortran: bool) -> &'py Self {
552+
let dims = dims.into_dimension();
552553
unsafe {
553554
let descr = PY_ARRAY_API.PyArray_DescrFromType(T::typenum_default());
554555
let ptr = PY_ARRAY_API.PyArray_Zeros(
555-
dims.dims_len(),
556-
dims.dims_ptr(),
556+
dims.ndim_cint(),
557+
dims.as_dims_ptr(),
557558
descr,
558559
if is_fortran { -1 } else { 0 },
559560
);
@@ -653,16 +654,17 @@ impl<T: TypeNum> PyArray<T> {
653654
/// # }
654655
/// ```
655656
#[inline(always)]
656-
pub fn reshape<'py, D: ToNpyDims>(&'py self, dims: D) -> Result<&Self, ErrorKind> {
657+
pub fn reshape<'py, D: IntoDimension>(&'py self, dims: D) -> Result<&Self, ErrorKind> {
657658
self.reshape_with_order(dims, NPY_ORDER::NPY_ANYORDER)
658659
}
659660

660661
/// Same as [reshape](method.reshape.html), but you can change the order of returned matrix.
661-
pub fn reshape_with_order<'py, D: ToNpyDims>(
662+
pub fn reshape_with_order<'py, D: IntoDimension>(
662663
&'py self,
663664
dims: D,
664665
order: NPY_ORDER,
665666
) -> Result<&Self, ErrorKind> {
667+
let dims = dims.into_dimension();
666668
let mut np_dims = dims.to_npy_dims();
667669
let ptr = unsafe {
668670
PY_ARRAY_API.PyArray_Newshape(
@@ -679,12 +681,13 @@ impl<T: TypeNum> PyArray<T> {
679681
}
680682

681683
// TODO: expose?
682-
fn resize_<'py, D: ToNpyDims>(
684+
fn resize_<'py, D: IntoDimension>(
683685
&'py self,
684686
dims: D,
685687
check_ref: c_int,
686688
order: NPY_ORDER,
687689
) -> Result<(), ErrorKind> {
690+
let dims = dims.into_dimension();
688691
let mut np_dims = dims.to_npy_dims();
689692
let res = unsafe {
690693
PY_ARRAY_API.PyArray_Resize(

src/convert.rs

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -45,49 +45,27 @@ where
4545
}
4646

4747
/// Utility trait to specify the dimention of array
48-
pub trait ToNpyDims {
49-
fn dims_len(&self) -> c_int;
50-
fn dims_ptr(&self) -> *mut npy_intp;
51-
fn dims_ref(&self) -> &[usize];
48+
pub trait ToNpyDims: Dimension {
49+
fn ndim_cint(&self) -> c_int {
50+
self.ndim() as c_int
51+
}
52+
fn as_dims_ptr(&self) -> *mut npy_intp {
53+
self.slice().as_ptr() as *mut npy_intp
54+
}
5255
fn to_npy_dims(&self) -> npyffi::PyArray_Dims {
5356
npyffi::PyArray_Dims {
54-
ptr: self.dims_ptr(),
55-
len: self.dims_len(),
57+
ptr: self.as_dims_ptr(),
58+
len: self.ndim_cint(),
5659
}
5760
}
61+
fn __private__(&self) -> PrivateMarker;
5862
}
5963

60-
macro_rules! array_dim_impls {
61-
($($N: expr)+) => {
62-
$(
63-
impl ToNpyDims for [usize; $N] {
64-
fn dims_len(&self) -> c_int {
65-
$N as c_int
66-
}
67-
fn dims_ptr(&self) -> *mut npy_intp {
68-
self.as_ptr() as *mut npy_intp
69-
}
70-
fn dims_ref(&self) -> &[usize] {
71-
self
72-
}
73-
}
74-
)+
64+
impl<T: Dimension> ToNpyDims for T {
65+
fn __private__(&self) -> PrivateMarker {
66+
PrivateMarker
7567
}
7668
}
7769

78-
array_dim_impls! {
79-
0 1 2 3 4 5 6 7 8 9
80-
10 11 12 13 14 15 16
81-
}
82-
83-
impl<'a> ToNpyDims for &'a [usize] {
84-
fn dims_len(&self) -> c_int {
85-
self.len() as c_int
86-
}
87-
fn dims_ptr(&self) -> *mut npy_intp {
88-
self.as_ptr() as *mut npy_intp
89-
}
90-
fn dims_ref(&self) -> &[usize] {
91-
*self
92-
}
93-
}
70+
#[doc(hidden)]
71+
pub struct PrivateMarker;

src/error.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ impl ErrorKind {
7878
.collect::<Vec<_>>()
7979
.into_boxed_slice();
8080
let dims_to = to_dim
81-
.dims_ref()
81+
.slice()
8282
.into_iter()
8383
.map(|&x| x)
8484
.collect::<Vec<_>>()

tests/array.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,15 @@ fn new_fortran_order() {
2828
assert!(arr.strides() == [size, dim[0] as isize * size],);
2929
}
3030

31+
#[test]
32+
fn tuple_as_dim() {
33+
let gil = pyo3::Python::acquire_gil();
34+
let dim = (3, 5);
35+
let arr = PyArray::<f64>::zeros(gil.python(), dim, false);
36+
assert!(arr.ndim() == 2);
37+
assert!(arr.dims() == [3, 5]);
38+
}
39+
3140
#[test]
3241
fn zeros() {
3342
let gil = pyo3::Python::acquire_gil();

0 commit comments

Comments
 (0)