@@ -2,6 +2,7 @@ use std::mem::size_of;
22use std:: os:: raw:: {
33 c_char, c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort,
44} ;
5+ use std:: ptr;
56
67use num_traits:: { Bounded , Zero } ;
78use pyo3:: {
@@ -57,6 +58,23 @@ pub fn dtype<T: Element>(py: Python) -> &PyArrayDescr {
5758}
5859
5960impl PyArrayDescr {
61+ /// Creates a new dtype object from an arbitrary object.
62+ ///
63+ /// Equivalent to invoking the constructor of `np.dtype`.
64+ #[ inline]
65+ pub fn new < ' py , T : ToPyObject + ?Sized > ( py : Python < ' py > , obj : & T ) -> PyResult < & ' py Self > {
66+ Self :: new_impl ( py, obj. to_object ( py) )
67+ }
68+
69+ fn new_impl < ' py > ( py : Python < ' py > , obj : PyObject ) -> PyResult < & ' py Self > {
70+ let mut descr: * mut PyArray_Descr = ptr:: null_mut ( ) ;
71+ unsafe {
72+ // None is an invalid input here and is not converted to NPY_DEFAULT_TYPE
73+ PY_ARRAY_API . PyArray_DescrConverter2 ( obj. as_ptr ( ) , & mut descr as * mut _ ) ;
74+ py. from_owned_ptr_or_err ( descr as _ )
75+ }
76+ }
77+
6078 /// Returns `self` as `*mut PyArray_Descr`.
6179 pub fn as_dtype_ptr ( & self ) -> * mut PyArray_Descr {
6280 self . as_ptr ( ) as _
@@ -423,6 +441,19 @@ mod tests {
423441 use super :: { dtype, Complex32 , Complex64 , Element , PyArrayDescr } ;
424442 use crate :: npyffi:: { NPY_ALIGNED_STRUCT , NPY_ITEM_HASOBJECT , NPY_NEEDS_PYAPI , NPY_TYPES } ;
425443
444+ #[ test]
445+ fn test_dtype_new ( ) {
446+ pyo3:: Python :: with_gil ( |py| {
447+ assert_eq ! ( PyArrayDescr :: new( py, "float64" ) . unwrap( ) , dtype:: <f64 >( py) ) ;
448+ let d = PyArrayDescr :: new ( py, [ ( "a" , "O" ) , ( "b" , "?" ) ] . as_ref ( ) ) . unwrap ( ) ;
449+ assert_eq ! ( d. names( ) , Some ( vec![ "a" , "b" ] ) ) ;
450+ assert ! ( d. has_object( ) ) ;
451+ assert_eq ! ( d. get_field( "a" ) . unwrap( ) . 0 , dtype:: <PyObject >( py) ) ;
452+ assert_eq ! ( d. get_field( "b" ) . unwrap( ) . 0 , dtype:: <bool >( py) ) ;
453+ assert ! ( PyArrayDescr :: new( py, & 123_usize ) . is_err( ) ) ;
454+ } ) ;
455+ }
456+
426457 #[ test]
427458 fn test_dtype_names ( ) {
428459 fn type_name < T : Element > ( py : pyo3:: Python ) -> & str {
0 commit comments