Skip to content

Commit 7fcf952

Browse files
committed
Add some basic methods to PyArrayDescr
1 parent 8d510db commit 7fcf952

File tree

8 files changed

+206
-168
lines changed

8 files changed

+206
-168
lines changed

src/array.rs

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ use std::{cell::Cell, mem, os::raw::c_int, ptr, slice};
1010
use std::{iter::ExactSizeIterator, marker::PhantomData};
1111

1212
use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
13+
use crate::dtype::Element;
1314
use crate::error::{FromVecError, NotContiguousError, ShapeError};
1415
use crate::slice_box::SliceBox;
15-
use crate::types::Element;
1616

1717
/// A safe, static-typed interface for
1818
/// [NumPy ndarray](https://numpy.org/doc/stable/reference/arrays.ndarray.html).
@@ -136,12 +136,12 @@ impl<'a, T: Element, D: Dimension> FromPyObject<'a> for &'a PyArray<T, D> {
136136
}
137137
&*(ob as *const PyAny as *const PyArray<T, D>)
138138
};
139-
let type_ = unsafe { (*(*array.as_array_ptr()).descr).type_num };
139+
let dtype = array.dtype();
140140
let dim = array.shape().len();
141-
if T::is_same_type(type_) && D::NDIM.map(|n| n == dim).unwrap_or(true) {
141+
if T::is_same_type(dtype) && D::NDIM.map(|n| n == dim).unwrap_or(true) {
142142
Ok(array)
143143
} else {
144-
Err(ShapeError::new(type_, dim, T::DATA_TYPE, D::NDIM).into())
144+
Err(ShapeError::new(dtype, dim, T::DATA_TYPE, D::NDIM).into())
145145
}
146146
}
147147
}
@@ -152,6 +152,11 @@ impl<T, D> PyArray<T, D> {
152152
self.as_ptr() as _
153153
}
154154

155+
pub fn dtype(&self) -> &crate::PyArrayDescr {
156+
let descr_ptr = unsafe { (*self.as_array_ptr()).descr };
157+
unsafe { pyo3::FromPyPointer::from_borrowed_ptr(self.py(), descr_ptr as _) }
158+
}
159+
155160
#[inline(always)]
156161
fn check_flag(&self, flag: c_int) -> bool {
157162
unsafe { *self.as_array_ptr() }.flags & flag == flag
@@ -389,7 +394,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
389394
PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type),
390395
dims.ndim_cint(),
391396
dims.as_dims_ptr(),
392-
T::ffi_dtype() as i32,
397+
T::npy_type() as i32,
393398
strides as *mut _, // strides
394399
ptr::null_mut(), // data
395400
0, // itemsize
@@ -418,7 +423,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
418423
PY_ARRAY_API.get_type_object(npyffi::NpyTypes::PyArray_Type),
419424
dims.ndim_cint(),
420425
dims.as_dims_ptr(),
421-
T::ffi_dtype() as i32,
426+
T::npy_type() as i32,
422427
strides as *mut _, // strides
423428
data_ptr as _, // data
424429
mem::size_of::<T>() as i32, // itemsize
@@ -450,11 +455,11 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
450455
{
451456
let dims = dims.into_dimension();
452457
unsafe {
453-
let descr = PY_ARRAY_API.PyArray_DescrFromType(T::ffi_dtype() as i32);
458+
let dtype = T::get_dtype(py);
454459
let ptr = PY_ARRAY_API.PyArray_Zeros(
455460
dims.ndim_cint(),
456461
dims.as_dims_ptr(),
457-
descr,
462+
dtype.into_ptr() as _,
458463
if is_fortran { -1 } else { 0 },
459464
);
460465
Self::from_owned_ptr(py, ptr)
@@ -941,10 +946,10 @@ impl<T: Element, D> PyArray<T, D> {
941946
/// assert_eq!(pyarray_i.readonly().as_slice().unwrap(), &[2, 3, 4]);
942947
pub fn cast<'py, U: Element>(&'py self, is_fortran: bool) -> PyResult<&'py PyArray<U, D>> {
943948
let ptr = unsafe {
944-
let descr = PY_ARRAY_API.PyArray_DescrFromType(U::ffi_dtype() as i32);
949+
let dtype = U::get_dtype(self.py());
945950
PY_ARRAY_API.PyArray_CastToType(
946951
self.as_array_ptr(),
947-
descr,
952+
dtype.into_ptr() as _,
948953
if is_fortran { -1 } else { 0 },
949954
)
950955
};
@@ -1028,7 +1033,7 @@ impl<T: Element + AsPrimitive<f64>> PyArray<T, Ix1> {
10281033
start.as_(),
10291034
stop.as_(),
10301035
step.as_(),
1031-
T::ffi_dtype() as i32,
1036+
T::npy_type() as i32,
10321037
);
10331038
Self::from_owned_ptr(py, ptr)
10341039
}

src/dtype.rs

Lines changed: 182 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1-
use crate::npyffi::{NpyTypes, PyArray_Descr, PY_ARRAY_API};
1+
//! Implements conversion utitlities.
2+
use crate::npyffi::{NpyTypes, PyArray_Descr, NPY_TYPES, PY_ARRAY_API};
3+
pub use num_complex::Complex32 as c32;
4+
pub use num_complex::Complex64 as c64;
25
use pyo3::ffi;
36
use pyo3::prelude::*;
7+
use pyo3::types::PyType;
8+
use pyo3::{AsPyPointer, PyNativeType};
49
use std::os::raw::c_int;
510

611
pub struct PyArrayDescr(PyAny);
@@ -21,3 +26,179 @@ unsafe fn arraydescr_check(op: *mut ffi::PyObject) -> c_int {
2126
PY_ARRAY_API.get_type_object(NpyTypes::PyArrayDescr_Type),
2227
)
2328
}
29+
30+
impl PyArrayDescr {
31+
pub fn as_dtype_ptr(&self) -> *mut PyArray_Descr {
32+
self.as_ptr() as _
33+
}
34+
35+
pub fn get_type(&self) -> &PyType {
36+
let dtype_type_ptr = unsafe { *self.as_dtype_ptr() }.typeobj;
37+
unsafe { PyType::from_type_ptr(self.py(), dtype_type_ptr) }
38+
}
39+
40+
pub fn get_typenum(&self) -> std::os::raw::c_int {
41+
unsafe { *self.as_dtype_ptr() }.type_num
42+
}
43+
44+
pub fn get_datatype(&self) -> Option<DataType> {
45+
DataType::from_typenum(self.get_typenum())
46+
}
47+
48+
pub fn from_npy_type(py: Python, npy_type: NPY_TYPES) -> &Self {
49+
unsafe {
50+
let descr = PY_ARRAY_API.PyArray_DescrFromType(npy_type as i32);
51+
py.from_owned_ptr(descr as _)
52+
}
53+
}
54+
}
55+
56+
/// An enum type represents numpy data type.
57+
///
58+
/// This type is mainly for displaying error, and user don't have to use it directly.
59+
#[derive(Clone, Debug, Eq, PartialEq)]
60+
pub enum DataType {
61+
Bool,
62+
Int8,
63+
Int16,
64+
Int32,
65+
Int64,
66+
Uint8,
67+
Uint16,
68+
Uint32,
69+
Uint64,
70+
Float32,
71+
Float64,
72+
Complex32,
73+
Complex64,
74+
Object,
75+
}
76+
77+
impl DataType {
78+
pub fn from_typenum(typenum: c_int) -> Option<Self> {
79+
Some(match typenum {
80+
x if x == NPY_TYPES::NPY_BOOL as i32 => DataType::Bool,
81+
x if x == NPY_TYPES::NPY_BYTE as i32 => DataType::Int8,
82+
x if x == NPY_TYPES::NPY_SHORT as i32 => DataType::Int16,
83+
x if x == NPY_TYPES::NPY_INT as i32 => DataType::Int32,
84+
x if x == NPY_TYPES::NPY_LONG as i32 => return DataType::from_clong(false),
85+
x if x == NPY_TYPES::NPY_LONGLONG as i32 => DataType::Int64,
86+
x if x == NPY_TYPES::NPY_UBYTE as i32 => DataType::Uint8,
87+
x if x == NPY_TYPES::NPY_USHORT as i32 => DataType::Uint16,
88+
x if x == NPY_TYPES::NPY_UINT as i32 => DataType::Uint32,
89+
x if x == NPY_TYPES::NPY_ULONG as i32 => return DataType::from_clong(true),
90+
x if x == NPY_TYPES::NPY_ULONGLONG as i32 => DataType::Uint64,
91+
x if x == NPY_TYPES::NPY_FLOAT as i32 => DataType::Float32,
92+
x if x == NPY_TYPES::NPY_DOUBLE as i32 => DataType::Float64,
93+
x if x == NPY_TYPES::NPY_CFLOAT as i32 => DataType::Complex32,
94+
x if x == NPY_TYPES::NPY_CDOUBLE as i32 => DataType::Complex64,
95+
x if x == NPY_TYPES::NPY_OBJECT as i32 => DataType::Object,
96+
_ => return None,
97+
})
98+
}
99+
100+
pub fn from_dtype(dtype: &crate::PyArrayDescr) -> Option<Self> {
101+
Self::from_typenum(dtype.get_typenum())
102+
}
103+
104+
#[inline]
105+
pub fn into_ctype(self) -> NPY_TYPES {
106+
match self {
107+
DataType::Bool => NPY_TYPES::NPY_BOOL,
108+
DataType::Int8 => NPY_TYPES::NPY_BYTE,
109+
DataType::Int16 => NPY_TYPES::NPY_SHORT,
110+
DataType::Int32 => NPY_TYPES::NPY_INT,
111+
DataType::Int64 => NPY_TYPES::NPY_LONGLONG,
112+
DataType::Uint8 => NPY_TYPES::NPY_UBYTE,
113+
DataType::Uint16 => NPY_TYPES::NPY_USHORT,
114+
DataType::Uint32 => NPY_TYPES::NPY_UINT,
115+
DataType::Uint64 => NPY_TYPES::NPY_ULONGLONG,
116+
DataType::Float32 => NPY_TYPES::NPY_FLOAT,
117+
DataType::Float64 => NPY_TYPES::NPY_DOUBLE,
118+
DataType::Complex32 => NPY_TYPES::NPY_CFLOAT,
119+
DataType::Complex64 => NPY_TYPES::NPY_CDOUBLE,
120+
DataType::Object => NPY_TYPES::NPY_OBJECT,
121+
}
122+
}
123+
124+
#[inline(always)]
125+
fn from_clong(is_usize: bool) -> Option<Self> {
126+
if cfg!(any(target_pointer_width = "32", windows)) {
127+
Some(if is_usize {
128+
DataType::Uint32
129+
} else {
130+
DataType::Int32
131+
})
132+
} else if cfg!(all(target_pointer_width = "64", not(windows))) {
133+
Some(if is_usize {
134+
DataType::Uint64
135+
} else {
136+
DataType::Int64
137+
})
138+
} else {
139+
None
140+
}
141+
}
142+
}
143+
144+
/// Represents that a type can be an element of `PyArray`.
145+
pub trait Element: Clone {
146+
const DATA_TYPE: DataType;
147+
148+
fn is_same_type(dtype: &PyArrayDescr) -> bool;
149+
150+
#[inline]
151+
fn npy_type() -> NPY_TYPES {
152+
Self::DATA_TYPE.into_ctype()
153+
}
154+
155+
fn get_dtype(py: Python) -> &PyArrayDescr {
156+
PyArrayDescr::from_npy_type(py, Self::npy_type())
157+
}
158+
}
159+
160+
macro_rules! impl_num_element {
161+
($t:ty, $npy_dat_t:ident $(,$npy_types: ident)+) => {
162+
impl Element for $t {
163+
const DATA_TYPE: DataType = DataType::$npy_dat_t;
164+
fn is_same_type(dtype: &PyArrayDescr) -> bool {
165+
$(dtype.get_typenum() == NPY_TYPES::$npy_types as i32 ||)+ false
166+
}
167+
}
168+
};
169+
}
170+
171+
impl_num_element!(bool, Bool, NPY_BOOL);
172+
impl_num_element!(i8, Int8, NPY_BYTE);
173+
impl_num_element!(i16, Int16, NPY_SHORT);
174+
impl_num_element!(u8, Uint8, NPY_UBYTE);
175+
impl_num_element!(u16, Uint16, NPY_USHORT);
176+
impl_num_element!(f32, Float32, NPY_FLOAT);
177+
impl_num_element!(f64, Float64, NPY_DOUBLE);
178+
impl_num_element!(c32, Complex32, NPY_CFLOAT);
179+
impl_num_element!(c64, Complex64, NPY_CDOUBLE);
180+
181+
cfg_if! {
182+
if #[cfg(all(target_pointer_width = "64", windows))] {
183+
impl_num_element!(usize, Uint64, NPY_ULONGLONG);
184+
} else if #[cfg(all(target_pointer_width = "64", not(windows)))] {
185+
impl_num_element!(usize, Uint64, NPY_ULONG, NPY_ULONGLONG);
186+
} else if #[cfg(all(target_pointer_width = "32", windows))] {
187+
impl_num_element!(usize, Uint32, NPY_UINT, NPY_ULONG);
188+
} else if #[cfg(all(target_pointer_width = "32", not(windows)))] {
189+
impl_num_element!(usize, Uint32, NPY_UINT);
190+
}
191+
}
192+
cfg_if! {
193+
if #[cfg(any(target_pointer_width = "32", windows))] {
194+
impl_num_element!(i32, Int32, NPY_INT, NPY_LONG);
195+
impl_num_element!(u32, Uint32, NPY_UINT, NPY_ULONG);
196+
impl_num_element!(i64, Int64, NPY_LONGLONG);
197+
impl_num_element!(u64, Uint64, NPY_ULONGLONG);
198+
} else if #[cfg(all(target_pointer_width = "64", not(windows)))] {
199+
impl_num_element!(i32, Int32, NPY_INT);
200+
impl_num_element!(u32, Uint32, NPY_UINT);
201+
impl_num_element!(i64, Int64, NPY_LONG, NPY_LONGLONG);
202+
impl_num_element!(u64, Uint64, NPY_ULONG, NPY_ULONGLONG);
203+
}
204+
}

src/error.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//! Defines error types.
2-
use crate::types::DataType;
2+
use crate::DataType;
33
use pyo3::{exceptions as exc, PyErr, PyErrArguments, PyObject, Python, ToPyObject};
44
use std::fmt;
55

@@ -33,15 +33,15 @@ pub struct ShapeError {
3333

3434
impl ShapeError {
3535
pub(crate) fn new(
36-
from_type: i32,
36+
from_dtype: &crate::PyArrayDescr,
3737
from_dim: usize,
3838
to_type: DataType,
3939
to_dim: Option<usize>,
4040
) -> Self {
4141
ShapeError {
4242
from: ArrayDim {
4343
dim: Some(from_dim),
44-
dtype: DataType::from_i32(from_type),
44+
dtype: DataType::from_dtype(from_dtype),
4545
},
4646
to: ArrayDim {
4747
dim: to_dim,

src/lib.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,13 @@ pub mod npyffi;
4444
pub mod npyiter;
4545
mod readonly;
4646
mod slice_box;
47-
mod types;
4847

4948
pub use crate::array::{
5049
get_array_module, PyArray, PyArray1, PyArray2, PyArray3, PyArray4, PyArray5, PyArray6,
5150
PyArrayDyn,
5251
};
5352
pub use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
54-
pub use crate::dtype::PyArrayDescr;
53+
pub use crate::dtype::{c32, c64, DataType, Element, PyArrayDescr};
5554
pub use crate::error::{FromVecError, NotContiguousError, ShapeError};
5655
pub use crate::npyffi::{PY_ARRAY_API, PY_UFUNC_API};
5756
pub use crate::npyiter::{
@@ -61,7 +60,6 @@ pub use crate::readonly::{
6160
PyReadonlyArray, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3, PyReadonlyArray4,
6261
PyReadonlyArray5, PyReadonlyArray6, PyReadonlyArrayDyn,
6362
};
64-
pub use crate::types::{c32, c64, DataType, Element};
6563
pub use ndarray::{Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
6664

6765
/// Test readme

src/npyffi/array.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,10 +315,10 @@ impl PyArrayAPI {
315315
impl_api![303; PyArray_SetWritebackIfCopyBase(arr: *mut PyArrayObject, base: *mut PyArrayObject) -> c_int];
316316
}
317317

318-
/// Define PyTypeObject related to Array API
318+
// Define type objects that belongs to Numpy API
319319
macro_rules! impl_array_type {
320320
($(($offset:expr, $tname:ident)),*) => {
321-
/// All type objects that numpy has.
321+
/// All type objects of numpy API.
322322
#[allow(non_camel_case_types)]
323323
pub enum NpyTypes { $($tname),* }
324324
impl PyArrayAPI {

src/npyiter.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
//! This module exposes two iterators:
44
//! [NpySingleIter](./struct.NpySingleIter.html) and
55
//! [NpyMultiIter](./struct.NpyMultiIter.html).
6-
use crate::array::{PyArray, PyArrayDyn};
76
use crate::npyffi::{
87
array::PY_ARRAY_API,
98
npy_intp, npy_uint32,
@@ -14,8 +13,7 @@ use crate::npyffi::{
1413
NPY_ITER_READONLY, NPY_ITER_READWRITE, NPY_ITER_REDUCE_OK, NPY_ITER_REFS_OK,
1514
NPY_ITER_ZEROSIZE_OK,
1615
};
17-
use crate::readonly::PyReadonlyArray;
18-
use crate::types::Element;
16+
use crate::{Element, PyArray, PyArrayDyn, PyReadonlyArray};
1917
use pyo3::{prelude::*, PyNativeType};
2018

2119
use std::marker::PhantomData;

0 commit comments

Comments
 (0)