1
+ use std:: mem:: size_of;
2
+ use std:: os:: raw:: { c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort} ;
3
+
1
4
use crate :: npyffi:: { NpyTypes , PyArray_Descr , NPY_TYPES , PY_ARRAY_API } ;
2
5
use cfg_if:: cfg_if;
6
+ use num_traits:: { Bounded , Zero } ;
3
7
use pyo3:: { ffi, prelude:: * , pyobject_native_type_core, types:: PyType , AsPyPointer , PyNativeType } ;
4
- use std:: os:: raw:: c_int;
5
8
6
9
pub use num_complex:: Complex32 as c32;
7
10
pub use num_complex:: Complex64 as c64;
@@ -123,14 +126,14 @@ impl DataType {
123
126
x if x == NPY_TYPES :: NPY_BOOL as i32 => DataType :: Bool ,
124
127
x if x == NPY_TYPES :: NPY_BYTE as i32 => DataType :: Int8 ,
125
128
x if x == NPY_TYPES :: NPY_SHORT as i32 => DataType :: Int16 ,
126
- x if x == NPY_TYPES :: NPY_INT as i32 => DataType :: Int32 ,
127
- x if x == NPY_TYPES :: NPY_LONG as i32 => return DataType :: from_clong ( false ) ,
128
- x if x == NPY_TYPES :: NPY_LONGLONG as i32 => DataType :: Int64 ,
129
+ x if x == NPY_TYPES :: NPY_INT as i32 => Self :: integer :: < c_int > ( ) ? ,
130
+ x if x == NPY_TYPES :: NPY_LONG as i32 => Self :: integer :: < c_long > ( ) ? ,
131
+ x if x == NPY_TYPES :: NPY_LONGLONG as i32 => Self :: integer :: < c_longlong > ( ) ? ,
129
132
x if x == NPY_TYPES :: NPY_UBYTE as i32 => DataType :: Uint8 ,
130
133
x if x == NPY_TYPES :: NPY_USHORT as i32 => DataType :: Uint16 ,
131
- x if x == NPY_TYPES :: NPY_UINT as i32 => DataType :: Uint32 ,
132
- x if x == NPY_TYPES :: NPY_ULONG as i32 => return DataType :: from_clong ( true ) ,
133
- x if x == NPY_TYPES :: NPY_ULONGLONG as i32 => DataType :: Uint64 ,
134
+ x if x == NPY_TYPES :: NPY_UINT as i32 => Self :: integer :: < c_uint > ( ) ? ,
135
+ x if x == NPY_TYPES :: NPY_ULONG as i32 => Self :: integer :: < c_ulong > ( ) ? ,
136
+ x if x == NPY_TYPES :: NPY_ULONGLONG as i32 => Self :: integer :: < c_ulonglong > ( ) ? ,
134
137
x if x == NPY_TYPES :: NPY_FLOAT as i32 => DataType :: Float32 ,
135
138
x if x == NPY_TYPES :: NPY_DOUBLE as i32 => DataType :: Float64 ,
136
139
x if x == NPY_TYPES :: NPY_CFLOAT as i32 => DataType :: Complex32 ,
@@ -140,48 +143,71 @@ impl DataType {
140
143
} )
141
144
}
142
145
146
+ #[ inline]
147
+ fn integer < T : Bounded + Zero + Sized + PartialEq > ( ) -> Option < Self > {
148
+ let is_unsigned = T :: min_value ( ) == T :: zero ( ) ;
149
+ let bit_width = size_of :: < T > ( ) << 3 ;
150
+ Some ( match ( is_unsigned, bit_width) {
151
+ ( false , 8 ) => Self :: Int8 ,
152
+ ( false , 16 ) => Self :: Int16 ,
153
+ ( false , 32 ) => Self :: Int32 ,
154
+ ( false , 64 ) => Self :: Int64 ,
155
+ ( true , 8 ) => Self :: Uint8 ,
156
+ ( true , 16 ) => Self :: Uint16 ,
157
+ ( true , 32 ) => Self :: Uint32 ,
158
+ ( true , 64 ) => Self :: Uint64 ,
159
+ _ => return None ,
160
+ } )
161
+ }
162
+
143
163
/// Convert `self` into
144
164
/// [Enumerated Types](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types).
145
165
pub fn into_ctype ( self ) -> NPY_TYPES {
166
+ fn npy_int_type_lookup < T , T0 , T1 , T2 > ( npy_types : [ NPY_TYPES ; 3 ] ) -> NPY_TYPES {
167
+ // `npy_common.h` defines the integer aliases. In order, it checks:
168
+ // NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR
169
+ // and assigns the alias to the first matching size, so we should check in this order.
170
+ match size_of :: < T > ( ) {
171
+ x if x == size_of :: < T0 > ( ) => npy_types[ 0 ] ,
172
+ x if x == size_of :: < T1 > ( ) => npy_types[ 1 ] ,
173
+ x if x == size_of :: < T2 > ( ) => npy_types[ 2 ] ,
174
+ _ => panic ! ( "Unable to match integer type descriptor: {:?}" , npy_types) ,
175
+ }
176
+ }
177
+
146
178
match self {
147
179
DataType :: Bool => NPY_TYPES :: NPY_BOOL ,
148
180
DataType :: Int8 => NPY_TYPES :: NPY_BYTE ,
149
181
DataType :: Int16 => NPY_TYPES :: NPY_SHORT ,
150
- DataType :: Int32 => NPY_TYPES :: NPY_INT ,
151
- #[ cfg( all( target_pointer_width = "64" , not( windows) ) ) ]
152
- DataType :: Int64 => NPY_TYPES :: NPY_LONG ,
153
- #[ cfg( any( target_pointer_width = "32" , windows) ) ]
154
- DataType :: Int64 => NPY_TYPES :: NPY_LONGLONG ,
182
+ DataType :: Int32 => npy_int_type_lookup :: < i32 , c_long , c_int , c_short > ( [
183
+ NPY_TYPES :: NPY_LONG ,
184
+ NPY_TYPES :: NPY_INT ,
185
+ NPY_TYPES :: NPY_SHORT ,
186
+ ] ) ,
187
+ DataType :: Int64 => npy_int_type_lookup :: < i64 , c_long , c_longlong , c_int > ( [
188
+ NPY_TYPES :: NPY_LONG ,
189
+ NPY_TYPES :: NPY_LONGLONG ,
190
+ NPY_TYPES :: NPY_INT ,
191
+ ] ) ,
155
192
DataType :: Uint8 => NPY_TYPES :: NPY_UBYTE ,
156
193
DataType :: Uint16 => NPY_TYPES :: NPY_USHORT ,
157
- DataType :: Uint32 => NPY_TYPES :: NPY_UINT ,
158
- DataType :: Uint64 => NPY_TYPES :: NPY_ULONGLONG ,
194
+ DataType :: Uint32 => npy_int_type_lookup :: < u32 , c_ulong , c_uint , c_ushort > ( [
195
+ NPY_TYPES :: NPY_ULONG ,
196
+ NPY_TYPES :: NPY_UINT ,
197
+ NPY_TYPES :: NPY_USHORT ,
198
+ ] ) ,
199
+ DataType :: Uint64 => npy_int_type_lookup :: < u64 , c_ulong , c_ulonglong , c_uint > ( [
200
+ NPY_TYPES :: NPY_ULONG ,
201
+ NPY_TYPES :: NPY_ULONGLONG ,
202
+ NPY_TYPES :: NPY_UINT ,
203
+ ] ) ,
159
204
DataType :: Float32 => NPY_TYPES :: NPY_FLOAT ,
160
205
DataType :: Float64 => NPY_TYPES :: NPY_DOUBLE ,
161
206
DataType :: Complex32 => NPY_TYPES :: NPY_CFLOAT ,
162
207
DataType :: Complex64 => NPY_TYPES :: NPY_CDOUBLE ,
163
208
DataType :: Object => NPY_TYPES :: NPY_OBJECT ,
164
209
}
165
210
}
166
-
167
- #[ inline( always) ]
168
- fn from_clong ( is_usize : bool ) -> Option < Self > {
169
- if cfg ! ( any( target_pointer_width = "32" , windows) ) {
170
- Some ( if is_usize {
171
- DataType :: Uint32
172
- } else {
173
- DataType :: Int32
174
- } )
175
- } else if cfg ! ( all( target_pointer_width = "64" , not( windows) ) ) {
176
- Some ( if is_usize {
177
- DataType :: Uint64
178
- } else {
179
- DataType :: Int64
180
- } )
181
- } else {
182
- None
183
- }
184
- }
185
211
}
186
212
187
213
/// Represents that a type can be an element of `PyArray`.
@@ -249,59 +275,50 @@ pub unsafe trait Element: Clone + Send {
249
275
}
250
276
251
277
macro_rules! impl_num_element {
252
- ( $t: ty, $npy_dat_t: ident $( , $npy_types: ident) +) => {
253
- unsafe impl Element for $t {
254
- const DATA_TYPE : DataType = DataType :: $npy_dat_t;
278
+ ( $ty: ty, $data_type: expr) => {
279
+ unsafe impl Element for $ty {
280
+ const DATA_TYPE : DataType = $data_type;
281
+
255
282
fn is_same_type( dtype: & PyArrayDescr ) -> bool {
256
- $ ( dtype. get_typenum ( ) == NPY_TYPES :: $npy_types as i32 || ) + false
283
+ dtype. get_datatype ( ) == Some ( $data_type )
257
284
}
285
+
258
286
fn get_dtype( py: Python ) -> & PyArrayDescr {
259
- PyArrayDescr :: from_npy_type( py, DataType :: $npy_dat_t . into_ctype( ) )
287
+ PyArrayDescr :: from_npy_type( py, $data_type . into_ctype( ) )
260
288
}
261
289
}
262
290
} ;
263
291
}
264
292
265
- impl_num_element ! ( bool , Bool , NPY_BOOL ) ;
266
- impl_num_element ! ( i8 , Int8 , NPY_BYTE ) ;
267
- impl_num_element ! ( i16 , Int16 , NPY_SHORT ) ;
268
- impl_num_element ! ( u8 , Uint8 , NPY_UBYTE ) ;
269
- impl_num_element ! ( u16 , Uint16 , NPY_USHORT ) ;
270
- impl_num_element ! ( f32 , Float32 , NPY_FLOAT ) ;
271
- impl_num_element ! ( f64 , Float64 , NPY_DOUBLE ) ;
272
- impl_num_element ! ( c32, Complex32 , NPY_CFLOAT ) ;
273
- impl_num_element ! ( c64, Complex64 , NPY_CDOUBLE ) ;
293
+ impl_num_element ! ( bool , DataType :: Bool ) ;
294
+ impl_num_element ! ( i8 , DataType :: Int8 ) ;
295
+ impl_num_element ! ( i16 , DataType :: Int16 ) ;
296
+ impl_num_element ! ( i32 , DataType :: Int32 ) ;
297
+ impl_num_element ! ( i64 , DataType :: Int64 ) ;
298
+ impl_num_element ! ( u8 , DataType :: Uint8 ) ;
299
+ impl_num_element ! ( u16 , DataType :: Uint16 ) ;
300
+ impl_num_element ! ( u32 , DataType :: Uint32 ) ;
301
+ impl_num_element ! ( u64 , DataType :: Uint64 ) ;
302
+ impl_num_element ! ( f32 , DataType :: Float32 ) ;
303
+ impl_num_element ! ( f64 , DataType :: Float64 ) ;
304
+ impl_num_element ! ( c32, DataType :: Complex32 ) ;
305
+ impl_num_element ! ( c64, DataType :: Complex64 ) ;
274
306
275
307
cfg_if ! {
276
- if #[ cfg( all( target_pointer_width = "64" , windows) ) ] {
277
- impl_num_element!( usize , Uint64 , NPY_ULONGLONG ) ;
278
- } else if #[ cfg( all( target_pointer_width = "64" , not( windows) ) ) ] {
279
- impl_num_element!( usize , Uint64 , NPY_ULONG , NPY_ULONGLONG ) ;
280
- } else if #[ cfg( all( target_pointer_width = "32" , windows) ) ] {
281
- impl_num_element!( usize , Uint32 , NPY_UINT , NPY_ULONG ) ;
282
- } else if #[ cfg( all( target_pointer_width = "32" , not( windows) ) ) ] {
283
- impl_num_element!( usize , Uint32 , NPY_UINT ) ;
284
- }
285
- }
286
- cfg_if ! {
287
- if #[ cfg( any( target_pointer_width = "32" , windows) ) ] {
288
- impl_num_element!( i32 , Int32 , NPY_INT , NPY_LONG ) ;
289
- impl_num_element!( u32 , Uint32 , NPY_UINT , NPY_ULONG ) ;
290
- impl_num_element!( i64 , Int64 , NPY_LONGLONG ) ;
291
- impl_num_element!( u64 , Uint64 , NPY_ULONGLONG ) ;
292
- } else if #[ cfg( all( target_pointer_width = "64" , not( windows) ) ) ] {
293
- impl_num_element!( i32 , Int32 , NPY_INT ) ;
294
- impl_num_element!( u32 , Uint32 , NPY_UINT ) ;
295
- impl_num_element!( i64 , Int64 , NPY_LONG , NPY_LONGLONG ) ;
296
- impl_num_element!( u64 , Uint64 , NPY_ULONG , NPY_ULONGLONG ) ;
308
+ if #[ cfg( target_pointer_width = "64" ) ] {
309
+ impl_num_element!( usize , DataType :: Uint64 ) ;
310
+ } else if #[ cfg( target_pointer_width = "32" ) ] {
311
+ impl_num_element!( usize , DataType :: Uint32 ) ;
297
312
}
298
313
}
299
314
300
315
unsafe impl Element for PyObject {
301
316
const DATA_TYPE : DataType = DataType :: Object ;
317
+
302
318
fn is_same_type ( dtype : & PyArrayDescr ) -> bool {
303
319
dtype. get_typenum ( ) == NPY_TYPES :: NPY_OBJECT as i32
304
320
}
321
+
305
322
fn get_dtype ( py : Python ) -> & PyArrayDescr {
306
323
PyArrayDescr :: object ( py)
307
324
}
0 commit comments