@@ -18,6 +18,7 @@ use crate::npyffi::{
1818} ;
1919
2020pub use num_complex:: { Complex32 , Complex64 } ;
21+ use pyo3:: exceptions:: { PyIndexError , PyValueError } ;
2122
2223/// Binding of [`numpy.dtype`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html).
2324///
@@ -268,28 +269,32 @@ impl PyArrayDescr {
268269 FromPyObject :: extract ( names) . ok ( )
269270 }
270271
271- /// Returns names, types and offsets of fields, or `None` if not a structured type .
272+ /// Returns the dtype and offset of a field with a given name .
272273 ///
273- /// The iterator has entries in the form `(name, ( dtype, offset))` so it can be
274- /// collected directly into a map-like structure .
274+ /// This method will return an error if the dtype is not structured, or if it doesn't
275+ /// contain a field with a given name .
275276 ///
276- /// Note: titles (the optional 3rd tuple element) are ignored .
277+ /// The list of all names can be found via [`PyArrayDescr::names`] .
277278 ///
278- /// Equivalent to [`np.dtype.fields`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.fields.html).
279- pub fn fields ( & self ) -> Option < impl Iterator < Item = ( & str , ( & PyArrayDescr , usize ) ) > + ' _ > {
279+ /// Equivalent to retrieving a single item from
280+ /// [`np.dtype.fields`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.fields.html).
281+ pub fn get_field ( & self , name : & str ) -> PyResult < ( & PyArrayDescr , usize ) > {
280282 if !self . has_fields ( ) {
281- return None ;
283+ return Err ( PyValueError :: new_err (
284+ "cannot get field information: dtype has no fields" ,
285+ ) ) ;
282286 }
283287 let dict = unsafe { PyDict :: from_borrowed_ptr ( self . py ( ) , ( * self . as_dtype_ptr ( ) ) . fields ) } ;
284288 // Panic-wise: numpy guarantees that fields are tuples of proper size and type
285- Some ( dict. iter ( ) . map ( |( k, v) | {
286- let name = FromPyObject :: extract ( k) . unwrap ( ) ;
287- let tuple = v. downcast :: < PyTuple > ( ) . unwrap ( ) ;
288- // note: we can't just extract the entire tuple since 3rd element can be a title
289- let dtype = FromPyObject :: extract ( tuple. as_ref ( ) . get_item ( 0 ) . unwrap ( ) ) . unwrap ( ) ;
290- let offset = FromPyObject :: extract ( tuple. as_ref ( ) . get_item ( 1 ) . unwrap ( ) ) . unwrap ( ) ;
291- ( name, ( dtype, offset) )
292- } ) )
289+ let tuple = dict
290+ . get_item ( name)
291+ . ok_or_else ( || PyIndexError :: new_err ( name. to_owned ( ) ) ) ?
292+ . downcast :: < PyTuple > ( )
293+ . unwrap ( ) ;
294+ // (note: we can't just extract the entire tuple since 3rd element can be a title)
295+ let dtype = FromPyObject :: extract ( tuple. as_ref ( ) . get_item ( 0 ) . unwrap ( ) ) . unwrap ( ) ;
296+ let offset = FromPyObject :: extract ( tuple. as_ref ( ) . get_item ( 1 ) . unwrap ( ) ) . unwrap ( ) ;
297+ Ok ( ( dtype, offset) )
293298 }
294299}
295300
0 commit comments