Skip to content

Commit 56e6f6f

Browse files
committed
Replace PyArrayDescr::fields() -> get_field()
1 parent 9a8c376 commit 56e6f6f

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

src/dtype.rs

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use crate::npyffi::{
1818
};
1919

2020
pub use num_complex::{Complex32, Complex64};
21+
use pyo3::exceptions::{PyIndexError, PyValueError};
2122

2223
/// Binding of [`numpy.dtype`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html).
2324
///
@@ -268,28 +269,32 @@ impl PyArrayDescr {
268269
FromPyObject::extract(names).ok()
269270
}
270271

271-
/// Returns names, types and offsets of fields, or `None` if not a structured type.
272+
/// Returns the dtype and offset of a field with a given name.
272273
///
273-
/// The iterator has entries in the form `(name, (dtype, offset))` so it can be
274-
/// collected directly into a map-like structure.
274+
/// This method will return an error if the dtype is not structured, or if it doesn't
275+
/// contain a field with a given name.
275276
///
276-
/// Note: titles (the optional 3rd tuple element) are ignored.
277+
/// The list of all names can be found via [`PyArrayDescr::names`].
277278
///
278-
/// Equivalent to [`np.dtype.fields`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.fields.html).
279-
pub fn fields(&self) -> Option<impl Iterator<Item = (&str, (&PyArrayDescr, usize))> + '_> {
279+
/// Equivalent to retrieving a single item from
280+
/// [`np.dtype.fields`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.fields.html).
281+
pub fn get_field(&self, name: &str) -> PyResult<(&PyArrayDescr, usize)> {
280282
if !self.has_fields() {
281-
return None;
283+
return Err(PyValueError::new_err(
284+
"cannot get field information: dtype has no fields",
285+
));
282286
}
283287
let dict = unsafe { PyDict::from_borrowed_ptr(self.py(), (*self.as_dtype_ptr()).fields) };
284288
// Panic-wise: numpy guarantees that fields are tuples of proper size and type
285-
Some(dict.iter().map(|(k, v)| {
286-
let name = FromPyObject::extract(k).unwrap();
287-
let tuple = v.downcast::<PyTuple>().unwrap();
288-
// note: we can't just extract the entire tuple since 3rd element can be a title
289-
let dtype = FromPyObject::extract(tuple.as_ref().get_item(0).unwrap()).unwrap();
290-
let offset = FromPyObject::extract(tuple.as_ref().get_item(1).unwrap()).unwrap();
291-
(name, (dtype, offset))
292-
}))
289+
let tuple = dict
290+
.get_item(name)
291+
.ok_or_else(|| PyIndexError::new_err(name.to_owned()))?
292+
.downcast::<PyTuple>()
293+
.unwrap();
294+
// (note: we can't just extract the entire tuple since 3rd element can be a title)
295+
let dtype = FromPyObject::extract(tuple.as_ref().get_item(0).unwrap()).unwrap();
296+
let offset = FromPyObject::extract(tuple.as_ref().get_item(1).unwrap()).unwrap();
297+
Ok((dtype, offset))
293298
}
294299
}
295300

0 commit comments

Comments
 (0)