Skip to content

Commit 41cb84a

Browse files
committed
Document dtype
1 parent 7fcf952 commit 41cb84a

File tree

4 files changed

+68
-25
lines changed

4 files changed

+68
-25
lines changed

src/array.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,17 @@ impl<T, D> PyArray<T, D> {
152152
self.as_ptr() as _
153153
}
154154

155+
/// Returns `dtype` of the array.
156+
/// Counterpart of `array.dtype` in Python.
157+
///
158+
/// # Example
159+
/// ```
160+
/// pyo3::Python::with_gil(|py| {
161+
/// let array = numpy::PyArray::from_vec(py, vec![1, 2, 3i32]);
162+
/// let dtype = array.dtype();
163+
/// assert_eq!(dtype.get_datatype().unwrap(), numpy::DataType::Int32);
164+
/// });
165+
/// ```
155166
pub fn dtype(&self) -> &crate::PyArrayDescr {
156167
let descr_ptr = unsafe { (*self.as_array_ptr()).descr };
157168
unsafe { pyo3::FromPyPointer::from_borrowed_ptr(self.py(), descr_ptr as _) }

src/dtype.rs

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
//! Implements conversion utitlities.
21
use crate::npyffi::{NpyTypes, PyArray_Descr, NPY_TYPES, PY_ARRAY_API};
32
pub use num_complex::Complex32 as c32;
43
pub use num_complex::Complex64 as c64;
@@ -8,6 +7,21 @@ use pyo3::types::PyType;
87
use pyo3::{AsPyPointer, PyNativeType};
98
use std::os::raw::c_int;
109

10+
/// Binding of [`numpy.dtype`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html).
11+
///
12+
/// # Example
13+
/// ```
14+
/// use pyo3::types::IntoPyDict;
15+
/// pyo3::Python::with_gil(|py| {
16+
/// let locals = [("np", numpy::get_array_module(py).unwrap())].into_py_dict(py);
17+
/// let dtype: &numpy::PyArrayDescr = py
18+
/// .eval("np.array([1, 2, 3.0]).dtype", Some(locals), None)
19+
/// .unwrap()
20+
/// .downcast()
21+
/// .unwrap();
22+
/// assert_eq!(dtype.get_datatype().unwrap(), numpy::DataType::Float64);
23+
/// });
24+
/// ```
1125
pub struct PyArrayDescr(PyAny);
1226

1327
pyobject_native_type_core!(
@@ -28,34 +42,48 @@ unsafe fn arraydescr_check(op: *mut ffi::PyObject) -> c_int {
2842
}
2943

3044
impl PyArrayDescr {
45+
/// Returns `self` as `*mut PyArray_Descr`.
3146
pub fn as_dtype_ptr(&self) -> *mut PyArray_Descr {
3247
self.as_ptr() as _
3348
}
3449

50+
/// Returns the internal `PyType` that this `dtype` holds.
51+
///
52+
/// # Example
53+
/// ```
54+
/// pyo3::Python::with_gil(|py| {
55+
/// let array = numpy::PyArray::from_vec(py, vec![0.0, 1.0, 2.0f64]);
56+
/// let dtype = array.dtype();
57+
/// assert_eq!(dtype.get_type().name().to_string(), "numpy.float64");
58+
/// });
59+
/// ```
3560
pub fn get_type(&self) -> &PyType {
3661
let dtype_type_ptr = unsafe { *self.as_dtype_ptr() }.typeobj;
3762
unsafe { PyType::from_type_ptr(self.py(), dtype_type_ptr) }
3863
}
3964

40-
pub fn get_typenum(&self) -> std::os::raw::c_int {
41-
unsafe { *self.as_dtype_ptr() }.type_num
42-
}
43-
65+
/// Returns the data type as `DataType` enum.
4466
pub fn get_datatype(&self) -> Option<DataType> {
4567
DataType::from_typenum(self.get_typenum())
4668
}
4769

48-
pub fn from_npy_type(py: Python, npy_type: NPY_TYPES) -> &Self {
70+
fn from_npy_type(py: Python, npy_type: NPY_TYPES) -> &Self {
4971
unsafe {
5072
let descr = PY_ARRAY_API.PyArray_DescrFromType(npy_type as i32);
5173
py.from_owned_ptr(descr as _)
5274
}
5375
}
76+
77+
fn get_typenum(&self) -> std::os::raw::c_int {
78+
unsafe { *self.as_dtype_ptr() }.type_num
79+
}
5480
}
5581

56-
/// An enum type represents numpy data type.
82+
/// Represents numpy data type.
5783
///
58-
/// This type is mainly for displaying error, and user don't have to use it directly.
84+
/// This is an incomplete counterpart of
85+
/// [Enumerated Types](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types)
86+
/// in numpy C-API.
5987
#[derive(Clone, Debug, Eq, PartialEq)]
6088
pub enum DataType {
6189
Bool,
@@ -75,6 +103,8 @@ pub enum DataType {
75103
}
76104

77105
impl DataType {
106+
/// Construct `DataType` from
107+
/// [Enumerated Types](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types).
78108
pub fn from_typenum(typenum: c_int) -> Option<Self> {
79109
Some(match typenum {
80110
x if x == NPY_TYPES::NPY_BOOL as i32 => DataType::Bool,
@@ -97,11 +127,8 @@ impl DataType {
97127
})
98128
}
99129

100-
pub fn from_dtype(dtype: &crate::PyArrayDescr) -> Option<Self> {
101-
Self::from_typenum(dtype.get_typenum())
102-
}
103-
104-
#[inline]
130+
/// Convert `self` into
131+
/// [Enumerated Types](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types).
105132
pub fn into_ctype(self) -> NPY_TYPES {
106133
match self {
107134
DataType::Bool => NPY_TYPES::NPY_BOOL,
@@ -143,15 +170,20 @@ impl DataType {
143170

144171
/// Represents that a type can be an element of `PyArray`.
145172
pub trait Element: Clone {
173+
/// `DataType` corresponding to this type.
146174
const DATA_TYPE: DataType;
147175

176+
/// Returns if the give `dtype` is convertible to `Self` in Rust.
148177
fn is_same_type(dtype: &PyArrayDescr) -> bool;
149178

179+
/// Returns the corresponding
180+
/// [Enumerated Type](https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types).
150181
#[inline]
151182
fn npy_type() -> NPY_TYPES {
152183
Self::DATA_TYPE.into_ctype()
153184
}
154185

186+
/// Create `dtype`.
155187
fn get_dtype(py: Python) -> &PyArrayDescr {
156188
PyArrayDescr::from_npy_type(py, Self::npy_type())
157189
}

src/error.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ impl ShapeError {
4141
ShapeError {
4242
from: ArrayDim {
4343
dim: Some(from_dim),
44-
dtype: DataType::from_dtype(from_dtype),
44+
dtype: from_dtype.get_datatype(),
4545
},
4646
to: ArrayDim {
4747
dim: to_dim,

tests/array.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -237,15 +237,15 @@ fn handle_negative_strides() {
237237

238238
#[test]
239239
fn dtype_from_py() {
240-
let gil = pyo3::Python::acquire_gil();
241-
let py = gil.python();
242-
let arr = array![[2, 3], [4, 5u32]];
243-
let pyarr = arr.to_pyarray(py);
244-
let dtype: &numpy::PyArrayDescr = py
245-
.eval("a.dtype", Some([("a", pyarr)].into_py_dict(py)), None)
246-
.unwrap()
247-
.downcast()
248-
.unwrap();
249-
assert_eq!(&format!("{:?}", dtype), "dtype('uint32')");
250-
assert_eq!(dtype.get_datatype().unwrap(), numpy::DataType::Uint32);
240+
pyo3::Python::with_gil(|py| {
241+
let arr = array![[2, 3], [4, 5u32]];
242+
let pyarr = arr.to_pyarray(py);
243+
let dtype: &numpy::PyArrayDescr = py
244+
.eval("a.dtype", Some([("a", pyarr)].into_py_dict(py)), None)
245+
.unwrap()
246+
.downcast()
247+
.unwrap();
248+
assert_eq!(&format!("{:?}", dtype), "dtype('uint32')");
249+
assert_eq!(dtype.get_datatype().unwrap(), numpy::DataType::Uint32);
250+
})
251251
}

0 commit comments

Comments
 (0)