Skip to content

Commit 1ec6e0d

Browse files
committed
A few fixes and simplifications in the new dtype
1 parent 45f6d4a commit 1ec6e0d

File tree

1 file changed

+17
-23
lines changed

1 file changed

+17
-23
lines changed

src/dtype.rs

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use std::collections::BTreeMap;
21
use std::mem::size_of;
32
use std::os::raw::{
43
c_char, c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort,
@@ -10,7 +9,7 @@ use pyo3::{
109
prelude::*,
1110
pyobject_native_type_core,
1211
types::{PyDict, PyTuple, PyType},
13-
AsPyPointer, FromPyObject, FromPyPointer, PyNativeType, PyResult,
12+
AsPyPointer, FromPyObject, FromPyPointer, PyNativeType,
1413
};
1514

1615
use crate::npyffi::{
@@ -192,12 +191,12 @@ impl PyArrayDescr {
192191
return None;
193192
}
194193
Some(
195-
// TODO: can this be done simpler, without the incref?
194+
// Panic-wise: numpy guarantees that shape is a tuple of non-negative integers
196195
unsafe {
197196
PyTuple::from_borrowed_ptr(self.py(), (*(*self.as_dtype_ptr()).subarray).shape)
198197
}
199198
.extract()
200-
.unwrap(), // TODO: unwrap? numpy sort-of guarantees it will be an int tuple
199+
.unwrap(),
201200
)
202201
}
203202

@@ -271,36 +270,31 @@ impl PyArrayDescr {
271270
return None;
272271
}
273272
let names = unsafe { PyTuple::from_borrowed_ptr(self.py(), (*self.as_dtype_ptr()).names) };
274-
<_>::extract(names).ok()
273+
FromPyObject::extract(names).ok()
275274
}
276275

277-
/// Returns a dictionary of fields, or `None` if not a structured type.
276+
/// Returns names, types and offsets of fields, or `None` if not a structured type.
278277
///
279-
/// The dictionary is indexed by keys that are the names of the fields. Each entry in
280-
/// the dictionary is a tuple fully describing the field: `(dtype, offset)`.
278+
/// The iterator has entries in the form `(name, (dtype, offset))` so it can be
279+
/// collected directly into a map-like structure.
281280
///
282281
/// Note: titles (the optional 3rd tuple element) are ignored.
283282
///
284283
/// Equivalent to [`np.dtype.fields`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.fields.html).
285-
pub fn fields(&self) -> Option<BTreeMap<&str, (&PyArrayDescr, usize)>> {
284+
pub fn fields(&self) -> Option<impl Iterator<Item = (&str, (&PyArrayDescr, usize))> + '_> {
286285
if !self.has_fields() {
287286
return None;
288287
}
289-
// TODO: can this be done simpler, without the incref?
290288
let dict = unsafe { PyDict::from_borrowed_ptr(self.py(), (*self.as_dtype_ptr()).fields) };
291-
let mut fields = BTreeMap::new();
292-
(|| -> PyResult<_> {
293-
for (k, v) in dict.iter() {
294-
// TODO: alternatively, could unwrap everything here
295-
let name = <_>::extract(k)?;
296-
let tuple = v.downcast::<PyTuple>()?;
297-
let dtype = <_>::extract(tuple.as_ref().get_item(0)?)?;
298-
let offset = <_>::extract(tuple.as_ref().get_item(1)?)?;
299-
fields.insert(name, (dtype, offset));
300-
}
301-
Ok(fields)
302-
})()
303-
.ok()
289+
// Panic-wise: numpy guarantees that fields are tuples of proper size and type
290+
Some(dict.iter().map(|(k, v)| {
291+
let name = FromPyObject::extract(k).unwrap();
292+
let tuple = v.downcast::<PyTuple>().unwrap();
293+
// note: we can't just extract the entire tuple since 3rd element can be a title
294+
let dtype = FromPyObject::extract(tuple.as_ref().get_item(0).unwrap()).unwrap();
295+
let offset = FromPyObject::extract(tuple.as_ref().get_item(1).unwrap()).unwrap();
296+
(name, (dtype, offset))
297+
}))
304298
}
305299
}
306300

0 commit comments

Comments
 (0)