|
| 1 | +use std::collections::BTreeMap; |
1 | 2 | use std::mem::size_of;
|
2 |
| -use std::os::raw::{c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort}; |
| 3 | +use std::os::raw::{ |
| 4 | + c_char, c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort, |
| 5 | +}; |
3 | 6 |
|
4 | 7 | use num_traits::{Bounded, Zero};
|
5 |
| -use pyo3::{ffi, prelude::*, pyobject_native_type_core, types::PyType, AsPyPointer, PyNativeType}; |
| 8 | +use pyo3::{ |
| 9 | + ffi::{self, PyTuple_Size}, |
| 10 | + prelude::*, |
| 11 | + pyobject_native_type_core, |
| 12 | + types::{PyDict, PyTuple, PyType}, |
| 13 | + AsPyPointer, FromPyObject, FromPyPointer, PyNativeType, PyResult, |
| 14 | +}; |
6 | 15 |
|
7 |
| -use crate::npyffi::{NpyTypes, PyArray_Descr, NPY_TYPES, PY_ARRAY_API}; |
| 16 | +use crate::npyffi::{ |
| 17 | + NpyTypes, PyArray_Descr, NPY_ALIGNED_STRUCT, NPY_BYTEORDER_CHAR, NPY_ITEM_HASOBJECT, NPY_TYPES, |
| 18 | + PY_ARRAY_API, |
| 19 | +}; |
8 | 20 |
|
9 | 21 | pub use num_complex::{Complex32, Complex64};
|
10 | 22 |
|
@@ -103,6 +115,193 @@ impl PyArrayDescr {
|
103 | 115 | pub fn num(&self) -> c_int {
|
104 | 116 | unsafe { *self.as_dtype_ptr() }.type_num
|
105 | 117 | }
|
| 118 | + |
| 119 | + /// Returns the element size of this data-type object. |
| 120 | + /// |
| 121 | + /// Equivalent to [`np.dtype.itemsize`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.itemsize.html). |
| 122 | + pub fn itemsize(&self) -> usize { |
| 123 | + unsafe { *self.as_dtype_ptr() }.elsize.max(0) as _ |
| 124 | + } |
| 125 | + |
| 126 | + /// Returns the required alignment (bytes) of this data-type according to the compiler |
| 127 | + /// |
| 128 | + /// Equivalent to [`np.dtype.alignment`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.alignment.html). |
| 129 | + pub fn alignment(&self) -> usize { |
| 130 | + unsafe { *self.as_dtype_ptr() }.alignment.max(0) as _ |
| 131 | + } |
| 132 | + |
| 133 | + /// Returns a character indicating the byte-order of this data-type object. |
| 134 | + /// |
| 135 | + /// All built-in data-type objects have byteorder either `=` or `|`. |
| 136 | + /// |
| 137 | + /// Equivalent to [`np.dtype.byteorder`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.byteorder.html). |
| 138 | + pub fn byteorder(&self) -> u8 { |
| 139 | + unsafe { *self.as_dtype_ptr() }.byteorder.max(0) as _ |
| 140 | + } |
| 141 | + |
| 142 | + /// Returns a unique character code for each of the 21 different built-in types. |
| 143 | + /// |
| 144 | + /// Note: structured data types are categorized as `V` (void). |
| 145 | + /// |
| 146 | + /// Equivalent to [`np.dtype.char`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.char.html) |
| 147 | + pub fn char(&self) -> u8 { |
| 148 | + unsafe { *self.as_dtype_ptr() }.type_.max(0) as _ |
| 149 | + } |
| 150 | + |
| 151 | + /// Returns a character code (one of `biufcmMOSUV`) identifying the general kind of data. |
| 152 | + /// |
| 153 | + /// Note: structured data types are categorized as `V` (void). |
| 154 | + /// |
| 155 | + /// Equivalent to [`np.dtype.kind`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.kind.html) |
| 156 | + pub fn kind(&self) -> u8 { |
| 157 | + unsafe { *self.as_dtype_ptr() }.kind.max(0) as _ |
| 158 | + } |
| 159 | + |
| 160 | + /// Returns bit-flags describing how this data type is to be interpreted. |
| 161 | + /// |
| 162 | + /// Equivalent to [`np.dtype.flags`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.flags.html) |
| 163 | + pub fn flags(&self) -> c_char { |
| 164 | + unsafe { *self.as_dtype_ptr() }.flags |
| 165 | + } |
| 166 | + |
| 167 | + /// Returns the number of dimensions if this data type describes a sub-array, and `0` otherwise. |
| 168 | + /// |
| 169 | + /// Equivalent to [`np.dtype.ndim`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.ndim.html) |
| 170 | + pub fn ndim(&self) -> usize { |
| 171 | + if !self.has_subarray() { |
| 172 | + return 0; |
| 173 | + } |
| 174 | + unsafe { PyTuple_Size((*((*self.as_dtype_ptr()).subarray)).shape).max(0) as _ } |
| 175 | + } |
| 176 | + |
| 177 | + /// Returns dtype for the base element of subarrays, regardless of their dimension or shape. |
| 178 | + /// |
| 179 | + /// Equivalent to [`np.dtype.base`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.base.html). |
| 180 | + pub fn base(&self) -> Option<&PyArrayDescr> { |
| 181 | + if !self.has_subarray() { |
| 182 | + return None; |
| 183 | + } |
| 184 | + Some(unsafe { Self::from_borrowed_ptr(self.py(), (*self.as_dtype_ptr()).subarray as _) }) |
| 185 | + } |
| 186 | + |
| 187 | + /// Returns shape tuple of the sub-array if this dtype is a sub-array, and `None` otherwise. |
| 188 | + /// |
| 189 | + /// Equivalent to [`np.dtype.shape`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.shape.html) |
| 190 | + pub fn shape(&self) -> Option<Vec<usize>> { |
| 191 | + if !self.has_subarray() { |
| 192 | + return None; |
| 193 | + } |
| 194 | + Some( |
| 195 | + // TODO: can this be done simpler, without the incref? |
| 196 | + unsafe { |
| 197 | + PyTuple::from_borrowed_ptr(self.py(), (*(*self.as_dtype_ptr()).subarray).shape) |
| 198 | + } |
| 199 | + .extract() |
| 200 | + .unwrap(), // TODO: unwrap? numpy sort-of guarantees it will be an int tuple |
| 201 | + ) |
| 202 | + } |
| 203 | + |
| 204 | + /// Returns `(item_dtype, shape)` if this dtype describes a sub-array, and `None` otherwise. |
| 205 | + /// |
| 206 | + /// The `shape` is the fixed shape of the sub-array described by this data type, |
| 207 | + /// and `item_dtype` the data type of the array. |
| 208 | + /// |
| 209 | + /// If a field whose dtype object has this attribute is retrieved, then the extra dimensions |
| 210 | + /// implied by shape are tacked on to the end of the retrieved array. |
| 211 | + /// |
| 212 | + /// Equivalent to [`np.dtype.subdtype`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.subdtype.html) |
| 213 | + pub fn subdtype(&self) -> Option<(&PyArrayDescr, Vec<usize>)> { |
| 214 | + self.shape() |
| 215 | + .and_then(|shape| self.base().map(|base| (base, shape))) |
| 216 | + } |
| 217 | + |
| 218 | + /// Returns true if the dtype is a sub-array at the top level. |
| 219 | + /// |
| 220 | + /// Equivalent to [`np.dtype.hasobject`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.hasobject.html) |
| 221 | + pub fn has_object(&self) -> bool { |
| 222 | + self.flags() & NPY_ITEM_HASOBJECT != 0 |
| 223 | + } |
| 224 | + |
| 225 | + /// Returns true if the dtype is a struct which maintains field alignment. |
| 226 | + /// |
| 227 | + /// This flag is sticky, so when combining multiple structs together, it is preserved |
| 228 | + /// and produces new dtypes which are also aligned. |
| 229 | + /// |
| 230 | + /// Equivalent to [`np.dtype.isalignedstruct`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.isalignedstruct.html) |
| 231 | + pub fn is_aligned_struct(&self) -> bool { |
| 232 | + self.flags() & NPY_ALIGNED_STRUCT != 0 |
| 233 | + } |
| 234 | + |
| 235 | + /// Returns true if the data type is a sub-array. |
| 236 | + pub fn has_subarray(&self) -> bool { |
| 237 | + // equivalent to PyDataType_HASSUBARRAY(self) |
| 238 | + unsafe { !(*self.as_dtype_ptr()).subarray.is_null() } |
| 239 | + } |
| 240 | + |
| 241 | + /// Returns true if the data type is a structured type. |
| 242 | + pub fn has_fields(&self) -> bool { |
| 243 | + // equivalent to PyDataType_HASFIELDS(self) |
| 244 | + unsafe { !(*self.as_dtype_ptr()).names.is_null() } |
| 245 | + } |
| 246 | + |
| 247 | + /// Returns true if the data type is unsized |
| 248 | + pub fn is_unsized(&self) -> bool { |
| 249 | + // equivalent to PyDataType_ISUNSIZED(self) |
| 250 | + self.itemsize() == 0 && !self.has_fields() |
| 251 | + } |
| 252 | + |
| 253 | + /// Returns true if data type byteorder is native, or `None` if not applicable. |
| 254 | + pub fn is_native_byteorder(&self) -> Option<bool> { |
| 255 | + // based on PyArray_ISNBO(self->byteorder) |
| 256 | + match self.byteorder() { |
| 257 | + b'=' => Some(true), |
| 258 | + b'|' => None, |
| 259 | + byteorder if byteorder == NPY_BYTEORDER_CHAR::NPY_NATBYTE as u8 => Some(true), |
| 260 | + _ => Some(false), |
| 261 | + } |
| 262 | + } |
| 263 | + |
| 264 | + /// Returns an ordered list of field names, or `None` if there are no fields. |
| 265 | + /// |
| 266 | + /// The names are ordered according to increasing byte offset. |
| 267 | + /// |
| 268 | + /// Equivalent to [`np.dtype.names`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.names.html). |
| 269 | + pub fn names(&self) -> Option<Vec<&str>> { |
| 270 | + if !self.has_fields() { |
| 271 | + return None; |
| 272 | + } |
| 273 | + let names = unsafe { PyTuple::from_borrowed_ptr(self.py(), (*self.as_dtype_ptr()).names) }; |
| 274 | + <_>::extract(names).ok() |
| 275 | + } |
| 276 | + |
| 277 | + /// Returns a dictionary of fields, or `None` if not a structured type. |
| 278 | + /// |
| 279 | + /// The dictionary is indexed by keys that are the names of the fields. Each entry in |
| 280 | + /// the dictionary is a tuple fully describing the field: `(dtype, offset)`. |
| 281 | + /// |
| 282 | + /// Note: titles (the optional 3rd tuple element) are ignored. |
| 283 | + /// |
| 284 | + /// Equivalent to [`np.dtype.fields`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.fields.html). |
| 285 | + pub fn fields(&self) -> Option<BTreeMap<&str, (&PyArrayDescr, usize)>> { |
| 286 | + if !self.has_fields() { |
| 287 | + return None; |
| 288 | + } |
| 289 | + // TODO: can this be done simpler, without the incref? |
| 290 | + let dict = unsafe { PyDict::from_borrowed_ptr(self.py(), (*self.as_dtype_ptr()).fields) }; |
| 291 | + let mut fields = BTreeMap::new(); |
| 292 | + (|| -> PyResult<_> { |
| 293 | + for (k, v) in dict.iter() { |
| 294 | + // TODO: alternatively, could unwrap everything here |
| 295 | + let name = <_>::extract(k)?; |
| 296 | + let tuple = v.downcast::<PyTuple>()?; |
| 297 | + let dtype = <_>::extract(tuple.as_ref().get_item(0)?)?; |
| 298 | + let offset = <_>::extract(tuple.as_ref().get_item(1)?)?; |
| 299 | + fields.insert(name, (dtype, offset)); |
| 300 | + } |
| 301 | + Ok(fields) |
| 302 | + })() |
| 303 | + .ok() |
| 304 | + } |
106 | 305 | }
|
107 | 306 |
|
108 | 307 | /// Represents that a type can be an element of `PyArray`.
|
|
0 commit comments