1
- use crate :: npyffi:: { NpyTypes , PyArray_Descr , PY_ARRAY_API } ;
1
+ //! Implements conversion utitlities.
2
+ use crate :: npyffi:: { NpyTypes , PyArray_Descr , NPY_TYPES , PY_ARRAY_API } ;
3
+ pub use num_complex:: Complex32 as c32;
4
+ pub use num_complex:: Complex64 as c64;
2
5
use pyo3:: ffi;
3
6
use pyo3:: prelude:: * ;
7
+ use pyo3:: types:: PyType ;
8
+ use pyo3:: { AsPyPointer , PyNativeType } ;
4
9
use std:: os:: raw:: c_int;
5
10
6
11
pub struct PyArrayDescr ( PyAny ) ;
@@ -21,3 +26,179 @@ unsafe fn arraydescr_check(op: *mut ffi::PyObject) -> c_int {
21
26
PY_ARRAY_API . get_type_object ( NpyTypes :: PyArrayDescr_Type ) ,
22
27
)
23
28
}
29
+
30
+ impl PyArrayDescr {
31
+ pub fn as_dtype_ptr ( & self ) -> * mut PyArray_Descr {
32
+ self . as_ptr ( ) as _
33
+ }
34
+
35
+ pub fn get_type ( & self ) -> & PyType {
36
+ let dtype_type_ptr = unsafe { * self . as_dtype_ptr ( ) } . typeobj ;
37
+ unsafe { PyType :: from_type_ptr ( self . py ( ) , dtype_type_ptr) }
38
+ }
39
+
40
+ pub fn get_typenum ( & self ) -> std:: os:: raw:: c_int {
41
+ unsafe { * self . as_dtype_ptr ( ) } . type_num
42
+ }
43
+
44
+ pub fn get_datatype ( & self ) -> Option < DataType > {
45
+ DataType :: from_typenum ( self . get_typenum ( ) )
46
+ }
47
+
48
+ pub fn from_npy_type ( py : Python , npy_type : NPY_TYPES ) -> & Self {
49
+ unsafe {
50
+ let descr = PY_ARRAY_API . PyArray_DescrFromType ( npy_type as i32 ) ;
51
+ py. from_owned_ptr ( descr as _ )
52
+ }
53
+ }
54
+ }
55
+
56
+ /// An enum type represents numpy data type.
57
+ ///
58
+ /// This type is mainly for displaying error, and user don't have to use it directly.
59
+ #[ derive( Clone , Debug , Eq , PartialEq ) ]
60
+ pub enum DataType {
61
+ Bool ,
62
+ Int8 ,
63
+ Int16 ,
64
+ Int32 ,
65
+ Int64 ,
66
+ Uint8 ,
67
+ Uint16 ,
68
+ Uint32 ,
69
+ Uint64 ,
70
+ Float32 ,
71
+ Float64 ,
72
+ Complex32 ,
73
+ Complex64 ,
74
+ Object ,
75
+ }
76
+
77
+ impl DataType {
78
+ pub fn from_typenum ( typenum : c_int ) -> Option < Self > {
79
+ Some ( match typenum {
80
+ x if x == NPY_TYPES :: NPY_BOOL as i32 => DataType :: Bool ,
81
+ x if x == NPY_TYPES :: NPY_BYTE as i32 => DataType :: Int8 ,
82
+ x if x == NPY_TYPES :: NPY_SHORT as i32 => DataType :: Int16 ,
83
+ x if x == NPY_TYPES :: NPY_INT as i32 => DataType :: Int32 ,
84
+ x if x == NPY_TYPES :: NPY_LONG as i32 => return DataType :: from_clong ( false ) ,
85
+ x if x == NPY_TYPES :: NPY_LONGLONG as i32 => DataType :: Int64 ,
86
+ x if x == NPY_TYPES :: NPY_UBYTE as i32 => DataType :: Uint8 ,
87
+ x if x == NPY_TYPES :: NPY_USHORT as i32 => DataType :: Uint16 ,
88
+ x if x == NPY_TYPES :: NPY_UINT as i32 => DataType :: Uint32 ,
89
+ x if x == NPY_TYPES :: NPY_ULONG as i32 => return DataType :: from_clong ( true ) ,
90
+ x if x == NPY_TYPES :: NPY_ULONGLONG as i32 => DataType :: Uint64 ,
91
+ x if x == NPY_TYPES :: NPY_FLOAT as i32 => DataType :: Float32 ,
92
+ x if x == NPY_TYPES :: NPY_DOUBLE as i32 => DataType :: Float64 ,
93
+ x if x == NPY_TYPES :: NPY_CFLOAT as i32 => DataType :: Complex32 ,
94
+ x if x == NPY_TYPES :: NPY_CDOUBLE as i32 => DataType :: Complex64 ,
95
+ x if x == NPY_TYPES :: NPY_OBJECT as i32 => DataType :: Object ,
96
+ _ => return None ,
97
+ } )
98
+ }
99
+
100
+ pub fn from_dtype ( dtype : & crate :: PyArrayDescr ) -> Option < Self > {
101
+ Self :: from_typenum ( dtype. get_typenum ( ) )
102
+ }
103
+
104
+ #[ inline]
105
+ pub fn into_ctype ( self ) -> NPY_TYPES {
106
+ match self {
107
+ DataType :: Bool => NPY_TYPES :: NPY_BOOL ,
108
+ DataType :: Int8 => NPY_TYPES :: NPY_BYTE ,
109
+ DataType :: Int16 => NPY_TYPES :: NPY_SHORT ,
110
+ DataType :: Int32 => NPY_TYPES :: NPY_INT ,
111
+ DataType :: Int64 => NPY_TYPES :: NPY_LONGLONG ,
112
+ DataType :: Uint8 => NPY_TYPES :: NPY_UBYTE ,
113
+ DataType :: Uint16 => NPY_TYPES :: NPY_USHORT ,
114
+ DataType :: Uint32 => NPY_TYPES :: NPY_UINT ,
115
+ DataType :: Uint64 => NPY_TYPES :: NPY_ULONGLONG ,
116
+ DataType :: Float32 => NPY_TYPES :: NPY_FLOAT ,
117
+ DataType :: Float64 => NPY_TYPES :: NPY_DOUBLE ,
118
+ DataType :: Complex32 => NPY_TYPES :: NPY_CFLOAT ,
119
+ DataType :: Complex64 => NPY_TYPES :: NPY_CDOUBLE ,
120
+ DataType :: Object => NPY_TYPES :: NPY_OBJECT ,
121
+ }
122
+ }
123
+
124
+ #[ inline( always) ]
125
+ fn from_clong ( is_usize : bool ) -> Option < Self > {
126
+ if cfg ! ( any( target_pointer_width = "32" , windows) ) {
127
+ Some ( if is_usize {
128
+ DataType :: Uint32
129
+ } else {
130
+ DataType :: Int32
131
+ } )
132
+ } else if cfg ! ( all( target_pointer_width = "64" , not( windows) ) ) {
133
+ Some ( if is_usize {
134
+ DataType :: Uint64
135
+ } else {
136
+ DataType :: Int64
137
+ } )
138
+ } else {
139
+ None
140
+ }
141
+ }
142
+ }
143
+
144
+ /// Represents that a type can be an element of `PyArray`.
145
+ pub trait Element : Clone {
146
+ const DATA_TYPE : DataType ;
147
+
148
+ fn is_same_type ( dtype : & PyArrayDescr ) -> bool ;
149
+
150
+ #[ inline]
151
+ fn npy_type ( ) -> NPY_TYPES {
152
+ Self :: DATA_TYPE . into_ctype ( )
153
+ }
154
+
155
+ fn get_dtype ( py : Python ) -> & PyArrayDescr {
156
+ PyArrayDescr :: from_npy_type ( py, Self :: npy_type ( ) )
157
+ }
158
+ }
159
+
160
+ macro_rules! impl_num_element {
161
+ ( $t: ty, $npy_dat_t: ident $( , $npy_types: ident) +) => {
162
+ impl Element for $t {
163
+ const DATA_TYPE : DataType = DataType :: $npy_dat_t;
164
+ fn is_same_type( dtype: & PyArrayDescr ) -> bool {
165
+ $( dtype. get_typenum( ) == NPY_TYPES :: $npy_types as i32 ||) + false
166
+ }
167
+ }
168
+ } ;
169
+ }
170
+
171
+ impl_num_element ! ( bool , Bool , NPY_BOOL ) ;
172
+ impl_num_element ! ( i8 , Int8 , NPY_BYTE ) ;
173
+ impl_num_element ! ( i16 , Int16 , NPY_SHORT ) ;
174
+ impl_num_element ! ( u8 , Uint8 , NPY_UBYTE ) ;
175
+ impl_num_element ! ( u16 , Uint16 , NPY_USHORT ) ;
176
+ impl_num_element ! ( f32 , Float32 , NPY_FLOAT ) ;
177
+ impl_num_element ! ( f64 , Float64 , NPY_DOUBLE ) ;
178
+ impl_num_element ! ( c32, Complex32 , NPY_CFLOAT ) ;
179
+ impl_num_element ! ( c64, Complex64 , NPY_CDOUBLE ) ;
180
+
181
+ cfg_if ! {
182
+ if #[ cfg( all( target_pointer_width = "64" , windows) ) ] {
183
+ impl_num_element!( usize , Uint64 , NPY_ULONGLONG ) ;
184
+ } else if #[ cfg( all( target_pointer_width = "64" , not( windows) ) ) ] {
185
+ impl_num_element!( usize , Uint64 , NPY_ULONG , NPY_ULONGLONG ) ;
186
+ } else if #[ cfg( all( target_pointer_width = "32" , windows) ) ] {
187
+ impl_num_element!( usize , Uint32 , NPY_UINT , NPY_ULONG ) ;
188
+ } else if #[ cfg( all( target_pointer_width = "32" , not( windows) ) ) ] {
189
+ impl_num_element!( usize , Uint32 , NPY_UINT ) ;
190
+ }
191
+ }
192
+ cfg_if ! {
193
+ if #[ cfg( any( target_pointer_width = "32" , windows) ) ] {
194
+ impl_num_element!( i32 , Int32 , NPY_INT , NPY_LONG ) ;
195
+ impl_num_element!( u32 , Uint32 , NPY_UINT , NPY_ULONG ) ;
196
+ impl_num_element!( i64 , Int64 , NPY_LONGLONG ) ;
197
+ impl_num_element!( u64 , Uint64 , NPY_ULONGLONG ) ;
198
+ } else if #[ cfg( all( target_pointer_width = "64" , not( windows) ) ) ] {
199
+ impl_num_element!( i32 , Int32 , NPY_INT ) ;
200
+ impl_num_element!( u32 , Uint32 , NPY_UINT ) ;
201
+ impl_num_element!( i64 , Int64 , NPY_LONG , NPY_LONGLONG ) ;
202
+ impl_num_element!( u64 , Uint64 , NPY_ULONG , NPY_ULONGLONG ) ;
203
+ }
204
+ }
0 commit comments