@@ -2,6 +2,7 @@ use std::mem::size_of;
2
2
use std:: os:: raw:: {
3
3
c_char, c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort,
4
4
} ;
5
+ use std:: ptr;
5
6
6
7
use num_traits:: { Bounded , Zero } ;
7
8
use pyo3:: {
@@ -57,6 +58,23 @@ pub fn dtype<T: Element>(py: Python) -> &PyArrayDescr {
57
58
}
58
59
59
60
impl PyArrayDescr {
61
+ /// Creates a new dtype object from an arbitrary object.
62
+ ///
63
+ /// Equivalent to invoking the constructor of `np.dtype`.
64
+ #[ inline]
65
+ pub fn new < ' py , T : ToPyObject + ?Sized > ( py : Python < ' py > , obj : & T ) -> PyResult < & ' py Self > {
66
+ Self :: new_impl ( py, obj. to_object ( py) )
67
+ }
68
+
69
+ fn new_impl < ' py > ( py : Python < ' py > , obj : PyObject ) -> PyResult < & ' py Self > {
70
+ let mut descr: * mut PyArray_Descr = ptr:: null_mut ( ) ;
71
+ unsafe {
72
+ // None is an invalid input here and is not converted to NPY_DEFAULT_TYPE
73
+ PY_ARRAY_API . PyArray_DescrConverter2 ( obj. as_ptr ( ) , & mut descr as * mut _ ) ;
74
+ py. from_owned_ptr_or_err ( descr as _ )
75
+ }
76
+ }
77
+
60
78
/// Returns `self` as `*mut PyArray_Descr`.
61
79
pub fn as_dtype_ptr ( & self ) -> * mut PyArray_Descr {
62
80
self . as_ptr ( ) as _
@@ -423,6 +441,19 @@ mod tests {
423
441
use super :: { dtype, Complex32 , Complex64 , Element , PyArrayDescr } ;
424
442
use crate :: npyffi:: { NPY_ALIGNED_STRUCT , NPY_ITEM_HASOBJECT , NPY_NEEDS_PYAPI , NPY_TYPES } ;
425
443
444
+ #[ test]
445
+ fn test_dtype_new ( ) {
446
+ pyo3:: Python :: with_gil ( |py| {
447
+ assert_eq ! ( PyArrayDescr :: new( py, "float64" ) . unwrap( ) , dtype:: <f64 >( py) ) ;
448
+ let d = PyArrayDescr :: new ( py, [ ( "a" , "O" ) , ( "b" , "?" ) ] . as_ref ( ) ) . unwrap ( ) ;
449
+ assert_eq ! ( d. names( ) , Some ( vec![ "a" , "b" ] ) ) ;
450
+ assert ! ( d. has_object( ) ) ;
451
+ assert_eq ! ( d. get_field( "a" ) . unwrap( ) . 0 , dtype:: <PyObject >( py) ) ;
452
+ assert_eq ! ( d. get_field( "b" ) . unwrap( ) . 0 , dtype:: <bool >( py) ) ;
453
+ assert ! ( PyArrayDescr :: new( py, & 123_usize ) . is_err( ) ) ;
454
+ } ) ;
455
+ }
456
+
426
457
#[ test]
427
458
fn test_dtype_names ( ) {
428
459
fn type_name < T : Element > ( py : pyo3:: Python ) -> & str {
0 commit comments