Skip to content

Commit 8d510db

Browse files
committed
Introduce PyArrayDescr
1 parent 9703a15 commit 8d510db

File tree

5 files changed

+50
-12
lines changed

5 files changed

+50
-12
lines changed

src/array.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ impl<T, D> type_object::PySizedLayout<PyArray<T, D>> for npyffi::PyArrayObject {
103103
pyobject_native_type_convert!(
104104
PyArray<T, D>,
105105
npyffi::PyArrayObject,
106-
*npyffi::PY_ARRAY_API.get_type_object(npyffi::ArrayType::PyArray_Type),
106+
*npyffi::PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type),
107107
Some("numpy"),
108108
npyffi::PyArray_Check,
109109
T, D
@@ -386,7 +386,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
386386
{
387387
let dims = dims.into_dimension();
388388
let ptr = PY_ARRAY_API.PyArray_New(
389-
PY_ARRAY_API.get_type_object(npyffi::ArrayType::PyArray_Type),
389+
PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type),
390390
dims.ndim_cint(),
391391
dims.as_dims_ptr(),
392392
T::ffi_dtype() as i32,
@@ -415,7 +415,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
415415
.create_cell(py)
416416
.expect("Object creation failed.");
417417
let ptr = PY_ARRAY_API.PyArray_New(
418-
PY_ARRAY_API.get_type_object(npyffi::ArrayType::PyArray_Type),
418+
PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type),
419419
dims.ndim_cint(),
420420
dims.as_dims_ptr(),
421421
T::ffi_dtype() as i32,

src/dtype.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
use crate::npyffi::{NpyTypes, PyArray_Descr, PY_ARRAY_API};
2+
use pyo3::ffi;
3+
use pyo3::prelude::*;
4+
use std::os::raw::c_int;
5+
6+
pub struct PyArrayDescr(PyAny);
7+
8+
pyobject_native_type_core!(
9+
PyArrayDescr,
10+
PyArray_Descr,
11+
*PY_ARRAY_API.get_type_object(NpyTypes::PyArrayDescr_Type),
12+
Some("numpy"),
13+
arraydescr_check
14+
);
15+
16+
pyobject_native_type_fmt!(PyArrayDescr);
17+
18+
unsafe fn arraydescr_check(op: *mut ffi::PyObject) -> c_int {
19+
ffi::PyObject_TypeCheck(
20+
op,
21+
PY_ARRAY_API.get_type_object(NpyTypes::PyArrayDescr_Type),
22+
)
23+
}

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ extern crate pyo3;
3838

3939
pub mod array;
4040
pub mod convert;
41+
mod dtype;
4142
mod error;
4243
pub mod npyffi;
4344
pub mod npyiter;
@@ -50,6 +51,7 @@ pub use crate::array::{
5051
PyArrayDyn,
5152
};
5253
pub use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
54+
pub use crate::dtype::PyArrayDescr;
5355
pub use crate::error::{FromVecError, NotContiguousError, ShapeError};
5456
pub use crate::npyffi::{PY_ARRAY_API, PY_UFUNC_API};
5557
pub use crate::npyiter::{

src/npyffi/array.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -320,20 +320,19 @@ macro_rules! impl_array_type {
320320
($(($offset:expr, $tname:ident)),*) => {
321321
/// All type objects that numpy has.
322322
#[allow(non_camel_case_types)]
323-
#[repr(i32)]
324-
pub enum ArrayType { $($tname),* }
323+
pub enum NpyTypes { $($tname),* }
325324
impl PyArrayAPI {
326325
/// Get the pointer of the type object that `self` refers.
327-
pub unsafe fn get_type_object(&self, ty: ArrayType) -> *mut PyTypeObject {
326+
pub unsafe fn get_type_object(&self, ty: NpyTypes) -> *mut PyTypeObject {
328327
match ty {
329-
$( ArrayType::$tname => *(self.get($offset)) as *mut PyTypeObject ),*
328+
$( NpyTypes::$tname => *(self.get($offset)) as _ ),*
330329
}
331330
}
332331
}
333332
}
334-
} // impl_array_type!;
333+
}
335334

336-
impl_array_type!(
335+
impl_array_type! {
337336
(1, PyBigArray_Type),
338337
(2, PyArray_Type),
339338
(3, PyArrayDescr_Type),
@@ -373,18 +372,18 @@ impl_array_type!(
373372
(37, PyStringArrType_Type),
374373
(38, PyUnicodeArrType_Type),
375374
(39, PyVoidArrType_Type)
376-
);
375+
}
377376

378377
/// Checks that `op` is an instance of `PyArray` or not.
379378
#[allow(non_snake_case)]
380379
pub unsafe fn PyArray_Check(op: *mut PyObject) -> c_int {
381-
ffi::PyObject_TypeCheck(op, PY_ARRAY_API.get_type_object(ArrayType::PyArray_Type))
380+
ffi::PyObject_TypeCheck(op, PY_ARRAY_API.get_type_object(NpyTypes::PyArray_Type))
382381
}
383382

384383
/// Checks that `op` is an exact instance of `PyArray` or not.
385384
#[allow(non_snake_case)]
386385
pub unsafe fn PyArray_CheckExact(op: *mut PyObject) -> c_int {
387-
(ffi::Py_TYPE(op) == PY_ARRAY_API.get_type_object(ArrayType::PyArray_Type)) as c_int
386+
(ffi::Py_TYPE(op) == PY_ARRAY_API.get_type_object(NpyTypes::PyArray_Type)) as _
388387
}
389388

390389
#[test]

tests/array.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,17 @@ fn handle_negative_strides() {
234234
.unwrap();
235235
assert_eq!(negstr_pyarr.to_owned_array(), arr.slice(s![..;-1, ..]));
236236
}
237+
238+
#[test]
239+
fn dtype_from_py() {
240+
let gil = pyo3::Python::acquire_gil();
241+
let py = gil.python();
242+
let arr = array![[2, 3], [4, 5u32]];
243+
let pyarr = arr.to_pyarray(py);
244+
let dtype: &numpy::PyArrayDescr = py
245+
.eval("a.dtype", Some([("a", pyarr)].into_py_dict(py)), None)
246+
.unwrap()
247+
.downcast()
248+
.unwrap();
249+
assert_eq!(&format!("{:?}", dtype), "dtype('uint32')");
250+
}

0 commit comments

Comments
 (0)