@@ -285,39 +285,89 @@ pub unsafe trait Element: Clone + Send {
285
285
fn get_dtype ( py : Python ) -> & PyArrayDescr ;
286
286
}
287
287
288
- macro_rules! impl_num_element {
289
- ( $ty: ty, $data_type: expr $( , #[ $meta: meta] ) * ) => {
288
+ fn npy_int_type_lookup < T , T0 , T1 , T2 > ( npy_types : [ NPY_TYPES ; 3 ] ) -> NPY_TYPES {
289
+ // `npy_common.h` defines the integer aliases. In order, it checks:
290
+ // NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR
291
+ // and assigns the alias to the first matching size, so we should check in this order.
292
+ match size_of :: < T > ( ) {
293
+ x if x == size_of :: < T0 > ( ) => npy_types[ 0 ] ,
294
+ x if x == size_of :: < T1 > ( ) => npy_types[ 1 ] ,
295
+ x if x == size_of :: < T2 > ( ) => npy_types[ 2 ] ,
296
+ _ => panic ! ( "Unable to match integer type descriptor: {:?}" , npy_types) ,
297
+ }
298
+ }
299
+
300
+ fn npy_int_type < T : Bounded + Zero + Sized + PartialEq > ( ) -> NPY_TYPES {
301
+ let is_unsigned = T :: min_value ( ) == T :: zero ( ) ;
302
+ let bit_width = size_of :: < T > ( ) << 3 ;
303
+
304
+ match ( is_unsigned, bit_width) {
305
+ ( false , 8 ) => NPY_TYPES :: NPY_BYTE ,
306
+ ( false , 16 ) => NPY_TYPES :: NPY_SHORT ,
307
+ ( false , 32 ) => npy_int_type_lookup :: < i32 , c_long , c_int , c_short > ( [
308
+ NPY_TYPES :: NPY_LONG ,
309
+ NPY_TYPES :: NPY_INT ,
310
+ NPY_TYPES :: NPY_SHORT ,
311
+ ] ) ,
312
+ ( false , 64 ) => npy_int_type_lookup :: < i64 , c_long , c_longlong , c_int > ( [
313
+ NPY_TYPES :: NPY_LONG ,
314
+ NPY_TYPES :: NPY_LONGLONG ,
315
+ NPY_TYPES :: NPY_INT ,
316
+ ] ) ,
317
+ ( true , 8 ) => NPY_TYPES :: NPY_UBYTE ,
318
+ ( true , 16 ) => NPY_TYPES :: NPY_USHORT ,
319
+ ( true , 32 ) => npy_int_type_lookup :: < u32 , c_ulong , c_uint , c_ushort > ( [
320
+ NPY_TYPES :: NPY_ULONG ,
321
+ NPY_TYPES :: NPY_UINT ,
322
+ NPY_TYPES :: NPY_USHORT ,
323
+ ] ) ,
324
+ ( true , 64 ) => npy_int_type_lookup :: < u64 , c_ulong , c_ulonglong , c_uint > ( [
325
+ NPY_TYPES :: NPY_ULONG ,
326
+ NPY_TYPES :: NPY_ULONGLONG ,
327
+ NPY_TYPES :: NPY_UINT ,
328
+ ] ) ,
329
+ _ => unreachable ! ( ) ,
330
+ }
331
+ }
332
+
333
+ macro_rules! impl_element_scalar {
334
+ ( @impl : $ty: ty, $npy_type: expr $( , #[ $meta: meta] ) * ) => {
290
335
$( #[ $meta] ) *
291
336
unsafe impl Element for $ty {
292
337
const IS_COPY : bool = true ;
293
-
294
338
fn get_dtype( py: Python ) -> & PyArrayDescr {
295
- PyArrayDescr :: from_npy_type( py, $data_type . into_npy_type ( ) )
339
+ PyArrayDescr :: from_npy_type( py, $npy_type )
296
340
}
297
341
}
298
342
} ;
343
+ ( $ty: ty, $npy_type: ident $( , #[ $meta: meta] ) * ) => {
344
+ impl_element_scalar!( @impl : $ty, NPY_TYPES :: $npy_type $( , #[ $meta] ) * ) ;
345
+ } ;
346
+ ( $ty: ty $( , #[ $meta: meta] ) * ) => {
347
+ impl_element_scalar!( @impl : $ty, npy_int_type:: <$ty>( ) $( , #[ $meta] ) * ) ;
348
+ } ;
299
349
}
300
350
301
- impl_num_element ! ( bool , DataType :: Bool ) ;
302
- impl_num_element ! ( i8 , DataType :: Int8 ) ;
303
- impl_num_element ! ( i16 , DataType :: Int16 ) ;
304
- impl_num_element ! ( i32 , DataType :: Int32 ) ;
305
- impl_num_element ! ( i64 , DataType :: Int64 ) ;
306
- impl_num_element ! ( u8 , DataType :: Uint8 ) ;
307
- impl_num_element ! ( u16 , DataType :: Uint16 ) ;
308
- impl_num_element ! ( u32 , DataType :: Uint32 ) ;
309
- impl_num_element ! ( u64 , DataType :: Uint64 ) ;
310
- impl_num_element ! ( f32 , DataType :: Float32 ) ;
311
- impl_num_element ! ( f64 , DataType :: Float64 ) ;
312
- impl_num_element ! ( Complex32 , DataType :: Complex32 ,
351
+ impl_element_scalar ! ( bool , NPY_BOOL ) ;
352
+ impl_element_scalar ! ( i8 ) ;
353
+ impl_element_scalar ! ( i16 ) ;
354
+ impl_element_scalar ! ( i32 ) ;
355
+ impl_element_scalar ! ( i64 ) ;
356
+ impl_element_scalar ! ( u8 ) ;
357
+ impl_element_scalar ! ( u16 ) ;
358
+ impl_element_scalar ! ( u32 ) ;
359
+ impl_element_scalar ! ( u64 ) ;
360
+ impl_element_scalar ! ( f32 , NPY_FLOAT ) ;
361
+ impl_element_scalar ! ( f64 , NPY_DOUBLE ) ;
362
+ impl_element_scalar ! ( Complex32 , NPY_CFLOAT ,
313
363
#[ doc = "Complex type with `f32` components which maps to `np.csingle` (`np.complex64`)." ] ) ;
314
- impl_num_element ! ( Complex64 , DataType :: Complex64 ,
364
+ impl_element_scalar ! ( Complex64 , NPY_CDOUBLE ,
315
365
#[ doc = "Complex type with `f64` components which maps to `np.cdouble` (`np.complex128`)." ] ) ;
316
366
317
367
cfg_if ! {
318
368
if #[ cfg( any( target_pointer_width = "32" , target_pointer_width = "64" ) ) ] {
319
- impl_num_element !( usize , DataType :: integer :: < usize > ( ) . unwrap ( ) ) ;
320
- impl_num_element !( isize , DataType :: integer :: < isize > ( ) . unwrap ( ) ) ;
369
+ impl_element_scalar !( usize ) ;
370
+ impl_element_scalar !( isize ) ;
321
371
}
322
372
}
323
373
0 commit comments