Skip to content

Commit 5abc676

Browse files
committed
Add tests for all new PyArrayDescr methods
1 parent 67e4719 commit 5abc676

File tree

1 file changed

+112
-1
lines changed

1 file changed

+112
-1
lines changed

src/dtype.rs

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,12 @@ unsafe impl Element for PyObject {
441441

442442
#[cfg(test)]
443443
mod tests {
444-
use super::{dtype, Complex32, Complex64, Element};
444+
use std::os::raw::c_int;
445+
446+
use pyo3::{py_run, types::PyDict, PyObject};
447+
448+
use super::{dtype, Complex32, Complex64, Element, PyArrayDescr};
449+
use crate::npyffi::{NPY_ALIGNED_STRUCT, NPY_ITEM_HASOBJECT, NPY_NEEDS_PYAPI, NPY_TYPES};
445450

446451
#[test]
447452
fn test_dtype_names() {
@@ -474,4 +479,110 @@ mod tests {
474479
}
475480
});
476481
}
482+
483+
#[test]
484+
fn test_dtype_methods_scalar() {
485+
pyo3::Python::with_gil(|py| {
486+
let dt = dtype::<f64>(py);
487+
488+
assert_eq!(dt.num(), NPY_TYPES::NPY_DOUBLE as c_int);
489+
assert_eq!(dt.flags(), 0);
490+
assert_eq!(dt.typeobj().name().unwrap(), "float64");
491+
assert_eq!(dt.char(), b'd');
492+
assert_eq!(dt.kind(), b'f');
493+
assert_eq!(dt.byteorder(), b'=');
494+
assert_eq!(dt.is_native_byteorder(), Some(true));
495+
assert_eq!(dt.itemsize(), 8);
496+
assert_eq!(dt.alignment(), 8);
497+
assert!(!dt.has_object());
498+
assert_eq!(dt.names(), None);
499+
assert!(!dt.has_fields());
500+
assert!(!dt.is_aligned_struct());
501+
assert!(!dt.has_subarray());
502+
assert!(dt.base().is_equiv_to(dt));
503+
assert_eq!(dt.ndim(), 0);
504+
assert_eq!(dt.shape(), None);
505+
});
506+
}
507+
508+
#[test]
509+
fn test_dtype_methods_subarray() {
510+
pyo3::Python::with_gil(|py| {
511+
let locals = PyDict::new(py);
512+
py_run!(
513+
py,
514+
*locals,
515+
"dtype = __import__('numpy').dtype(('f8', (2, 3)))"
516+
);
517+
let dt = locals
518+
.get_item("dtype")
519+
.unwrap()
520+
.downcast::<PyArrayDescr>()
521+
.unwrap();
522+
523+
assert_eq!(dt.num(), NPY_TYPES::NPY_VOID as c_int);
524+
assert_eq!(dt.flags(), 0);
525+
assert_eq!(dt.typeobj().name().unwrap(), "void");
526+
assert_eq!(dt.char(), b'V');
527+
assert_eq!(dt.kind(), b'V');
528+
assert_eq!(dt.byteorder(), b'|');
529+
assert_eq!(dt.is_native_byteorder(), None);
530+
assert_eq!(dt.itemsize(), 48);
531+
assert_eq!(dt.alignment(), 8);
532+
assert!(!dt.has_object());
533+
assert_eq!(dt.names(), None);
534+
assert!(!dt.has_fields());
535+
assert!(!dt.is_aligned_struct());
536+
assert!(dt.has_subarray());
537+
assert_eq!(dt.ndim(), 2);
538+
assert_eq!(dt.shape().unwrap(), vec![2, 3]);
539+
assert!(dt.base().is_equiv_to(dtype::<f64>(py)));
540+
});
541+
}
542+
543+
#[test]
544+
fn test_dtype_methods_record() {
545+
pyo3::Python::with_gil(|py| {
546+
let locals = PyDict::new(py);
547+
py_run!(
548+
py,
549+
*locals,
550+
"dtype = __import__('numpy').dtype([('x', 'u1'), ('y', 'f8'), ('z', 'O')], align=True)"
551+
);
552+
let dt = locals
553+
.get_item("dtype")
554+
.unwrap()
555+
.downcast::<PyArrayDescr>()
556+
.unwrap();
557+
558+
assert_eq!(dt.num(), NPY_TYPES::NPY_VOID as c_int);
559+
assert_ne!(dt.flags() & NPY_ITEM_HASOBJECT, 0);
560+
assert_ne!(dt.flags() & NPY_NEEDS_PYAPI, 0);
561+
assert_ne!(dt.flags() & NPY_ALIGNED_STRUCT, 0);
562+
assert_eq!(dt.typeobj().name().unwrap(), "void");
563+
assert_eq!(dt.char(), b'V');
564+
assert_eq!(dt.kind(), b'V');
565+
assert_eq!(dt.byteorder(), b'|');
566+
assert_eq!(dt.is_native_byteorder(), None);
567+
assert_eq!(dt.itemsize(), 24);
568+
assert_eq!(dt.alignment(), 8);
569+
assert!(dt.has_object());
570+
assert_eq!(dt.names(), Some(vec!["x", "y", "z"]));
571+
assert!(dt.has_fields());
572+
assert!(dt.is_aligned_struct());
573+
assert!(!dt.has_subarray());
574+
assert_eq!(dt.ndim(), 0);
575+
assert_eq!(dt.shape(), None);
576+
assert!(dt.base().is_equiv_to(dt));
577+
let x = dt.get_field("x").unwrap();
578+
assert!(x.0.is_equiv_to(dtype::<u8>(py)));
579+
assert_eq!(x.1, 0);
580+
let y = dt.get_field("y").unwrap();
581+
assert!(y.0.is_equiv_to(dtype::<f64>(py)));
582+
assert_eq!(y.1, 8);
583+
let z = dt.get_field("z").unwrap();
584+
assert!(z.0.is_equiv_to(dtype::<PyObject>(py)));
585+
assert_eq!(z.1, 16);
586+
});
587+
}
477588
}

0 commit comments

Comments
 (0)