Skip to content

Commit e85e3a9

Browse files
authored
Merge pull request #157 from PyO3/dtype
Introduce PyArrayDescr
2 parents 9703a15 + 41cb84a commit e85e3a9

File tree

7 files changed

+184
-65
lines changed

7 files changed

+184
-65
lines changed

src/array.rs

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ use std::{cell::Cell, mem, os::raw::c_int, ptr, slice};
1010
use std::{iter::ExactSizeIterator, marker::PhantomData};
1111

1212
use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
13+
use crate::dtype::Element;
1314
use crate::error::{FromVecError, NotContiguousError, ShapeError};
1415
use crate::slice_box::SliceBox;
15-
use crate::types::Element;
1616

1717
/// A safe, static-typed interface for
1818
/// [NumPy ndarray](https://numpy.org/doc/stable/reference/arrays.ndarray.html).
@@ -103,7 +103,7 @@ impl<T, D> type_object::PySizedLayout<PyArray<T, D>> for npyffi::PyArrayObject {
103103
pyobject_native_type_convert!(
104104
PyArray<T, D>,
105105
npyffi::PyArrayObject,
106-
*npyffi::PY_ARRAY_API.get_type_object(npyffi::ArrayType::PyArray_Type),
106+
*npyffi::PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type),
107107
Some("numpy"),
108108
npyffi::PyArray_Check,
109109
T, D
@@ -136,12 +136,12 @@ impl<'a, T: Element, D: Dimension> FromPyObject<'a> for &'a PyArray<T, D> {
136136
}
137137
&*(ob as *const PyAny as *const PyArray<T, D>)
138138
};
139-
let type_ = unsafe { (*(*array.as_array_ptr()).descr).type_num };
139+
let dtype = array.dtype();
140140
let dim = array.shape().len();
141-
if T::is_same_type(type_) && D::NDIM.map(|n| n == dim).unwrap_or(true) {
141+
if T::is_same_type(dtype) && D::NDIM.map(|n| n == dim).unwrap_or(true) {
142142
Ok(array)
143143
} else {
144-
Err(ShapeError::new(type_, dim, T::DATA_TYPE, D::NDIM).into())
144+
Err(ShapeError::new(dtype, dim, T::DATA_TYPE, D::NDIM).into())
145145
}
146146
}
147147
}
@@ -152,6 +152,22 @@ impl<T, D> PyArray<T, D> {
152152
self.as_ptr() as _
153153
}
154154

155+
/// Returns `dtype` of the array.
156+
/// Counterpart of `array.dtype` in Python.
157+
///
158+
/// # Example
159+
/// ```
160+
/// pyo3::Python::with_gil(|py| {
161+
/// let array = numpy::PyArray::from_vec(py, vec![1, 2, 3i32]);
162+
/// let dtype = array.dtype();
163+
/// assert_eq!(dtype.get_datatype().unwrap(), numpy::DataType::Int32);
164+
/// });
165+
/// ```
166+
pub fn dtype(&self) -> &crate::PyArrayDescr {
167+
let descr_ptr = unsafe { (*self.as_array_ptr()).descr };
168+
unsafe { pyo3::FromPyPointer::from_borrowed_ptr(self.py(), descr_ptr as _) }
169+
}
170+
155171
#[inline(always)]
156172
fn check_flag(&self, flag: c_int) -> bool {
157173
unsafe { *self.as_array_ptr() }.flags & flag == flag
@@ -386,10 +402,10 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
386402
{
387403
let dims = dims.into_dimension();
388404
let ptr = PY_ARRAY_API.PyArray_New(
389-
PY_ARRAY_API.get_type_object(npyffi::ArrayType::PyArray_Type),
405+
PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type),
390406
dims.ndim_cint(),
391407
dims.as_dims_ptr(),
392-
T::ffi_dtype() as i32,
408+
T::npy_type() as i32,
393409
strides as *mut _, // strides
394410
ptr::null_mut(), // data
395411
0, // itemsize
@@ -415,10 +431,10 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
415431
.create_cell(py)
416432
.expect("Object creation failed.");
417433
let ptr = PY_ARRAY_API.PyArray_New(
418-
PY_ARRAY_API.get_type_object(npyffi::ArrayType::PyArray_Type),
434+
PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type),
419435
dims.ndim_cint(),
420436
dims.as_dims_ptr(),
421-
T::ffi_dtype() as i32,
437+
T::npy_type() as i32,
422438
strides as *mut _, // strides
423439
data_ptr as _, // data
424440
mem::size_of::<T>() as i32, // itemsize
@@ -450,11 +466,11 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
450466
{
451467
let dims = dims.into_dimension();
452468
unsafe {
453-
let descr = PY_ARRAY_API.PyArray_DescrFromType(T::ffi_dtype() as i32);
469+
let dtype = T::get_dtype(py);
454470
let ptr = PY_ARRAY_API.PyArray_Zeros(
455471
dims.ndim_cint(),
456472
dims.as_dims_ptr(),
457-
descr,
473+
dtype.into_ptr() as _,
458474
if is_fortran { -1 } else { 0 },
459475
);
460476
Self::from_owned_ptr(py, ptr)
@@ -941,10 +957,10 @@ impl<T: Element, D> PyArray<T, D> {
941957
/// assert_eq!(pyarray_i.readonly().as_slice().unwrap(), &[2, 3, 4]);
942958
pub fn cast<'py, U: Element>(&'py self, is_fortran: bool) -> PyResult<&'py PyArray<U, D>> {
943959
let ptr = unsafe {
944-
let descr = PY_ARRAY_API.PyArray_DescrFromType(U::ffi_dtype() as i32);
960+
let dtype = U::get_dtype(self.py());
945961
PY_ARRAY_API.PyArray_CastToType(
946962
self.as_array_ptr(),
947-
descr,
963+
dtype.into_ptr() as _,
948964
if is_fortran { -1 } else { 0 },
949965
)
950966
};
@@ -1028,7 +1044,7 @@ impl<T: Element + AsPrimitive<f64>> PyArray<T, Ix1> {
10281044
start.as_(),
10291045
stop.as_(),
10301046
step.as_(),
1031-
T::ffi_dtype() as i32,
1047+
T::npy_type() as i32,
10321048
);
10331049
Self::from_owned_ptr(py, ptr)
10341050
}

src/types.rs renamed to src/dtype.rs

Lines changed: 123 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,89 @@
1-
//! Implements conversion utitlities.
2-
/// alias of Complex32
1+
use crate::npyffi::{NpyTypes, PyArray_Descr, NPY_TYPES, PY_ARRAY_API};
32
pub use num_complex::Complex32 as c32;
4-
/// alias of Complex64
53
pub use num_complex::Complex64 as c64;
4+
use pyo3::ffi;
5+
use pyo3::prelude::*;
6+
use pyo3::types::PyType;
7+
use pyo3::{AsPyPointer, PyNativeType};
8+
use std::os::raw::c_int;
69

7-
use super::npyffi::NPY_TYPES;
10+
/// Binding of [`numpy.dtype`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html).
11+
///
12+
/// # Example
13+
/// ```
14+
/// use pyo3::types::IntoPyDict;
15+
/// pyo3::Python::with_gil(|py| {
16+
/// let locals = [("np", numpy::get_array_module(py).unwrap())].into_py_dict(py);
17+
/// let dtype: &numpy::PyArrayDescr = py
18+
/// .eval("np.array([1, 2, 3.0]).dtype", Some(locals), None)
19+
/// .unwrap()
20+
/// .downcast()
21+
/// .unwrap();
22+
/// assert_eq!(dtype.get_datatype().unwrap(), numpy::DataType::Float64);
23+
/// });
24+
/// ```
25+
pub struct PyArrayDescr(PyAny);
26+
27+
pyobject_native_type_core!(
28+
PyArrayDescr,
29+
PyArray_Descr,
30+
*PY_ARRAY_API.get_type_object(NpyTypes::PyArrayDescr_Type),
31+
Some("numpy"),
32+
arraydescr_check
33+
);
34+
35+
pyobject_native_type_fmt!(PyArrayDescr);
36+
37+
unsafe fn arraydescr_check(op: *mut ffi::PyObject) -> c_int {
38+
ffi::PyObject_TypeCheck(
39+
op,
40+
PY_ARRAY_API.get_type_object(NpyTypes::PyArrayDescr_Type),
41+
)
42+
}
43+
44+
impl PyArrayDescr {
45+
/// Returns `self` as `*mut PyArray_Descr`.
46+
pub fn as_dtype_ptr(&self) -> *mut PyArray_Descr {
47+
self.as_ptr() as _
48+
}
849

9-
/// An enum type represents numpy data type.
50+
/// Returns the internal `PyType` that this `dtype` holds.
51+
///
52+
/// # Example
53+
/// ```
54+
/// pyo3::Python::with_gil(|py| {
55+
/// let array = numpy::PyArray::from_vec(py, vec![0.0, 1.0, 2.0f64]);
56+
/// let dtype = array.dtype();
57+
/// assert_eq!(dtype.get_type().name().to_string(), "numpy.float64");
58+
/// });
59+
/// ```
60+
pub fn get_type(&self) -> &PyType {
61+
let dtype_type_ptr = unsafe { *self.as_dtype_ptr() }.typeobj;
62+
unsafe { PyType::from_type_ptr(self.py(), dtype_type_ptr) }
63+
}
64+
65+
/// Returns the data type as `DataType` enum.
66+
pub fn get_datatype(&self) -> Option<DataType> {
67+
DataType::from_typenum(self.get_typenum())
68+
}
69+
70+
fn from_npy_type(py: Python, npy_type: NPY_TYPES) -> &Self {
71+
unsafe {
72+
let descr = PY_ARRAY_API.PyArray_DescrFromType(npy_type as i32);
73+
py.from_owned_ptr(descr as _)
74+
}
75+
}
76+
77+
fn get_typenum(&self) -> std::os::raw::c_int {
78+
unsafe { *self.as_dtype_ptr() }.type_num
79+
}
80+
}
81+
82+
/// Represents numpy data type.
1083
///
11-
/// This type is mainly for displaying error, and user don't have to use it directly.
84+
/// This is an incomplete counterpart of
85+
/// [Enumerated Types](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types)
86+
/// in numpy C-API.
1287
#[derive(Clone, Debug, Eq, PartialEq)]
1388
pub enum DataType {
1489
Bool,
@@ -28,8 +103,10 @@ pub enum DataType {
28103
}
29104

30105
impl DataType {
31-
pub(crate) fn from_i32(npy_t: i32) -> Option<Self> {
32-
Some(match npy_t {
106+
/// Construct `DataType` from
107+
/// [Enumerated Types](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types).
108+
pub fn from_typenum(typenum: c_int) -> Option<Self> {
109+
Some(match typenum {
33110
x if x == NPY_TYPES::NPY_BOOL as i32 => DataType::Bool,
34111
x if x == NPY_TYPES::NPY_BYTE as i32 => DataType::Int8,
35112
x if x == NPY_TYPES::NPY_SHORT as i32 => DataType::Int16,
@@ -49,6 +126,28 @@ impl DataType {
49126
_ => return None,
50127
})
51128
}
129+
130+
/// Convert `self` into
131+
/// [Enumerated Types](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types).
132+
pub fn into_ctype(self) -> NPY_TYPES {
133+
match self {
134+
DataType::Bool => NPY_TYPES::NPY_BOOL,
135+
DataType::Int8 => NPY_TYPES::NPY_BYTE,
136+
DataType::Int16 => NPY_TYPES::NPY_SHORT,
137+
DataType::Int32 => NPY_TYPES::NPY_INT,
138+
DataType::Int64 => NPY_TYPES::NPY_LONGLONG,
139+
DataType::Uint8 => NPY_TYPES::NPY_UBYTE,
140+
DataType::Uint16 => NPY_TYPES::NPY_USHORT,
141+
DataType::Uint32 => NPY_TYPES::NPY_UINT,
142+
DataType::Uint64 => NPY_TYPES::NPY_ULONGLONG,
143+
DataType::Float32 => NPY_TYPES::NPY_FLOAT,
144+
DataType::Float64 => NPY_TYPES::NPY_DOUBLE,
145+
DataType::Complex32 => NPY_TYPES::NPY_CFLOAT,
146+
DataType::Complex64 => NPY_TYPES::NPY_CDOUBLE,
147+
DataType::Object => NPY_TYPES::NPY_OBJECT,
148+
}
149+
}
150+
52151
#[inline(always)]
53152
fn from_clong(is_usize: bool) -> Option<Self> {
54153
if cfg!(any(target_pointer_width = "32", windows)) {
@@ -67,43 +166,35 @@ impl DataType {
67166
None
68167
}
69168
}
70-
#[inline]
71-
pub fn into_ffi_dtype(self) -> NPY_TYPES {
72-
match self {
73-
DataType::Bool => NPY_TYPES::NPY_BOOL,
74-
DataType::Int8 => NPY_TYPES::NPY_BYTE,
75-
DataType::Int16 => NPY_TYPES::NPY_SHORT,
76-
DataType::Int32 => NPY_TYPES::NPY_INT,
77-
DataType::Int64 => NPY_TYPES::NPY_LONGLONG,
78-
DataType::Uint8 => NPY_TYPES::NPY_UBYTE,
79-
DataType::Uint16 => NPY_TYPES::NPY_USHORT,
80-
DataType::Uint32 => NPY_TYPES::NPY_UINT,
81-
DataType::Uint64 => NPY_TYPES::NPY_ULONGLONG,
82-
DataType::Float32 => NPY_TYPES::NPY_FLOAT,
83-
DataType::Float64 => NPY_TYPES::NPY_DOUBLE,
84-
DataType::Complex32 => NPY_TYPES::NPY_CFLOAT,
85-
DataType::Complex64 => NPY_TYPES::NPY_CDOUBLE,
86-
DataType::Object => NPY_TYPES::NPY_OBJECT,
87-
}
88-
}
89169
}
90170

91171
/// Represents that a type can be an element of `PyArray`.
92172
pub trait Element: Clone {
173+
/// `DataType` corresponding to this type.
93174
const DATA_TYPE: DataType;
94-
fn is_same_type(other: i32) -> bool;
175+
176+
/// Returns if the give `dtype` is convertible to `Self` in Rust.
177+
fn is_same_type(dtype: &PyArrayDescr) -> bool;
178+
179+
/// Returns the corresponding
180+
/// [Enumerated Type](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types).
95181
#[inline]
96-
fn ffi_dtype() -> NPY_TYPES {
97-
Self::DATA_TYPE.into_ffi_dtype()
182+
fn npy_type() -> NPY_TYPES {
183+
Self::DATA_TYPE.into_ctype()
184+
}
185+
186+
/// Create `dtype`.
187+
fn get_dtype(py: Python) -> &PyArrayDescr {
188+
PyArrayDescr::from_npy_type(py, Self::npy_type())
98189
}
99190
}
100191

101192
macro_rules! impl_num_element {
102193
($t:ty, $npy_dat_t:ident $(,$npy_types: ident)+) => {
103194
impl Element for $t {
104195
const DATA_TYPE: DataType = DataType::$npy_dat_t;
105-
fn is_same_type(other: i32) -> bool {
106-
$(other == NPY_TYPES::$npy_types as i32 ||)+ false
196+
fn is_same_type(dtype: &PyArrayDescr) -> bool {
197+
$(dtype.get_typenum() == NPY_TYPES::$npy_types as i32 ||)+ false
107198
}
108199
}
109200
};

src/error.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//! Defines error types.
2-
use crate::types::DataType;
2+
use crate::DataType;
33
use pyo3::{exceptions as exc, PyErr, PyErrArguments, PyObject, Python, ToPyObject};
44
use std::fmt;
55

@@ -33,15 +33,15 @@ pub struct ShapeError {
3333

3434
impl ShapeError {
3535
pub(crate) fn new(
36-
from_type: i32,
36+
from_dtype: &crate::PyArrayDescr,
3737
from_dim: usize,
3838
to_type: DataType,
3939
to_dim: Option<usize>,
4040
) -> Self {
4141
ShapeError {
4242
from: ArrayDim {
4343
dim: Some(from_dim),
44-
dtype: DataType::from_i32(from_type),
44+
dtype: from_dtype.get_datatype(),
4545
},
4646
to: ArrayDim {
4747
dim: to_dim,

src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,19 @@ extern crate pyo3;
3838

3939
pub mod array;
4040
pub mod convert;
41+
mod dtype;
4142
mod error;
4243
pub mod npyffi;
4344
pub mod npyiter;
4445
mod readonly;
4546
mod slice_box;
46-
mod types;
4747

4848
pub use crate::array::{
4949
get_array_module, PyArray, PyArray1, PyArray2, PyArray3, PyArray4, PyArray5, PyArray6,
5050
PyArrayDyn,
5151
};
5252
pub use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
53+
pub use crate::dtype::{c32, c64, DataType, Element, PyArrayDescr};
5354
pub use crate::error::{FromVecError, NotContiguousError, ShapeError};
5455
pub use crate::npyffi::{PY_ARRAY_API, PY_UFUNC_API};
5556
pub use crate::npyiter::{
@@ -59,7 +60,6 @@ pub use crate::readonly::{
5960
PyReadonlyArray, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3, PyReadonlyArray4,
6061
PyReadonlyArray5, PyReadonlyArray6, PyReadonlyArrayDyn,
6162
};
62-
pub use crate::types::{c32, c64, DataType, Element};
6363
pub use ndarray::{Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
6464

6565
/// Test readme

0 commit comments

Comments
 (0)