Skip to content

Commit 45f6d4a

Browse files
committed
Add many PyArrayDescr methods like in np.dtype
1 parent 5a1a91d commit 45f6d4a

File tree

1 file changed

+202
-3
lines changed

1 file changed

+202
-3
lines changed

src/dtype.rs

Lines changed: 202 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
1+
use std::collections::BTreeMap;
12
use std::mem::size_of;
2-
use std::os::raw::{c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort};
3+
use std::os::raw::{
4+
c_char, c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort,
5+
};
36

47
use num_traits::{Bounded, Zero};
5-
use pyo3::{ffi, prelude::*, pyobject_native_type_core, types::PyType, AsPyPointer, PyNativeType};
8+
use pyo3::{
9+
ffi::{self, PyTuple_Size},
10+
prelude::*,
11+
pyobject_native_type_core,
12+
types::{PyDict, PyTuple, PyType},
13+
AsPyPointer, FromPyObject, FromPyPointer, PyNativeType, PyResult,
14+
};
615

7-
use crate::npyffi::{NpyTypes, PyArray_Descr, NPY_TYPES, PY_ARRAY_API};
16+
use crate::npyffi::{
17+
NpyTypes, PyArray_Descr, NPY_ALIGNED_STRUCT, NPY_BYTEORDER_CHAR, NPY_ITEM_HASOBJECT, NPY_TYPES,
18+
PY_ARRAY_API,
19+
};
820

921
pub use num_complex::{Complex32, Complex64};
1022

@@ -103,6 +115,193 @@ impl PyArrayDescr {
103115
pub fn num(&self) -> c_int {
104116
unsafe { *self.as_dtype_ptr() }.type_num
105117
}
118+
119+
/// Returns the element size of this data-type object.
120+
///
121+
/// Equivalent to [`np.dtype.itemsize`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.itemsize.html).
122+
pub fn itemsize(&self) -> usize {
123+
unsafe { *self.as_dtype_ptr() }.elsize.max(0) as _
124+
}
125+
126+
/// Returns the required alignment (bytes) of this data-type according to the compiler
127+
///
128+
/// Equivalent to [`np.dtype.alignment`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.alignment.html).
129+
pub fn alignment(&self) -> usize {
130+
unsafe { *self.as_dtype_ptr() }.alignment.max(0) as _
131+
}
132+
133+
/// Returns a character indicating the byte-order of this data-type object.
134+
///
135+
/// All built-in data-type objects have byteorder either `=` or `|`.
136+
///
137+
/// Equivalent to [`np.dtype.byteorder`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.byteorder.html).
138+
pub fn byteorder(&self) -> u8 {
139+
unsafe { *self.as_dtype_ptr() }.byteorder.max(0) as _
140+
}
141+
142+
/// Returns a unique character code for each of the 21 different built-in types.
143+
///
144+
/// Note: structured data types are categorized as `V` (void).
145+
///
146+
/// Equivalent to [`np.dtype.char`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.char.html)
147+
pub fn char(&self) -> u8 {
148+
unsafe { *self.as_dtype_ptr() }.type_.max(0) as _
149+
}
150+
151+
/// Returns a character code (one of `biufcmMOSUV`) identifying the general kind of data.
152+
///
153+
/// Note: structured data types are categorized as `V` (void).
154+
///
155+
/// Equivalent to [`np.dtype.kind`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.kind.html)
156+
pub fn kind(&self) -> u8 {
157+
unsafe { *self.as_dtype_ptr() }.kind.max(0) as _
158+
}
159+
160+
/// Returns bit-flags describing how this data type is to be interpreted.
161+
///
162+
/// Equivalent to [`np.dtype.flags`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.flags.html)
163+
pub fn flags(&self) -> c_char {
164+
unsafe { *self.as_dtype_ptr() }.flags
165+
}
166+
167+
/// Returns the number of dimensions if this data type describes a sub-array, and `0` otherwise.
168+
///
169+
/// Equivalent to [`np.dtype.ndim`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.ndim.html)
170+
pub fn ndim(&self) -> usize {
171+
if !self.has_subarray() {
172+
return 0;
173+
}
174+
unsafe { PyTuple_Size((*((*self.as_dtype_ptr()).subarray)).shape).max(0) as _ }
175+
}
176+
177+
/// Returns dtype for the base element of subarrays, regardless of their dimension or shape.
178+
///
179+
/// Equivalent to [`np.dtype.base`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.base.html).
180+
pub fn base(&self) -> Option<&PyArrayDescr> {
181+
if !self.has_subarray() {
182+
return None;
183+
}
184+
Some(unsafe { Self::from_borrowed_ptr(self.py(), (*self.as_dtype_ptr()).subarray as _) })
185+
}
186+
187+
/// Returns shape tuple of the sub-array if this dtype is a sub-array, and `None` otherwise.
188+
///
189+
/// Equivalent to [`np.dtype.shape`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.shape.html)
190+
pub fn shape(&self) -> Option<Vec<usize>> {
191+
if !self.has_subarray() {
192+
return None;
193+
}
194+
Some(
195+
// TODO: can this be done simpler, without the incref?
196+
unsafe {
197+
PyTuple::from_borrowed_ptr(self.py(), (*(*self.as_dtype_ptr()).subarray).shape)
198+
}
199+
.extract()
200+
.unwrap(), // TODO: unwrap? numpy sort-of guarantees it will be an int tuple
201+
)
202+
}
203+
204+
/// Returns `(item_dtype, shape)` if this dtype describes a sub-array, and `None` otherwise.
205+
///
206+
/// The `shape` is the fixed shape of the sub-array described by this data type,
207+
/// and `item_dtype` the data type of the array.
208+
///
209+
/// If a field whose dtype object has this attribute is retrieved, then the extra dimensions
210+
/// implied by shape are tacked on to the end of the retrieved array.
211+
///
212+
/// Equivalent to [`np.dtype.subdtype`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.subdtype.html)
213+
pub fn subdtype(&self) -> Option<(&PyArrayDescr, Vec<usize>)> {
214+
self.shape()
215+
.and_then(|shape| self.base().map(|base| (base, shape)))
216+
}
217+
218+
/// Returns true if the dtype is a sub-array at the top level.
219+
///
220+
/// Equivalent to [`np.dtype.hasobject`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.hasobject.html)
221+
pub fn has_object(&self) -> bool {
222+
self.flags() & NPY_ITEM_HASOBJECT != 0
223+
}
224+
225+
/// Returns true if the dtype is a struct which maintains field alignment.
226+
///
227+
/// This flag is sticky, so when combining multiple structs together, it is preserved
228+
/// and produces new dtypes which are also aligned.
229+
///
230+
/// Equivalent to [`np.dtype.isalignedstruct`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.isalignedstruct.html)
231+
pub fn is_aligned_struct(&self) -> bool {
232+
self.flags() & NPY_ALIGNED_STRUCT != 0
233+
}
234+
235+
/// Returns true if the data type is a sub-array.
236+
pub fn has_subarray(&self) -> bool {
237+
// equivalent to PyDataType_HASSUBARRAY(self)
238+
unsafe { !(*self.as_dtype_ptr()).subarray.is_null() }
239+
}
240+
241+
/// Returns true if the data type is a structured type.
242+
pub fn has_fields(&self) -> bool {
243+
// equivalent to PyDataType_HASFIELDS(self)
244+
unsafe { !(*self.as_dtype_ptr()).names.is_null() }
245+
}
246+
247+
/// Returns true if the data type is unsized
248+
pub fn is_unsized(&self) -> bool {
249+
// equivalent to PyDataType_ISUNSIZED(self)
250+
self.itemsize() == 0 && !self.has_fields()
251+
}
252+
253+
/// Returns true if data type byteorder is native, or `None` if not applicable.
254+
pub fn is_native_byteorder(&self) -> Option<bool> {
255+
// based on PyArray_ISNBO(self->byteorder)
256+
match self.byteorder() {
257+
b'=' => Some(true),
258+
b'|' => None,
259+
byteorder if byteorder == NPY_BYTEORDER_CHAR::NPY_NATBYTE as u8 => Some(true),
260+
_ => Some(false),
261+
}
262+
}
263+
264+
/// Returns an ordered list of field names, or `None` if there are no fields.
265+
///
266+
/// The names are ordered according to increasing byte offset.
267+
///
268+
/// Equivalent to [`np.dtype.names`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.names.html).
269+
pub fn names(&self) -> Option<Vec<&str>> {
270+
if !self.has_fields() {
271+
return None;
272+
}
273+
let names = unsafe { PyTuple::from_borrowed_ptr(self.py(), (*self.as_dtype_ptr()).names) };
274+
<_>::extract(names).ok()
275+
}
276+
277+
/// Returns a dictionary of fields, or `None` if not a structured type.
278+
///
279+
/// The dictionary is indexed by keys that are the names of the fields. Each entry in
280+
/// the dictionary is a tuple fully describing the field: `(dtype, offset)`.
281+
///
282+
/// Note: titles (the optional 3rd tuple element) are ignored.
283+
///
284+
/// Equivalent to [`np.dtype.fields`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.fields.html).
285+
pub fn fields(&self) -> Option<BTreeMap<&str, (&PyArrayDescr, usize)>> {
286+
if !self.has_fields() {
287+
return None;
288+
}
289+
// TODO: can this be done simpler, without the incref?
290+
let dict = unsafe { PyDict::from_borrowed_ptr(self.py(), (*self.as_dtype_ptr()).fields) };
291+
let mut fields = BTreeMap::new();
292+
(|| -> PyResult<_> {
293+
for (k, v) in dict.iter() {
294+
// TODO: alternatively, could unwrap everything here
295+
let name = <_>::extract(k)?;
296+
let tuple = v.downcast::<PyTuple>()?;
297+
let dtype = <_>::extract(tuple.as_ref().get_item(0)?)?;
298+
let offset = <_>::extract(tuple.as_ref().get_item(1)?)?;
299+
fields.insert(name, (dtype, offset));
300+
}
301+
Ok(fields)
302+
})()
303+
.ok()
304+
}
106305
}
107306

108307
/// Represents that a type can be an element of `PyArray`.

0 commit comments

Comments
 (0)