Skip to content

Commit 8f46870

Browse files
committed
Fix example and README
1 parent 149de74 commit 8f46870

File tree

3 files changed

+14
-21
lines changed

3 files changed

+14
-21
lines changed

README.md

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ numpy = "0.3"
4545
``` rust
4646
extern crate numpy;
4747
extern crate pyo3;
48-
use numpy::{IntoPyResult, PyArray, PyArrayModule};
48+
use numpy::{IntoPyResult, PyArray, get_array_module};
4949
use pyo3::prelude::{ObjectProtocol, PyDict, PyResult, Python};
5050

5151
fn main() -> Result<(), ()> {
@@ -59,9 +59,9 @@ fn main() -> Result<(), ()> {
5959
}
6060

6161
fn main_<'py>(py: Python<'py>) -> PyResult<()> {
62-
let np = PyArrayModule::import(py)?;
62+
let np = get_array_module(py)?;
6363
let dict = PyDict::new(py);
64-
dict.set_item("np", np.as_pymodule())?;
64+
dict.set_item("np", np)?;
6565
let pyarray: &PyArray<i32> = py
6666
.eval("np.array([1, 2, 3], dtype='int32')", Some(&dict), None)?
6767
.extract()?;
@@ -95,16 +95,11 @@ extern crate numpy;
9595
extern crate pyo3;
9696

9797
use ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
98-
use numpy::{IntoPyArray, IntoPyResult, PyArray, PyArrayModule};
98+
use numpy::{IntoPyArray, IntoPyResult, PyArray};
9999
use pyo3::prelude::{pymodinit, PyModule, PyResult, Python};
100100

101101
#[pymodinit]
102-
fn rust_ext(py: Python, m: &PyModule) -> PyResult<()> {
103-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
104-
// You **must** write this statement for the PyArray type checker to work correctly
105-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
106-
let _np = PyArrayModule::import(py)?;
107-
102+
fn rust_ext(_py: Python, m: &PyModule) -> PyResult<()> {
108103
// immutable example
109104
fn axpy(a: f64, x: ArrayViewD<f64>, y: ArrayViewD<f64>) -> ArrayD<f64> {
110105
a * &x + &y
@@ -118,10 +113,9 @@ fn rust_ext(py: Python, m: &PyModule) -> PyResult<()> {
118113
// wrapper of `axpy`
119114
#[pyfn(m, "axpy")]
120115
fn axpy_py(py: Python, a: f64, x: &PyArray<f64>, y: &PyArray<f64>) -> PyResult<PyArray<f64>> {
121-
let np = PyArrayModule::import(py)?;
122116
let x = x.as_array().into_pyresult("x must be f64 array")?;
123117
let y = y.as_array().into_pyresult("y must be f64 array")?;
124-
Ok(axpy(a, x, y).into_pyarray(py, &np))
118+
Ok(axpy(a, x, y).into_pyarray(py).to_owned(py))
125119
}
126120

127121
// wrapper of `mult`

example/extensions/src/lib.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,11 @@ extern crate numpy;
33
extern crate pyo3;
44

55
use ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
6-
use numpy::{IntoPyArray, IntoPyResult, PyArray, PyArrayModule};
6+
use numpy::{IntoPyArray, IntoPyResult, PyArray};
77
use pyo3::prelude::{pymodinit, PyModule, PyResult, Python};
88

99
#[pymodinit]
10-
fn rust_ext(py: Python, m: &PyModule) -> PyResult<()> {
11-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
12-
// You **must** write this statement for the PyArray type checker to work correctly
13-
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
14-
let _np = PyArrayModule::import(py)?;
15-
10+
fn rust_ext(_py: Python, m: &PyModule) -> PyResult<()> {
1611
// immutable example
1712
fn axpy(a: f64, x: ArrayViewD<f64>, y: ArrayViewD<f64>) -> ArrayD<f64> {
1813
a * &x + &y
@@ -26,10 +21,9 @@ fn rust_ext(py: Python, m: &PyModule) -> PyResult<()> {
2621
// wrapper of `axpy`
2722
#[pyfn(m, "axpy")]
2823
fn axpy_py(py: Python, a: f64, x: &PyArray<f64>, y: &PyArray<f64>) -> PyResult<PyArray<f64>> {
29-
let np = PyArrayModule::import(py)?;
3024
let x = x.as_array().into_pyresult("x must be f64 array")?;
3125
let y = y.as_array().into_pyresult("y must be f64 array")?;
32-
Ok(axpy(a, x, y).into_pyarray(py, &np))
26+
Ok(axpy(a, x, y).into_pyarray(py).to_owned(py))
3327
}
3428

3529
// wrapper of `mult`

src/array.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ impl<T> PyArray<T> {
6262
self.as_ptr() as _
6363
}
6464

65+
pub fn to_owned(&self, py: Python) -> Self {
66+
let obj = unsafe { PyObject::from_borrowed_ptr(py, self.as_ptr()) };
67+
PyArray(obj, PhantomData)
68+
}
69+
6570
/// Constructs `PyArray` from raw python object without incrementing reference counts.
6671
pub unsafe fn from_owned_ptr(py: Python, ptr: *mut pyo3::ffi::PyObject) -> &Self {
6772
py.from_owned_ptr(ptr)

0 commit comments

Comments
 (0)