Skip to content

Commit 2dda531

Browse files
authored
Allow *mut ffi::PyObject as an element of PyArray (#91)
* Allow PyObject as element of PyArray * Refactor test code * Allow *mut ffi::PyObject as PyArray's element instead of PyObject * Remove travis_wait
1 parent 1bbf535 commit 2dda531

File tree

6 files changed

+66
-27
lines changed

6 files changed

+66
-27
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ install:
4545

4646
script:
4747
- flake8 examples/
48-
- travis_wait ./ci/travis/test.sh
48+
- ./ci/travis/test.sh
4949

5050
deploy:
5151
- provider: script

appveyor.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@ build_script:
2424
- cargo build --verbose --features %FEATURES%
2525

2626
test_script:
27-
- cargo test --verbose --features %FEATURES%
27+
- cargo test --verbose --features %FEATURES% -- --test-threads=1
2828
- rustdoc --test README.md -L native="%PYTHON%\\libs" -L target/debug/deps/
2929
- cd examples/simple-extension && python setup.py install && python setup.py test

ci/travis/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
set -ex
44

55
cargo build --verbose --features $FEATURES
6-
cargo test --verbose --features $FEATURES
6+
cargo test --verbose --features $FEATURES -- --test-threads=1
77
rustdoc -L target/debug/deps/ --test README.md
88

99
for example in examples/*; do

src/array.rs

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -482,28 +482,6 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
482482
unsafe { ArrayViewMut::from_shape_ptr(self.ndarray_shape(), self.data()) }
483483
}
484484

485-
/// Get a copy of `PyArray` as
486-
/// [`ndarray::Array`](https://docs.rs/ndarray/0.12/ndarray/type.Array.html).
487-
///
488-
/// # Example
489-
/// ```
490-
/// # #[macro_use] extern crate ndarray; extern crate pyo3; extern crate numpy; fn main() {
491-
/// use numpy::PyArray;
492-
/// let gil = pyo3::Python::acquire_gil();
493-
/// let py_array = PyArray::arange(gil.python(), 0, 4, 1).reshape([2, 2]).unwrap();
494-
/// assert_eq!(
495-
/// py_array.to_owned_array(),
496-
/// array![[0, 1], [2, 3]]
497-
/// )
498-
/// # }
499-
/// ```
500-
pub fn to_owned_array(&self) -> Array<T, D> {
501-
unsafe {
502-
let vec = self.as_slice().to_owned();
503-
Array::from_shape_vec_unchecked(self.ndarray_shape(), vec)
504-
}
505-
}
506-
507485
/// Get an immutable reference of a specified element, without checking the
508486
/// passed index is valid.
509487
///
@@ -620,6 +598,30 @@ impl<T: TypeNum, D: Dimension> PyArray<T, D> {
620598
}
621599
}
622600

601+
impl<T: TypeNum + Clone, D: Dimension> PyArray<T, D> {
602+
/// Get a copy of `PyArray` as
603+
/// [`ndarray::Array`](https://docs.rs/ndarray/0.12/ndarray/type.Array.html).
604+
///
605+
/// # Example
606+
/// ```
607+
/// # #[macro_use] extern crate ndarray; extern crate pyo3; extern crate numpy; fn main() {
608+
/// use numpy::PyArray;
609+
/// let gil = pyo3::Python::acquire_gil();
610+
/// let py_array = PyArray::arange(gil.python(), 0, 4, 1).reshape([2, 2]).unwrap();
611+
/// assert_eq!(
612+
/// py_array.to_owned_array(),
613+
/// array![[0, 1], [2, 3]]
614+
/// )
615+
/// # }
616+
/// ```
617+
pub fn to_owned_array(&self) -> Array<T, D> {
618+
unsafe {
619+
let vec = self.as_slice().to_vec();
620+
Array::from_shape_vec_unchecked(self.ndarray_shape(), vec)
621+
}
622+
}
623+
}
624+
623625
impl<T: TypeNum> PyArray<T, Ix1> {
624626
/// Construct one-dimension PyArray from slice.
625627
///

src/types.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ pub use num_complex::Complex32 as c32;
55
pub use num_complex::Complex64 as c64;
66

77
use super::npyffi::NPY_TYPES;
8+
use pyo3::ffi::PyObject;
89

910
/// An enum type represents numpy data type.
1011
///
@@ -24,6 +25,7 @@ pub enum NpyDataType {
2425
Float64,
2526
Complex32,
2627
Complex64,
28+
PyObject,
2729
Unsupported,
2830
}
2931

@@ -45,6 +47,7 @@ impl NpyDataType {
4547
x if x == NPY_TYPES::NPY_DOUBLE as i32 => NpyDataType::Float64,
4648
x if x == NPY_TYPES::NPY_CFLOAT as i32 => NpyDataType::Complex32,
4749
x if x == NPY_TYPES::NPY_CDOUBLE as i32 => NpyDataType::Complex64,
50+
x if x == NPY_TYPES::NPY_OBJECT as i32 => NpyDataType::PyObject,
4851
_ => NpyDataType::Unsupported,
4952
}
5053
}
@@ -68,7 +71,7 @@ impl NpyDataType {
6871
}
6972
}
7073

71-
pub trait TypeNum: Clone {
74+
pub trait TypeNum {
7275
fn is_same_type(other: i32) -> bool;
7376
fn npy_data_type() -> NpyDataType;
7477
fn typenum_default() -> i32;
@@ -100,6 +103,7 @@ impl_type_num!(f32, Float32, NPY_FLOAT);
100103
impl_type_num!(f64, Float64, NPY_DOUBLE);
101104
impl_type_num!(c32, Complex32, NPY_CFLOAT);
102105
impl_type_num!(c64, Complex64, NPY_CDOUBLE);
106+
impl_type_num!(*mut PyObject, PyObject, NPY_OBJECT);
103107

104108
cfg_if! {
105109
if #[cfg(any(target_pointer_width = "32", windows))] {

tests/array.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ extern crate pyo3;
44

55
use ndarray::*;
66
use numpy::*;
7-
use pyo3::{prelude::*, types::PyDict, types::PyList};
7+
use pyo3::{prelude::*, types::PyDict, types::PyList, ToPyPointer};
88

99
#[test]
1010
fn new_c_order() {
@@ -267,3 +267,36 @@ fn into_pyarray_cant_resize() {
267267
let arr = a.into_pyarray(gil.python());
268268
assert!(arr.resize(100).is_err())
269269
}
270+
271+
// from pyo3, but modified for ease
272+
macro_rules! py_run {
273+
($py:expr, $val:expr, $code:expr) => {{
274+
let d = pyo3::types::PyDict::new($py);
275+
d.set_item(stringify!($val), &$val).unwrap();
276+
$py.run($code, None, Some(d))
277+
.map_err(|e| {
278+
e.print($py);
279+
$py.run("import sys; sys.stderr.flush()", None, None)
280+
.unwrap();
281+
})
282+
.expect($code)
283+
}};
284+
}
285+
286+
macro_rules! py_assert {
287+
($py:expr, $val:ident, $assertion:expr) => {
288+
py_run!($py, $val, concat!("assert ", $assertion))
289+
};
290+
}
291+
292+
#[test]
293+
fn into_obj_vec_to_pyarray() {
294+
let gil = pyo3::Python::acquire_gil();
295+
let py = gil.python();
296+
let dict = PyDict::new(py);
297+
let string = pyo3::types::PyString::new(py, "Hello python :)");
298+
let a = vec![dict.as_ptr(), string.as_ptr()];
299+
let arr = a.into_pyarray(py);
300+
py_assert!(py, arr, "arr[0] == {}");
301+
py_assert!(py, arr, "arr[1] == 'Hello python :)'");
302+
}

0 commit comments

Comments
 (0)