Skip to content

Commit e9cc2a6

Browse files
committed
Rework dtype integer type mapping to match numpy
1 parent a2ba3fb commit e9cc2a6

File tree

1 file changed

+85
-68
lines changed

1 file changed

+85
-68
lines changed

src/dtype.rs

Lines changed: 85 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
use std::mem::size_of;
2+
use std::os::raw::{c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort};
3+
14
use crate::npyffi::{NpyTypes, PyArray_Descr, NPY_TYPES, PY_ARRAY_API};
25
use cfg_if::cfg_if;
6+
use num_traits::{Bounded, Zero};
37
use pyo3::{ffi, prelude::*, pyobject_native_type_core, types::PyType, AsPyPointer, PyNativeType};
4-
use std::os::raw::c_int;
58

69
pub use num_complex::Complex32 as c32;
710
pub use num_complex::Complex64 as c64;
@@ -123,14 +126,14 @@ impl DataType {
123126
x if x == NPY_TYPES::NPY_BOOL as i32 => DataType::Bool,
124127
x if x == NPY_TYPES::NPY_BYTE as i32 => DataType::Int8,
125128
x if x == NPY_TYPES::NPY_SHORT as i32 => DataType::Int16,
126-
x if x == NPY_TYPES::NPY_INT as i32 => DataType::Int32,
127-
x if x == NPY_TYPES::NPY_LONG as i32 => return DataType::from_clong(false),
128-
x if x == NPY_TYPES::NPY_LONGLONG as i32 => DataType::Int64,
129+
x if x == NPY_TYPES::NPY_INT as i32 => Self::integer::<c_int>()?,
130+
x if x == NPY_TYPES::NPY_LONG as i32 => Self::integer::<c_long>()?,
131+
x if x == NPY_TYPES::NPY_LONGLONG as i32 => Self::integer::<c_longlong>()?,
129132
x if x == NPY_TYPES::NPY_UBYTE as i32 => DataType::Uint8,
130133
x if x == NPY_TYPES::NPY_USHORT as i32 => DataType::Uint16,
131-
x if x == NPY_TYPES::NPY_UINT as i32 => DataType::Uint32,
132-
x if x == NPY_TYPES::NPY_ULONG as i32 => return DataType::from_clong(true),
133-
x if x == NPY_TYPES::NPY_ULONGLONG as i32 => DataType::Uint64,
134+
x if x == NPY_TYPES::NPY_UINT as i32 => Self::integer::<c_uint>()?,
135+
x if x == NPY_TYPES::NPY_ULONG as i32 => Self::integer::<c_ulong>()?,
136+
x if x == NPY_TYPES::NPY_ULONGLONG as i32 => Self::integer::<c_ulonglong>()?,
134137
x if x == NPY_TYPES::NPY_FLOAT as i32 => DataType::Float32,
135138
x if x == NPY_TYPES::NPY_DOUBLE as i32 => DataType::Float64,
136139
x if x == NPY_TYPES::NPY_CFLOAT as i32 => DataType::Complex32,
@@ -140,48 +143,71 @@ impl DataType {
140143
})
141144
}
142145

146+
#[inline]
147+
fn integer<T: Bounded + Zero + Sized + PartialEq>() -> Option<Self> {
148+
let is_unsigned = T::min_value() == T::zero();
149+
let bit_width = size_of::<T>() << 3;
150+
Some(match (is_unsigned, bit_width) {
151+
(false, 8) => Self::Int8,
152+
(false, 16) => Self::Int16,
153+
(false, 32) => Self::Int32,
154+
(false, 64) => Self::Int64,
155+
(true, 8) => Self::Uint8,
156+
(true, 16) => Self::Uint16,
157+
(true, 32) => Self::Uint32,
158+
(true, 64) => Self::Uint64,
159+
_ => return None,
160+
})
161+
}
162+
143163
/// Convert `self` into
144164
/// [Enumerated Types](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types).
145165
pub fn into_ctype(self) -> NPY_TYPES {
166+
fn npy_int_type_lookup<T, T0, T1, T2>(npy_types: [NPY_TYPES; 3]) -> NPY_TYPES {
167+
// `npy_common.h` defines the integer aliases. In order, it checks:
168+
// NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR
169+
// and assigns the alias to the first matching size, so we should check in this order.
170+
match size_of::<T>() {
171+
x if x == size_of::<T0>() => npy_types[0],
172+
x if x == size_of::<T1>() => npy_types[1],
173+
x if x == size_of::<T2>() => npy_types[2],
174+
_ => panic!("Unable to match integer type descriptor: {:?}", npy_types),
175+
}
176+
}
177+
146178
match self {
147179
DataType::Bool => NPY_TYPES::NPY_BOOL,
148180
DataType::Int8 => NPY_TYPES::NPY_BYTE,
149181
DataType::Int16 => NPY_TYPES::NPY_SHORT,
150-
DataType::Int32 => NPY_TYPES::NPY_INT,
151-
#[cfg(all(target_pointer_width = "64", not(windows)))]
152-
DataType::Int64 => NPY_TYPES::NPY_LONG,
153-
#[cfg(any(target_pointer_width = "32", windows))]
154-
DataType::Int64 => NPY_TYPES::NPY_LONGLONG,
182+
DataType::Int32 => npy_int_type_lookup::<i32, c_long, c_int, c_short>([
183+
NPY_TYPES::NPY_LONG,
184+
NPY_TYPES::NPY_INT,
185+
NPY_TYPES::NPY_SHORT,
186+
]),
187+
DataType::Int64 => npy_int_type_lookup::<i64, c_long, c_longlong, c_int>([
188+
NPY_TYPES::NPY_LONG,
189+
NPY_TYPES::NPY_LONGLONG,
190+
NPY_TYPES::NPY_INT,
191+
]),
155192
DataType::Uint8 => NPY_TYPES::NPY_UBYTE,
156193
DataType::Uint16 => NPY_TYPES::NPY_USHORT,
157-
DataType::Uint32 => NPY_TYPES::NPY_UINT,
158-
DataType::Uint64 => NPY_TYPES::NPY_ULONGLONG,
194+
DataType::Uint32 => npy_int_type_lookup::<u32, c_ulong, c_uint, c_ushort>([
195+
NPY_TYPES::NPY_ULONG,
196+
NPY_TYPES::NPY_UINT,
197+
NPY_TYPES::NPY_USHORT,
198+
]),
199+
DataType::Uint64 => npy_int_type_lookup::<u64, c_ulong, c_ulonglong, c_uint>([
200+
NPY_TYPES::NPY_ULONG,
201+
NPY_TYPES::NPY_ULONGLONG,
202+
NPY_TYPES::NPY_UINT,
203+
]),
159204
DataType::Float32 => NPY_TYPES::NPY_FLOAT,
160205
DataType::Float64 => NPY_TYPES::NPY_DOUBLE,
161206
DataType::Complex32 => NPY_TYPES::NPY_CFLOAT,
162207
DataType::Complex64 => NPY_TYPES::NPY_CDOUBLE,
163208
DataType::Object => NPY_TYPES::NPY_OBJECT,
164209
}
165210
}
166-
167-
#[inline(always)]
168-
fn from_clong(is_usize: bool) -> Option<Self> {
169-
if cfg!(any(target_pointer_width = "32", windows)) {
170-
Some(if is_usize {
171-
DataType::Uint32
172-
} else {
173-
DataType::Int32
174-
})
175-
} else if cfg!(all(target_pointer_width = "64", not(windows))) {
176-
Some(if is_usize {
177-
DataType::Uint64
178-
} else {
179-
DataType::Int64
180-
})
181-
} else {
182-
None
183-
}
184-
}
185211
}
186212

187213
/// Represents that a type can be an element of `PyArray`.
@@ -249,59 +275,50 @@ pub unsafe trait Element: Clone + Send {
249275
}
250276

251277
macro_rules! impl_num_element {
252-
($t:ty, $npy_dat_t:ident $(,$npy_types: ident)+) => {
253-
unsafe impl Element for $t {
254-
const DATA_TYPE: DataType = DataType::$npy_dat_t;
278+
($ty:ty, $data_type:expr) => {
279+
unsafe impl Element for $ty {
280+
const DATA_TYPE: DataType = $data_type;
281+
255282
fn is_same_type(dtype: &PyArrayDescr) -> bool {
256-
$(dtype.get_typenum() == NPY_TYPES::$npy_types as i32 ||)+ false
283+
dtype.get_datatype() == Some($data_type)
257284
}
285+
258286
fn get_dtype(py: Python) -> &PyArrayDescr {
259-
PyArrayDescr::from_npy_type(py, DataType::$npy_dat_t.into_ctype())
287+
PyArrayDescr::from_npy_type(py, $data_type.into_ctype())
260288
}
261289
}
262290
};
263291
}
264292

265-
impl_num_element!(bool, Bool, NPY_BOOL);
266-
impl_num_element!(i8, Int8, NPY_BYTE);
267-
impl_num_element!(i16, Int16, NPY_SHORT);
268-
impl_num_element!(u8, Uint8, NPY_UBYTE);
269-
impl_num_element!(u16, Uint16, NPY_USHORT);
270-
impl_num_element!(f32, Float32, NPY_FLOAT);
271-
impl_num_element!(f64, Float64, NPY_DOUBLE);
272-
impl_num_element!(c32, Complex32, NPY_CFLOAT);
273-
impl_num_element!(c64, Complex64, NPY_CDOUBLE);
293+
impl_num_element!(bool, DataType::Bool);
294+
impl_num_element!(i8, DataType::Int8);
295+
impl_num_element!(i16, DataType::Int16);
296+
impl_num_element!(i32, DataType::Int32);
297+
impl_num_element!(i64, DataType::Int64);
298+
impl_num_element!(u8, DataType::Uint8);
299+
impl_num_element!(u16, DataType::Uint16);
300+
impl_num_element!(u32, DataType::Uint32);
301+
impl_num_element!(u64, DataType::Uint64);
302+
impl_num_element!(f32, DataType::Float32);
303+
impl_num_element!(f64, DataType::Float64);
304+
impl_num_element!(c32, DataType::Complex32);
305+
impl_num_element!(c64, DataType::Complex64);
274306

275307
cfg_if! {
276-
if #[cfg(all(target_pointer_width = "64", windows))] {
277-
impl_num_element!(usize, Uint64, NPY_ULONGLONG);
278-
} else if #[cfg(all(target_pointer_width = "64", not(windows)))] {
279-
impl_num_element!(usize, Uint64, NPY_ULONG, NPY_ULONGLONG);
280-
} else if #[cfg(all(target_pointer_width = "32", windows))] {
281-
impl_num_element!(usize, Uint32, NPY_UINT, NPY_ULONG);
282-
} else if #[cfg(all(target_pointer_width = "32", not(windows)))] {
283-
impl_num_element!(usize, Uint32, NPY_UINT);
284-
}
285-
}
286-
cfg_if! {
287-
if #[cfg(any(target_pointer_width = "32", windows))] {
288-
impl_num_element!(i32, Int32, NPY_INT, NPY_LONG);
289-
impl_num_element!(u32, Uint32, NPY_UINT, NPY_ULONG);
290-
impl_num_element!(i64, Int64, NPY_LONGLONG);
291-
impl_num_element!(u64, Uint64, NPY_ULONGLONG);
292-
} else if #[cfg(all(target_pointer_width = "64", not(windows)))] {
293-
impl_num_element!(i32, Int32, NPY_INT);
294-
impl_num_element!(u32, Uint32, NPY_UINT);
295-
impl_num_element!(i64, Int64, NPY_LONG, NPY_LONGLONG);
296-
impl_num_element!(u64, Uint64, NPY_ULONG, NPY_ULONGLONG);
308+
if #[cfg(target_pointer_width = "64")] {
309+
impl_num_element!(usize, DataType::Uint64);
310+
} else if #[cfg(target_pointer_width = "32")] {
311+
impl_num_element!(usize, DataType::Uint32);
297312
}
298313
}
299314

300315
unsafe impl Element for PyObject {
301316
const DATA_TYPE: DataType = DataType::Object;
317+
302318
fn is_same_type(dtype: &PyArrayDescr) -> bool {
303319
dtype.get_typenum() == NPY_TYPES::NPY_OBJECT as i32
304320
}
321+
305322
fn get_dtype(py: Python) -> &PyArrayDescr {
306323
PyArrayDescr::object(py)
307324
}

0 commit comments

Comments
 (0)