Skip to content

Commit 430e022

Browse files
authored
Merge pull request #69 from kngwyu/fixed-dim
Fixed dimension by PhantomData
2 parents 2976137 + bc271a2 commit 430e022

File tree

9 files changed

+469
-294
lines changed

9 files changed

+469
-294
lines changed

README.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ numpy = "0.3"
4545
``` rust
4646
extern crate numpy;
4747
extern crate pyo3;
48-
use numpy::{IntoPyResult, PyArray, get_array_module};
48+
use numpy::{IntoPyResult, PyArray1, get_array_module};
4949
use pyo3::prelude::{ObjectProtocol, PyDict, PyResult, Python};
5050

5151
fn main() -> Result<(), ()> {
@@ -62,7 +62,7 @@ fn main_<'py>(py: Python<'py>) -> PyResult<()> {
6262
let np = get_array_module(py)?;
6363
let dict = PyDict::new(py);
6464
dict.set_item("np", np)?;
65-
let pyarray: &PyArray<i32> = py
65+
let pyarray: &PyArray1<i32> = py
6666
.eval("np.array([1, 2, 3], dtype='int32')", Some(&dict), None)?
6767
.extract()?;
6868
let slice = pyarray.as_slice().into_pyresult("Array Cast failed")?;
@@ -95,7 +95,7 @@ extern crate numpy;
9595
extern crate pyo3;
9696

9797
use ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
98-
use numpy::{IntoPyResult, PyArray, ToPyArray};
98+
use numpy::{IntoPyResult, PyArrayDyn, ToPyArray};
9999
use pyo3::prelude::{pymodinit, PyModule, PyResult, Python};
100100

101101
#[pymodinit]
@@ -112,15 +112,20 @@ fn rust_ext(_py: Python, m: &PyModule) -> PyResult<()> {
112112

113113
// wrapper of `axpy`
114114
#[pyfn(m, "axpy")]
115-
fn axpy_py(py: Python, a: f64, x: &PyArray<f64>, y: &PyArray<f64>) -> PyResult<PyArray<f64>> {
115+
fn axpy_py(
116+
py: Python,
117+
a: f64,
118+
x: &PyArrayDyn<f64>,
119+
y: &PyArrayDyn<f64>,
120+
) -> PyResult<PyArrayDyn<f64>> {
116121
let x = x.as_array().into_pyresult("x must be f64 array")?;
117122
let y = y.as_array().into_pyresult("y must be f64 array")?;
118123
Ok(axpy(a, x, y).to_pyarray(py).to_owned(py))
119124
}
120125

121126
// wrapper of `mult`
122127
#[pyfn(m, "mult")]
123-
fn mult_py(_py: Python, a: f64, x: &PyArray<f64>) -> PyResult<()> {
128+
fn mult_py(_py: Python, a: f64, x: &PyArrayDyn<f64>) -> PyResult<()> {
124129
let x = x.as_array_mut().into_pyresult("x must be f64 array")?;
125130
mult(a, x);
126131
Ok(())

example/extensions/src/lib.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ extern crate numpy;
33
extern crate pyo3;
44

55
use ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
6-
use numpy::{IntoPyResult, PyArray, ToPyArray};
6+
use numpy::{IntoPyResult, PyArrayDyn, ToPyArray};
77
use pyo3::prelude::{pymodinit, PyModule, PyResult, Python};
88

99
#[pymodinit]
@@ -20,15 +20,20 @@ fn rust_ext(_py: Python, m: &PyModule) -> PyResult<()> {
2020

2121
// wrapper of `axpy`
2222
#[pyfn(m, "axpy")]
23-
fn axpy_py(py: Python, a: f64, x: &PyArray<f64>, y: &PyArray<f64>) -> PyResult<PyArray<f64>> {
23+
fn axpy_py(
24+
py: Python,
25+
a: f64,
26+
x: &PyArrayDyn<f64>,
27+
y: &PyArrayDyn<f64>,
28+
) -> PyResult<PyArrayDyn<f64>> {
2429
let x = x.as_array().into_pyresult("x must be f64 array")?;
2530
let y = y.as_array().into_pyresult("y must be f64 array")?;
2631
Ok(axpy(a, x, y).to_pyarray(py).to_owned(py))
2732
}
2833

2934
// wrapper of `mult`
3035
#[pyfn(m, "mult")]
31-
fn mult_py(_py: Python, a: f64, x: &PyArray<f64>) -> PyResult<()> {
36+
fn mult_py(_py: Python, a: f64, x: &PyArrayDyn<f64>) -> PyResult<()> {
3237
let x = x.as_array_mut().into_pyresult("x must be f64 array")?;
3338
mult(a, x);
3439
Ok(())

0 commit comments

Comments
 (0)