Skip to content

Commit 61e6363

Browse files
committed
Add PyArrayDescr::new() constructor
1 parent 7cc945c commit 61e6363

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

src/dtype.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::mem::size_of;
22
use std::os::raw::{
33
c_char, c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort,
44
};
5+
use std::ptr;
56

67
use num_traits::{Bounded, Zero};
78
use pyo3::{
@@ -57,6 +58,23 @@ pub fn dtype<T: Element>(py: Python) -> &PyArrayDescr {
5758
}
5859

5960
impl PyArrayDescr {
61+
/// Creates a new dtype object from an arbitrary object.
62+
///
63+
/// Equivalent to invoking the constructor of `np.dtype`.
64+
#[inline]
65+
pub fn new<'py, T: ToPyObject + ?Sized>(py: Python<'py>, obj: &T) -> PyResult<&'py Self> {
66+
Self::new_impl(py, obj.to_object(py))
67+
}
68+
69+
fn new_impl<'py>(py: Python<'py>, obj: PyObject) -> PyResult<&'py Self> {
70+
let mut descr: *mut PyArray_Descr = ptr::null_mut();
71+
unsafe {
72+
// None is an invalid input here and is not converted to NPY_DEFAULT_TYPE
73+
PY_ARRAY_API.PyArray_DescrConverter2(obj.as_ptr(), &mut descr as *mut _);
74+
py.from_owned_ptr_or_err(descr as _)
75+
}
76+
}
77+
6078
/// Returns `self` as `*mut PyArray_Descr`.
6179
pub fn as_dtype_ptr(&self) -> *mut PyArray_Descr {
6280
self.as_ptr() as _
@@ -423,6 +441,19 @@ mod tests {
423441
use super::{dtype, Complex32, Complex64, Element, PyArrayDescr};
424442
use crate::npyffi::{NPY_ALIGNED_STRUCT, NPY_ITEM_HASOBJECT, NPY_NEEDS_PYAPI, NPY_TYPES};
425443

444+
#[test]
445+
fn test_dtype_new() {
446+
pyo3::Python::with_gil(|py| {
447+
assert_eq!(PyArrayDescr::new(py, "float64").unwrap(), dtype::<f64>(py));
448+
let d = PyArrayDescr::new(py, [("a", "O"), ("b", "?")].as_ref()).unwrap();
449+
assert_eq!(d.names(), Some(vec!["a", "b"]));
450+
assert!(d.has_object());
451+
assert_eq!(d.get_field("a").unwrap().0, dtype::<PyObject>(py));
452+
assert_eq!(d.get_field("b").unwrap().0, dtype::<bool>(py));
453+
assert!(PyArrayDescr::new(py, &123_usize).is_err());
454+
});
455+
}
456+
426457
#[test]
427458
fn test_dtype_names() {
428459
fn type_name<T: Element>(py: pyo3::Python) -> &str {

0 commit comments

Comments
 (0)