Skip to content

Commit 9727040

Browse files
committed
Fix README example
1 parent a6ab468 commit 9727040

File tree

1 file changed

+24
-33
lines changed

1 file changed

+24
-33
lines changed

README.md

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -45,37 +45,28 @@ numpy = "0.3"
4545
``` rust
4646
extern crate numpy;
4747
extern crate pyo3;
48-
use numpy::{IntoPyResult, PyArray, ToPyArray};
49-
use pyo3::prelude::{pymodinit, PyModule, PyResult, Python};
50-
51-
#[pymodinit]
52-
fn rust_ext(_py: Python, m: &PyModule) -> PyResult<()> {
53-
// immutable example
54-
fn axpy(a: f64, x: ArrayViewD<f64>, y: ArrayViewD<f64>) -> ArrayD<f64> {
55-
a * &x + &y
56-
}
57-
58-
// mutable example (no return)
59-
fn mult(a: f64, mut x: ArrayViewMutD<f64>) {
60-
x *= a;
61-
}
62-
63-
// wrapper of `axpy`
64-
#[pyfn(m, "axpy")]
65-
fn axpy_py(py: Python, a: f64, x: &PyArray<f64>, y: &PyArray<f64>) -> PyResult<PyArray<f64>> {
66-
let x = x.as_array().into_pyresult("x must be f64 array")?;
67-
let y = y.as_array().into_pyresult("y must be f64 array")?;
68-
Ok(axpy(a, x, y).to_pyarray(py).to_owned(py))
69-
}
70-
71-
// wrapper of `mult`
72-
#[pyfn(m, "mult")]
73-
fn mult_py(_py: Python, a: f64, x: &PyArray<f64>) -> PyResult<()> {
74-
let x = x.as_array_mut().into_pyresult("x must be f64 array")?;
75-
mult(a, x);
76-
Ok(())
77-
}
48+
use numpy::{IntoPyResult, PyArray, get_array_module};
49+
use pyo3::prelude::{ObjectProtocol, PyDict, PyResult, Python};
50+
51+
fn main() -> Result<(), ()> {
52+
let gil = Python::acquire_gil();
53+
main_(gil.python()).map_err(|e| {
54+
eprintln!("error! :{:?}", e);
55+
// we can't display python error type via ::std::fmt::Display
56+
// so print error here manually
57+
e.print_and_set_sys_last_vars(gil.python());
58+
})
59+
}
7860

61+
fn main_<'py>(py: Python<'py>) -> PyResult<()> {
62+
let np = get_array_module(py)?;
63+
let dict = PyDict::new(py);
64+
dict.set_item("np", np)?;
65+
let pyarray: &PyArray<i32> = py
66+
.eval("np.array([1, 2, 3], dtype='int32')", Some(&dict), None)?
67+
.extract()?;
68+
let slice = pyarray.as_slice().into_pyresult("Array Cast failed")?;
69+
assert_eq!(slice, &[1, 2, 3]);
7970
Ok(())
8071
}
8172
```
@@ -91,7 +82,7 @@ crate-type = ["cdylib"]
9182

9283
[dependencies]
9384
numpy = "0.3"
94-
ndarray = "0.11"
85+
ndarray = "0.12"
9586

9687
[dependencies.pyo3]
9788
version = "^0.4.1"
@@ -104,7 +95,7 @@ extern crate numpy;
10495
extern crate pyo3;
10596

10697
use ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
107-
use numpy::{IntoPyArray, IntoPyResult, PyArray};
98+
use numpy::{IntoPyResult, PyArray, ToPyArray};
10899
use pyo3::prelude::{pymodinit, PyModule, PyResult, Python};
109100

110101
#[pymodinit]
@@ -124,7 +115,7 @@ fn rust_ext(_py: Python, m: &PyModule) -> PyResult<()> {
124115
fn axpy_py(py: Python, a: f64, x: &PyArray<f64>, y: &PyArray<f64>) -> PyResult<PyArray<f64>> {
125116
let x = x.as_array().into_pyresult("x must be f64 array")?;
126117
let y = y.as_array().into_pyresult("y must be f64 array")?;
127-
Ok(axpy(a, x, y).into_pyarray(py).to_owned(py))
118+
Ok(axpy(a, x, y).to_pyarray(py).to_owned(py))
128119
}
129120

130121
// wrapper of `mult`

0 commit comments

Comments
 (0)