Skip to content

Commit 5f1b3b6

Browse files
committed
Clean up typenums, add DataType::into_typenum()
1 parent e9cc2a6 commit 5f1b3b6

File tree

1 file changed

+34
-26
lines changed

1 file changed

+34
-26
lines changed

src/dtype.rs

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
use std::mem::size_of;
22
use std::os::raw::{c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort};
33

4-
use crate::npyffi::{NpyTypes, PyArray_Descr, NPY_TYPES, PY_ARRAY_API};
54
use cfg_if::cfg_if;
65
use num_traits::{Bounded, Zero};
76
use pyo3::{ffi, prelude::*, pyobject_native_type_core, types::PyType, AsPyPointer, PyNativeType};
87

8+
use crate::npyffi::{NpyTypes, PyArray_Descr, NPY_TYPES, PY_ARRAY_API};
9+
910
pub use num_complex::Complex32 as c32;
1011
pub use num_complex::Complex64 as c64;
1112

@@ -85,17 +86,20 @@ impl PyArrayDescr {
8586

8687
fn from_npy_type(py: Python, npy_type: NPY_TYPES) -> &Self {
8788
unsafe {
88-
let descr = PY_ARRAY_API.PyArray_DescrFromType(npy_type as i32);
89+
let descr = PY_ARRAY_API.PyArray_DescrFromType(npy_type as _);
8990
py.from_owned_ptr(descr as _)
9091
}
9192
}
9293

93-
pub(crate) fn get_typenum(&self) -> std::os::raw::c_int {
94+
/// Retrieves the
95+
/// [enumerated type](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types)
96+
/// for this type descriptor.
97+
pub fn get_typenum(&self) -> c_int {
9498
unsafe { *self.as_dtype_ptr() }.type_num
9599
}
96100
}
97101

98-
/// Represents numpy data type.
102+
/// Represents NumPy data type.
99103
///
100104
/// This is an incomplete counterpart of
101105
/// [Enumerated Types](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types)
@@ -119,26 +123,32 @@ pub enum DataType {
119123
}
120124

121125
impl DataType {
122-
/// Construct `DataType` from
123-
/// [Enumerated Types](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types).
126+
/// Convert `self` into an
127+
/// [enumerated type](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types).
128+
pub fn into_typenum(self) -> c_int {
129+
self.into_npy_type() as _
130+
}
131+
132+
/// Construct the data type from an
133+
/// [enumerated type](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types).
124134
pub fn from_typenum(typenum: c_int) -> Option<Self> {
125135
Some(match typenum {
126-
x if x == NPY_TYPES::NPY_BOOL as i32 => DataType::Bool,
127-
x if x == NPY_TYPES::NPY_BYTE as i32 => DataType::Int8,
128-
x if x == NPY_TYPES::NPY_SHORT as i32 => DataType::Int16,
129-
x if x == NPY_TYPES::NPY_INT as i32 => Self::integer::<c_int>()?,
130-
x if x == NPY_TYPES::NPY_LONG as i32 => Self::integer::<c_long>()?,
131-
x if x == NPY_TYPES::NPY_LONGLONG as i32 => Self::integer::<c_longlong>()?,
132-
x if x == NPY_TYPES::NPY_UBYTE as i32 => DataType::Uint8,
133-
x if x == NPY_TYPES::NPY_USHORT as i32 => DataType::Uint16,
134-
x if x == NPY_TYPES::NPY_UINT as i32 => Self::integer::<c_uint>()?,
135-
x if x == NPY_TYPES::NPY_ULONG as i32 => Self::integer::<c_ulong>()?,
136-
x if x == NPY_TYPES::NPY_ULONGLONG as i32 => Self::integer::<c_ulonglong>()?,
137-
x if x == NPY_TYPES::NPY_FLOAT as i32 => DataType::Float32,
138-
x if x == NPY_TYPES::NPY_DOUBLE as i32 => DataType::Float64,
139-
x if x == NPY_TYPES::NPY_CFLOAT as i32 => DataType::Complex32,
140-
x if x == NPY_TYPES::NPY_CDOUBLE as i32 => DataType::Complex64,
141-
x if x == NPY_TYPES::NPY_OBJECT as i32 => DataType::Object,
136+
x if x == NPY_TYPES::NPY_BOOL as c_int => DataType::Bool,
137+
x if x == NPY_TYPES::NPY_BYTE as c_int => DataType::Int8,
138+
x if x == NPY_TYPES::NPY_SHORT as c_int => DataType::Int16,
139+
x if x == NPY_TYPES::NPY_INT as c_int => Self::integer::<c_int>()?,
140+
x if x == NPY_TYPES::NPY_LONG as c_int => Self::integer::<c_long>()?,
141+
x if x == NPY_TYPES::NPY_LONGLONG as c_int => Self::integer::<c_longlong>()?,
142+
x if x == NPY_TYPES::NPY_UBYTE as c_int => DataType::Uint8,
143+
x if x == NPY_TYPES::NPY_USHORT as c_int => DataType::Uint16,
144+
x if x == NPY_TYPES::NPY_UINT as c_int => Self::integer::<c_uint>()?,
145+
x if x == NPY_TYPES::NPY_ULONG as c_int => Self::integer::<c_ulong>()?,
146+
x if x == NPY_TYPES::NPY_ULONGLONG as c_int => Self::integer::<c_ulonglong>()?,
147+
x if x == NPY_TYPES::NPY_FLOAT as c_int => DataType::Float32,
148+
x if x == NPY_TYPES::NPY_DOUBLE as c_int => DataType::Float64,
149+
x if x == NPY_TYPES::NPY_CFLOAT as c_int => DataType::Complex32,
150+
x if x == NPY_TYPES::NPY_CDOUBLE as c_int => DataType::Complex64,
151+
x if x == NPY_TYPES::NPY_OBJECT as c_int => DataType::Object,
142152
_ => return None,
143153
})
144154
}
@@ -160,9 +170,7 @@ impl DataType {
160170
})
161171
}
162172

163-
/// Convert `self` into
164-
/// [Enumerated Types](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types).
165-
pub fn into_ctype(self) -> NPY_TYPES {
173+
fn into_npy_type(self) -> NPY_TYPES {
166174
fn npy_int_type_lookup<T, T0, T1, T2>(npy_types: [NPY_TYPES; 3]) -> NPY_TYPES {
167175
// `npy_common.h` defines the integer aliases. In order, it checks:
168176
// NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR
@@ -284,7 +292,7 @@ macro_rules! impl_num_element {
284292
}
285293

286294
fn get_dtype(py: Python) -> &PyArrayDescr {
287-
PyArrayDescr::from_npy_type(py, $data_type.into_ctype())
295+
PyArrayDescr::from_npy_type(py, $data_type.into_npy_type())
288296
}
289297
}
290298
};

0 commit comments

Comments
 (0)