Skip to content

Commit fee3283

Browse files
committed
Map scalar types directly, bypassing DataType
1 parent 147775d commit fee3283

File tree

1 file changed

+69
-19
lines changed

1 file changed

+69
-19
lines changed

src/dtype.rs

Lines changed: 69 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -285,39 +285,89 @@ pub unsafe trait Element: Clone + Send {
285285
fn get_dtype(py: Python) -> &PyArrayDescr;
286286
}
287287

288-
macro_rules! impl_num_element {
289-
($ty:ty, $data_type:expr $(,#[$meta:meta])*) => {
288+
fn npy_int_type_lookup<T, T0, T1, T2>(npy_types: [NPY_TYPES; 3]) -> NPY_TYPES {
289+
// `npy_common.h` defines the integer aliases. In order, it checks:
290+
// NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR
291+
// and assigns the alias to the first matching size, so we should check in this order.
292+
match size_of::<T>() {
293+
x if x == size_of::<T0>() => npy_types[0],
294+
x if x == size_of::<T1>() => npy_types[1],
295+
x if x == size_of::<T2>() => npy_types[2],
296+
_ => panic!("Unable to match integer type descriptor: {:?}", npy_types),
297+
}
298+
}
299+
300+
fn npy_int_type<T: Bounded + Zero + Sized + PartialEq>() -> NPY_TYPES {
301+
let is_unsigned = T::min_value() == T::zero();
302+
let bit_width = size_of::<T>() << 3;
303+
304+
match (is_unsigned, bit_width) {
305+
(false, 8) => NPY_TYPES::NPY_BYTE,
306+
(false, 16) => NPY_TYPES::NPY_SHORT,
307+
(false, 32) => npy_int_type_lookup::<i32, c_long, c_int, c_short>([
308+
NPY_TYPES::NPY_LONG,
309+
NPY_TYPES::NPY_INT,
310+
NPY_TYPES::NPY_SHORT,
311+
]),
312+
(false, 64) => npy_int_type_lookup::<i64, c_long, c_longlong, c_int>([
313+
NPY_TYPES::NPY_LONG,
314+
NPY_TYPES::NPY_LONGLONG,
315+
NPY_TYPES::NPY_INT,
316+
]),
317+
(true, 8) => NPY_TYPES::NPY_UBYTE,
318+
(true, 16) => NPY_TYPES::NPY_USHORT,
319+
(true, 32) => npy_int_type_lookup::<u32, c_ulong, c_uint, c_ushort>([
320+
NPY_TYPES::NPY_ULONG,
321+
NPY_TYPES::NPY_UINT,
322+
NPY_TYPES::NPY_USHORT,
323+
]),
324+
(true, 64) => npy_int_type_lookup::<u64, c_ulong, c_ulonglong, c_uint>([
325+
NPY_TYPES::NPY_ULONG,
326+
NPY_TYPES::NPY_ULONGLONG,
327+
NPY_TYPES::NPY_UINT,
328+
]),
329+
_ => unreachable!(),
330+
}
331+
}
332+
333+
macro_rules! impl_element_scalar {
334+
(@impl: $ty:ty, $npy_type:expr $(,#[$meta:meta])*) => {
290335
$(#[$meta])*
291336
unsafe impl Element for $ty {
292337
const IS_COPY: bool = true;
293-
294338
fn get_dtype(py: Python) -> &PyArrayDescr {
295-
PyArrayDescr::from_npy_type(py, $data_type.into_npy_type())
339+
PyArrayDescr::from_npy_type(py, $npy_type)
296340
}
297341
}
298342
};
343+
($ty:ty, $npy_type:ident $(,#[$meta:meta])*) => {
344+
impl_element_scalar!(@impl: $ty, NPY_TYPES::$npy_type $(,#[$meta])*);
345+
};
346+
($ty:ty $(,#[$meta:meta])*) => {
347+
impl_element_scalar!(@impl: $ty, npy_int_type::<$ty>() $(,#[$meta])*);
348+
};
299349
}
300350

301-
impl_num_element!(bool, DataType::Bool);
302-
impl_num_element!(i8, DataType::Int8);
303-
impl_num_element!(i16, DataType::Int16);
304-
impl_num_element!(i32, DataType::Int32);
305-
impl_num_element!(i64, DataType::Int64);
306-
impl_num_element!(u8, DataType::Uint8);
307-
impl_num_element!(u16, DataType::Uint16);
308-
impl_num_element!(u32, DataType::Uint32);
309-
impl_num_element!(u64, DataType::Uint64);
310-
impl_num_element!(f32, DataType::Float32);
311-
impl_num_element!(f64, DataType::Float64);
312-
impl_num_element!(Complex32, DataType::Complex32,
351+
impl_element_scalar!(bool, NPY_BOOL);
352+
impl_element_scalar!(i8);
353+
impl_element_scalar!(i16);
354+
impl_element_scalar!(i32);
355+
impl_element_scalar!(i64);
356+
impl_element_scalar!(u8);
357+
impl_element_scalar!(u16);
358+
impl_element_scalar!(u32);
359+
impl_element_scalar!(u64);
360+
impl_element_scalar!(f32, NPY_FLOAT);
361+
impl_element_scalar!(f64, NPY_DOUBLE);
362+
impl_element_scalar!(Complex32, NPY_CFLOAT,
313363
#[doc = "Complex type with `f32` components which maps to `np.csingle` (`np.complex64`)."]);
314-
impl_num_element!(Complex64, DataType::Complex64,
364+
impl_element_scalar!(Complex64, NPY_CDOUBLE,
315365
#[doc = "Complex type with `f64` components which maps to `np.cdouble` (`np.complex128`)."]);
316366

317367
cfg_if! {
318368
if #[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))] {
319-
impl_num_element!(usize, DataType::integer::<usize>().unwrap());
320-
impl_num_element!(isize, DataType::integer::<isize>().unwrap());
369+
impl_element_scalar!(usize);
370+
impl_element_scalar!(isize);
321371
}
322372
}
323373

0 commit comments

Comments
 (0)