Skip to content

Commit 149de74

Browse files
committed
Fix all tests
1 parent b174540 commit 149de74

File tree

7 files changed

+101
-120
lines changed

7 files changed

+101
-120
lines changed

src/array.rs

Lines changed: 44 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ use super::*;
1313
/// Interface for [NumPy ndarray](https://docs.scipy.org/doc/numpy/reference/arrays.ndarray.html).
1414
pub struct PyArray<T>(PyObject, PhantomData<T>);
1515

16+
pub fn get_array_module(py: Python) -> PyResult<&PyModule> {
17+
PyModule::import(py, npyffi::array::MOD_NAME)
18+
}
19+
1620
pyobject_native_type_convert!(
1721
PyArray<T>,
1822
*npyffi::PY_ARRAY_API.get_type_object(npyffi::ArrayType::PyArray_Type),
@@ -75,10 +79,9 @@ impl<T> PyArray<T> {
7579
/// # Example
7680
/// ```
7781
/// # extern crate pyo3; extern crate numpy; fn main() {
78-
/// use numpy::{PyArray, PyArrayModule};
82+
/// use numpy::PyArray;
7983
/// let gil = pyo3::Python::acquire_gil();
80-
/// let np = PyArrayModule::import(gil.python()).unwrap();
81-
/// let arr = PyArray::<f64>::new(gil.python(), &np, &[4, 5, 6]);
84+
/// let arr = PyArray::<f64>::new(gil.python(), &[4, 5, 6]);
8285
/// assert_eq!(arr.ndim(), 3);
8386
/// # }
8487
/// ```
@@ -94,10 +97,9 @@ impl<T> PyArray<T> {
9497
/// # Example
9598
/// ```
9699
/// # extern crate pyo3; extern crate numpy; fn main() {
97-
/// use numpy::{PyArray, PyArrayModule};
100+
/// use numpy::PyArray;
98101
/// let gil = pyo3::Python::acquire_gil();
99-
/// let np = PyArrayModule::import(gil.python()).unwrap();
100-
/// let arr = PyArray::<f64>::new(gil.python(), &np, &[4, 5, 6]);
102+
/// let arr = PyArray::<f64>::new(gil.python(), &[4, 5, 6]);
101103
/// assert_eq!(arr.strides(), &[240, 48, 8]);
102104
/// # }
103105
/// ```
@@ -117,10 +119,9 @@ impl<T> PyArray<T> {
117119
/// # Example
118120
/// ```
119121
/// # extern crate pyo3; extern crate numpy; fn main() {
120-
/// use numpy::{PyArray, PyArrayModule};
122+
/// use numpy::PyArray;
121123
/// let gil = pyo3::Python::acquire_gil();
122-
/// let np = PyArrayModule::import(gil.python()).unwrap();
123-
/// let arr = PyArray::<f64>::new(gil.python(), &np, &[4, 5, 6]);
124+
/// let arr = PyArray::<f64>::new(gil.python(), &[4, 5, 6]);
124125
/// assert_eq!(arr.shape(), &[4, 5, 6]);
125126
/// # }
126127
/// ```
@@ -152,11 +153,10 @@ impl<T: TypeNum> PyArray<T> {
152153
/// # Example
153154
/// ```
154155
/// # extern crate pyo3; extern crate numpy; fn main() {
155-
/// use numpy::{PyArray, PyArrayModule};
156+
/// use numpy::PyArray;
156157
/// let gil = pyo3::Python::acquire_gil();
157-
/// let np = PyArrayModule::import(gil.python()).unwrap();
158158
/// let slice = vec![1, 2, 3, 4, 5].into_boxed_slice();
159-
/// let pyarray = PyArray::from_boxed_slice(gil.python(), &np, slice);
159+
/// let pyarray = PyArray::from_boxed_slice(gil.python(), slice);
160160
/// assert_eq!(pyarray.as_slice().unwrap(), &[1, 2, 3, 4, 5]);
161161
/// # }
162162
/// ```
@@ -169,12 +169,11 @@ impl<T: TypeNum> PyArray<T> {
169169
/// # Example
170170
/// ```
171171
/// # extern crate pyo3; extern crate numpy; fn main() {
172-
/// use numpy::{PyArray, PyArrayModule};
172+
/// use numpy::PyArray;
173173
/// use std::collections::BTreeSet;
174174
/// let gil = pyo3::Python::acquire_gil();
175-
/// let np = PyArrayModule::import(gil.python()).unwrap();
176175
/// let set: BTreeSet<u32> = [4, 3, 2, 5, 1].into_iter().cloned().collect();
177-
/// let pyarray = PyArray::from_iter(gil.python(), &np, set);
176+
/// let pyarray = PyArray::from_iter(gil.python(), set);
178177
/// assert_eq!(pyarray.as_slice().unwrap(), &[1, 2, 3, 4, 5]);
179178
/// # }
180179
/// ```
@@ -187,10 +186,9 @@ impl<T: TypeNum> PyArray<T> {
187186
/// # Example
188187
/// ```
189188
/// # extern crate pyo3; extern crate numpy; fn main() {
190-
/// use numpy::{PyArray, PyArrayModule};
189+
/// use numpy::PyArray;
191190
/// let gil = pyo3::Python::acquire_gil();
192-
/// let np = PyArrayModule::import(gil.python()).unwrap();
193-
/// let pyarray = PyArray::from_vec(gil.python(), &np, vec![1, 2, 3, 4, 5]);
191+
/// let pyarray = PyArray::from_vec(gil.python(), vec![1, 2, 3, 4, 5]);
194192
/// assert_eq!(pyarray.as_slice().unwrap(), &[1, 2, 3, 4, 5]);
195193
/// # }
196194
/// ```
@@ -206,13 +204,12 @@ impl<T: TypeNum> PyArray<T> {
206204
/// # Example
207205
/// ```
208206
/// # extern crate pyo3; extern crate numpy; #[macro_use] extern crate ndarray; fn main() {
209-
/// use numpy::{PyArray, PyArrayModule};
207+
/// use numpy::PyArray;
210208
/// let gil = pyo3::Python::acquire_gil();
211-
/// let np = PyArrayModule::import(gil.python()).unwrap();
212209
/// let vec2 = vec![vec![1, 2, 3]; 2];
213-
/// let pyarray = PyArray::from_vec2(gil.python(), &np, &vec2).unwrap();
210+
/// let pyarray = PyArray::from_vec2(gil.python(), &vec2).unwrap();
214211
/// assert_eq!(pyarray.as_array().unwrap(), array![[1, 2, 3], [1, 2, 3]].into_dyn());
215-
/// assert!(PyArray::from_vec2(gil.python(), &np, &vec![vec![1], vec![2, 3]]).is_err());
212+
/// assert!(PyArray::from_vec2(gil.python(), &vec![vec![1], vec![2, 3]]).is_err());
216213
/// # }
217214
/// ```
218215
pub fn from_vec2<'py>(py: Python<'py>, v: &Vec<Vec<T>>) -> Result<&'py Self, ArrayCastError> {
@@ -236,16 +233,15 @@ impl<T: TypeNum> PyArray<T> {
236233
/// # Example
237234
/// ```
238235
/// # extern crate pyo3; extern crate numpy; #[macro_use] extern crate ndarray; fn main() {
239-
/// use numpy::{PyArray, PyArrayModule};
236+
/// use numpy::PyArray;
240237
/// let gil = pyo3::Python::acquire_gil();
241-
/// let np = PyArrayModule::import(gil.python()).unwrap();
242238
/// let vec2 = vec![vec![vec![1, 2]; 2]; 2];
243-
/// let pyarray = PyArray::from_vec3(gil.python(), &np, &vec2).unwrap();
239+
/// let pyarray = PyArray::from_vec3(gil.python(), &vec2).unwrap();
244240
/// assert_eq!(
245241
/// pyarray.as_array().unwrap(),
246242
/// array![[[1, 2], [1, 2]], [[1, 2], [1, 2]]].into_dyn()
247243
/// );
248-
/// assert!(PyArray::from_vec3(gil.python(), &np, &vec![vec![vec![1], vec![]]]).is_err());
244+
/// assert!(PyArray::from_vec3(gil.python(), &vec![vec![vec![1], vec![]]]).is_err());
249245
/// # }
250246
/// ```
251247
pub fn from_vec3<'py>(
@@ -273,10 +269,9 @@ impl<T: TypeNum> PyArray<T> {
273269
/// # Example
274270
/// ```
275271
/// # extern crate pyo3; extern crate numpy; #[macro_use] extern crate ndarray; fn main() {
276-
/// use numpy::{PyArray, PyArrayModule};
272+
/// use numpy::PyArray;
277273
/// let gil = pyo3::Python::acquire_gil();
278-
/// let np = PyArrayModule::import(gil.python()).unwrap();
279-
/// let pyarray = PyArray::from_ndarray(gil.python(), &np, array![[1, 2], [3, 4]]);
274+
/// let pyarray = PyArray::from_ndarray(gil.python(), array![[1, 2], [3, 4]]);
280275
/// assert_eq!(pyarray.as_array().unwrap(), array![[1, 2], [3, 4]].into_dyn());
281276
/// # }
282277
/// ```
@@ -385,10 +380,9 @@ impl<T: TypeNum> PyArray<T> {
385380
/// # Example
386381
/// ```
387382
/// # extern crate pyo3; extern crate numpy; #[macro_use] extern crate ndarray; fn main() {
388-
/// use numpy::{PyArray, PyArrayModule};
383+
/// use numpy::PyArray;
389384
/// let gil = pyo3::Python::acquire_gil();
390-
/// let np = PyArrayModule::import(gil.python()).unwrap();
391-
/// let pyarray = PyArray::<i32>::new(gil.python(), &np, &[4, 5, 6]);
385+
/// let pyarray = PyArray::<i32>::new(gil.python(), &[4, 5, 6]);
392386
/// assert_eq!(pyarray.shape(), &[4, 5, 6]);
393387
/// # }
394388
/// ```
@@ -404,10 +398,9 @@ impl<T: TypeNum> PyArray<T> {
404398
/// # Example
405399
/// ```
406400
/// # extern crate pyo3; extern crate numpy; #[macro_use] extern crate ndarray; fn main() {
407-
/// use numpy::{PyArray, PyArrayModule};
401+
/// use numpy::PyArray;
408402
/// let gil = pyo3::Python::acquire_gil();
409-
/// let np = PyArrayModule::import(gil.python()).unwrap();
410-
/// let pyarray = PyArray::zeros(gil.python(), &np, &[2, 2], false);
403+
/// let pyarray = PyArray::zeros(gil.python(), &[2, 2], false);
411404
/// assert_eq!(pyarray.as_array().unwrap(), array![[0, 0], [0, 0]].into_dyn());
412405
/// # }
413406
/// ```
@@ -433,12 +426,11 @@ impl<T: TypeNum> PyArray<T> {
433426
/// # Example
434427
/// ```
435428
/// # extern crate pyo3; extern crate numpy; fn main() {
436-
/// use numpy::{PyArray, PyArrayModule, IntoPyArray};
429+
/// use numpy::{PyArray, IntoPyArray};
437430
/// let gil = pyo3::Python::acquire_gil();
438-
/// let np = PyArrayModule::import(gil.python()).unwrap();
439-
/// let pyarray = PyArray::<f64>::arange(gil.python(), &np, 2.0, 4.0, 0.5);
431+
/// let pyarray = PyArray::<f64>::arange(gil.python(), 2.0, 4.0, 0.5);
440432
/// assert_eq!(pyarray.as_slice().unwrap(), &[2.0, 2.5, 3.0, 3.5]);
441-
/// let pyarray = PyArray::<i32>::arange(gil.python(), &np, -2.0, 4.0, 3.0);
433+
/// let pyarray = PyArray::<i32>::arange(gil.python(), -2.0, 4.0, 3.0);
442434
/// assert_eq!(pyarray.as_slice().unwrap(), &[-2, 1]);
443435
/// # }
444436
pub fn arange<'py>(py: Python<'py>, start: f64, stop: f64, step: f64) -> &'py Self {
@@ -452,15 +444,14 @@ impl<T: TypeNum> PyArray<T> {
452444
/// # Example
453445
/// ```
454446
/// # extern crate pyo3; extern crate numpy; fn main() {
455-
/// use numpy::{PyArray, PyArrayModule, IntoPyArray};
447+
/// use numpy::{PyArray, IntoPyArray};
456448
/// let gil = pyo3::Python::acquire_gil();
457-
/// let np = PyArrayModule::import(gil.python()).unwrap();
458-
/// let pyarray_f = PyArray::<f64>::arange(gil.python(), &np, 2.0, 5.0, 1.0);
459-
/// let mut pyarray_i = PyArray::<i64>::new(gil.python(), &np, &[3]);
460-
/// assert!(pyarray_f.copy_to(&np, &mut pyarray_i).is_ok());
449+
/// let pyarray_f = PyArray::<f64>::arange(gil.python(), 2.0, 5.0, 1.0);
450+
/// let pyarray_i = PyArray::<i64>::new(gil.python(), &[3]);
451+
/// assert!(pyarray_f.copy_to(pyarray_i).is_ok());
461452
/// assert_eq!(pyarray_i.as_slice().unwrap(), &[2, 3, 4]);
462453
/// # }
463-
pub fn copy_to<U: TypeNum>(&self, other: &mut PyArray<U>) -> Result<(), ArrayCastError> {
454+
pub fn copy_to<U: TypeNum>(&self, other: &PyArray<U>) -> Result<(), ArrayCastError> {
464455
let self_ptr = self.as_array_ptr();
465456
let other_ptr = other.as_array_ptr();
466457
let result = unsafe { PY_ARRAY_API.PyArray_CopyInto(other_ptr, self_ptr) };
@@ -478,15 +469,14 @@ impl<T: TypeNum> PyArray<T> {
478469
/// # Example
479470
/// ```
480471
/// # extern crate pyo3; extern crate numpy; fn main() {
481-
/// use numpy::{PyArray, PyArrayModule, IntoPyArray};
472+
/// use numpy::{PyArray, IntoPyArray};
482473
/// let gil = pyo3::Python::acquire_gil();
483-
/// let np = PyArrayModule::import(gil.python()).unwrap();
484-
/// let pyarray_f = PyArray::<f64>::arange(gil.python(), &np, 2.0, 5.0, 1.0);
485-
/// let mut pyarray_i = PyArray::<i64>::new(gil.python(), &np, &[3]);
486-
/// assert!(pyarray_f.move_to(&np, &mut pyarray_i).is_ok());
474+
/// let pyarray_f = PyArray::<f64>::arange(gil.python(), 2.0, 5.0, 1.0);
475+
/// let pyarray_i = PyArray::<i64>::new(gil.python(), &[3]);
476+
/// assert!(pyarray_f.move_to(pyarray_i).is_ok());
487477
/// assert_eq!(pyarray_i.as_slice().unwrap(), &[2, 3, 4]);
488478
/// # }
489-
pub fn move_to<U: TypeNum>(&self, other: &mut PyArray<U>) -> Result<(), ArrayCastError> {
479+
pub fn move_to<U: TypeNum>(&self, other: &PyArray<U>) -> Result<(), ArrayCastError> {
490480
let self_ptr = self.as_array_ptr();
491481
let other_ptr = other.as_array_ptr();
492482
let result = unsafe { PY_ARRAY_API.PyArray_MoveInto(other_ptr, self_ptr) };
@@ -504,11 +494,10 @@ impl<T: TypeNum> PyArray<T> {
504494
/// # Example
505495
/// ```
506496
/// # extern crate pyo3; extern crate numpy; fn main() {
507-
/// use numpy::{PyArray, PyArrayModule, IntoPyArray};
497+
/// use numpy::{PyArray, IntoPyArray};
508498
/// let gil = pyo3::Python::acquire_gil();
509-
/// let np = PyArrayModule::import(gil.python()).unwrap();
510-
/// let pyarray_f = PyArray::<f64>::arange(gil.python(), &np, 2.0, 5.0, 1.0);
511-
/// let pyarray_i = pyarray_f.cast::<i32>(gil.python(), &np, false).unwrap();
499+
/// let pyarray_f = PyArray::<f64>::arange(gil.python(), 2.0, 5.0, 1.0);
500+
/// let pyarray_i = pyarray_f.cast::<i32>(gil.python(), false).unwrap();
512501
/// assert_eq!(pyarray_i.as_slice().unwrap(), &[2, 3, 4]);
513502
/// # }
514503
pub fn cast<'py, U: TypeNum>(

src/convert.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@ use super::*;
1515
/// # Example
1616
/// ```
1717
/// # extern crate pyo3; extern crate numpy; fn main() {
18-
/// use numpy::{PyArray, PyArrayModule, IntoPyArray};
18+
/// use numpy::{PyArray, IntoPyArray};
1919
/// let gil = pyo3::Python::acquire_gil();
20-
/// let np = PyArrayModule::import(gil.python()).unwrap();
21-
/// let py_array = vec![1, 2, 3].into_pyarray(gil.python(), &np);
20+
/// let py_array = vec![1, 2, 3].into_pyarray(gil.python());
2221
/// assert_eq!(py_array.as_slice().unwrap(), &[1, 2, 3]);
2322
/// # }
2423
/// ```

src/lib.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,11 @@
1616
//! extern crate numpy;
1717
//! extern crate pyo3;
1818
//! use pyo3::prelude::Python;
19-
//! use numpy::{IntoPyArray, PyArray, PyArrayModule};
19+
//! use numpy::{IntoPyArray, PyArray};
2020
//! fn main() {
2121
//! let gil = Python::acquire_gil();
2222
//! let py = gil.python();
23-
//! let np = PyArrayModule::import(py).unwrap();
24-
//! let py_array = array![[1i64, 2], [3, 4]].into_pyarray(py, &np);
23+
//! let py_array = array![[1i64, 2], [3, 4]].into_pyarray(py);
2524
//! assert_eq!(
2625
//! py_array.as_array().unwrap(),
2726
//! array![[1i64, 2], [3, 4]].into_dyn(),
@@ -44,7 +43,7 @@ pub mod error;
4443
pub mod npyffi;
4544
pub mod types;
4645

47-
pub use array::PyArray;
46+
pub use array::{get_array_module, PyArray};
4847
pub use convert::IntoPyArray;
4948
pub use error::*;
5049
pub use npyffi::{PY_ARRAY_API, PY_UFUNC_API};

src/npyffi/array.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@ use pyo3::ffi::{self, PyObject, PyTypeObject};
44
use std::ops::Deref;
55
use std::os::raw::*;
66
use std::ptr;
7-
use std::sync::{Once, ONCE_INIT};
87

98
use npyffi::*;
109

11-
const MOD_NAME: &str = "numpy.core.multiarray";
10+
pub(crate) const MOD_NAME: &str = "numpy.core.multiarray";
1211
const CAPSULE_NAME: &str = "_ARRAY_API";
1312

1413
pub static PY_ARRAY_API: PyArrayAPI = PyArrayAPI {
@@ -22,14 +21,14 @@ pub struct PyArrayAPI {
2221
impl Deref for PyArrayAPI {
2322
type Target = PyArrayAPI_Inner;
2423
fn deref(&self) -> &Self::Target {
25-
static INIT_API: Once = ONCE_INIT;
2624
static mut ARRAY_API_CACHE: PyArrayAPI_Inner = PyArrayAPI_Inner(ptr::null());
27-
INIT_API.call_once(|| {
28-
unsafe {
25+
unsafe {
26+
// TODO: this operation is 'mostly safe' because of GIL, but not completely thread safe
27+
if ARRAY_API_CACHE.0.is_null() {
2928
ARRAY_API_CACHE = PyArrayAPI_Inner(get_numpy_api(MOD_NAME, CAPSULE_NAME));
30-
};
31-
});
32-
unsafe { &ARRAY_API_CACHE }
29+
}
30+
&ARRAY_API_CACHE
31+
}
3332
}
3433
}
3534

@@ -367,6 +366,8 @@ pub unsafe fn PyArray_CheckExact(op: *mut PyObject) -> c_int {
367366

368367
#[test]
369368
fn call_api() {
369+
use pyo3::Python;
370+
let _gil = Python::acquire_gil();
370371
unsafe {
371372
assert_eq!(
372373
PY_ARRAY_API.PyArray_MultiplyIntList([1, 2, 3].as_mut_ptr(), 3),

src/npyffi/mod.rs

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,26 @@
99
//! - http://docs.python.jp/3/c-api/
1010
//! - http://dgrunwald.github.io/rust-pyo3/doc/pyo3/
1111
12-
use pyo3::{ffi, ObjectProtocol, Python, ToPyPointer};
12+
use pyo3::ffi;
13+
use std::ffi::CString;
1314
use std::os::raw::c_void;
1415
use std::ptr::null_mut;
1516

1617
fn get_numpy_api(module: &str, capsule: &str) -> *const *const c_void {
17-
let gil = Python::acquire_gil();
18-
let numpy = gil
19-
.python()
20-
.import("numpy.core.multiarray")
21-
.expect("Failed to import numpy.core.multiarray");
22-
let capsule = numpy
23-
.getattr("_ARRAY_API")
24-
.expect("Failed to import numpy.core.multiarray._ARRAY_API");
25-
unsafe { ffi::PyCapsule_GetPointer(capsule.as_ptr(), null_mut()) as *const *const c_void }
18+
let module = CString::new(module).unwrap();
19+
let capsule = CString::new(capsule).unwrap();
20+
unsafe {
21+
assert_ne!(
22+
ffi::Py_IsInitialized(),
23+
0,
24+
r"Numpy API is called before initializing Python!
25+
Please make sure that you get gil, by `let gil = Python::acquire_gil();`"
26+
);
27+
let numpy = ffi::PyImport_ImportModule(module.as_ptr());
28+
assert!(!numpy.is_null(), "Failed to import numpy module");
29+
let capsule = ffi::PyObject_GetAttrString(numpy as *mut ffi::PyObject, capsule.as_ptr());
30+
ffi::PyCapsule_GetPointer(capsule, null_mut()) as *const *const c_void
31+
}
2632
}
2733

2834
// Define Array&UFunc APIs

src/npyffi/ufunc.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
use std::ops::Deref;
44
use std::os::raw::*;
55
use std::ptr;
6-
use std::sync::{Once, ONCE_INIT};
76

87
use pyo3::ffi::PyObject;
98

@@ -25,14 +24,14 @@ pub struct PyUFuncAPI {
2524
impl Deref for PyUFuncAPI {
2625
type Target = PyUFuncAPI_Inner;
2726
fn deref(&self) -> &Self::Target {
28-
static INIT_API: Once = ONCE_INIT;
2927
static mut UFUNC_API_CACHE: PyUFuncAPI_Inner = PyUFuncAPI_Inner(ptr::null());
30-
INIT_API.call_once(|| {
31-
unsafe {
28+
unsafe {
29+
// TODO: this operation is 'mostly safe' because of GIL, but not completely thread safe
30+
if UFUNC_API_CACHE.0.is_null() {
3231
UFUNC_API_CACHE = PyUFuncAPI_Inner(get_numpy_api(MOD_NAME, CAPSULE_NAME));
33-
};
34-
});
35-
unsafe { &UFUNC_API_CACHE }
32+
}
33+
&UFUNC_API_CACHE
34+
}
3635
}
3736
}
3837

0 commit comments

Comments
 (0)