|
| 1 | +//! Support datetimes and timedeltas |
| 2 | +//! |
| 3 | +//! [The corresponding section][datetime] of the NumPy documentation contains more information. |
| 4 | +//! |
| 5 | +//! # Example |
| 6 | +//! |
| 7 | +//! ``` |
| 8 | +//! use numpy::{datetime::{units, Datetime, Timedelta}, PyArray1}; |
| 9 | +//! use pyo3::Python; |
| 10 | +//! # use pyo3::types::PyDict; |
| 11 | +//! |
| 12 | +//! Python::with_gil(|py| { |
| 13 | +//! # let locals = py |
| 14 | +//! # .eval("{ 'np': __import__('numpy') }", None, None) |
| 15 | +//! # .unwrap() |
| 16 | +//! # .downcast::<PyDict>() |
| 17 | +//! # .unwrap(); |
| 18 | +//! # |
| 19 | +//! let array = py |
| 20 | +//! .eval( |
| 21 | +//! "np.array([np.datetime64('2017-04-21')])", |
| 22 | +//! None, |
| 23 | +//! Some(locals), |
| 24 | +//! ) |
| 25 | +//! .unwrap() |
| 26 | +//! .downcast::<PyArray1<Datetime<units::Days>>>() |
| 27 | +//! .unwrap(); |
| 28 | +//! |
| 29 | +//! assert_eq!( |
| 30 | +//! array.get_owned(0).unwrap(), |
| 31 | +//! Datetime::<units::Days>::from(17_277) |
| 32 | +//! ); |
| 33 | +//! |
| 34 | +//! let array = py |
| 35 | +//! .eval( |
| 36 | +//! "np.array([np.datetime64('2022-03-29')]) - np.array([np.datetime64('2017-04-21')])", |
| 37 | +//! None, |
| 38 | +//! Some(locals), |
| 39 | +//! ) |
| 40 | +//! .unwrap() |
| 41 | +//! .downcast::<PyArray1<Timedelta<units::Days>>>() |
| 42 | +//! .unwrap(); |
| 43 | +//! |
| 44 | +//! assert_eq!( |
| 45 | +//! array.get_owned(0).unwrap(), |
| 46 | +//! Timedelta::<units::Days>::from(1_803) |
| 47 | +//! ); |
| 48 | +//! }); |
| 49 | +//! ``` |
| 50 | +//! |
| 51 | +//! [datetime]: https://numpy.org/doc/stable/reference/arrays.datetime.html |
| 52 | +
|
| 53 | +use std::fmt; |
| 54 | +use std::hash::Hash; |
| 55 | +use std::marker::PhantomData; |
| 56 | + |
| 57 | +use pyo3::Python; |
| 58 | + |
| 59 | +use crate::dtype::{Element, PyArrayDescr}; |
| 60 | +use crate::npyffi::{PyArray_DatetimeDTypeMetaData, NPY_DATETIMEUNIT, NPY_TYPES}; |
| 61 | + |
| 62 | +/// Represents the [datetime units][datetime-units] supported by NumPy |
| 63 | +/// |
| 64 | +/// [datetime-units]: https://numpy.org/doc/stable/reference/arrays.datetime.html#datetime-units |
| 65 | +pub trait Unit: Send + Sync + Clone + Copy + PartialEq + Eq + Hash + PartialOrd + Ord { |
| 66 | + /// The matching NumPy [datetime unit code][NPY_DATETIMEUNIT] |
| 67 | + /// |
| 68 | + /// [NPY_DATETIMEUNIT]: https://github.com/numpy/numpy/blob/4c60b3263ac50e5e72f6a909e156314fc3c9cba0/numpy/core/include/numpy/ndarraytypes.h#L276 |
| 69 | + const UNIT: NPY_DATETIMEUNIT; |
| 70 | + |
| 71 | + /// The abbrevation used for debug formatting |
| 72 | + const ABBREV: &'static str; |
| 73 | +} |
| 74 | + |
| 75 | +macro_rules! define_units { |
| 76 | + ($($(#[$meta:meta])* $struct:ident => $unit:ident $abbrev:literal,)+) => { |
| 77 | + $( |
| 78 | + |
| 79 | + $(#[$meta])* |
| 80 | + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] |
| 81 | + pub struct $struct; |
| 82 | + |
| 83 | + impl Unit for $struct { |
| 84 | + const UNIT: NPY_DATETIMEUNIT = NPY_DATETIMEUNIT::$unit; |
| 85 | + |
| 86 | + const ABBREV: &'static str = $abbrev; |
| 87 | + } |
| 88 | + |
| 89 | + )+ |
| 90 | + }; |
| 91 | +} |
| 92 | + |
| 93 | +/// Predefined implementors of the [`Unit`] trait |
| 94 | +pub mod units { |
| 95 | + use super::*; |
| 96 | + |
| 97 | + define_units!( |
| 98 | + #[doc = "Years, i.e. 12 months"] |
| 99 | + Years => NPY_FR_Y "a", |
| 100 | + #[doc = "Months, i.e. 30 days"] |
| 101 | + Months => NPY_FR_M "mo", |
| 102 | + #[doc = "Weeks, i.e. 7 days"] |
| 103 | + Weeks => NPY_FR_W "w", |
| 104 | + #[doc = "Days, i.e. 24 hours"] |
| 105 | + Days => NPY_FR_D "d", |
| 106 | + #[doc = "Hours, i.e. 60 minutes"] |
| 107 | + Hours => NPY_FR_h "h", |
| 108 | + #[doc = "Minutes, i.e. 60 seconds"] |
| 109 | + Minutes => NPY_FR_m "min", |
| 110 | + #[doc = "Seconds"] |
| 111 | + Seconds => NPY_FR_s "s", |
| 112 | + #[doc = "Milliseconds, i.e. 10^-3 seconds"] |
| 113 | + Milliseconds => NPY_FR_ms "ms", |
| 114 | + #[doc = "Microseconds, i.e. 10^-6 seconds"] |
| 115 | + Microseconds => NPY_FR_us "µs", |
| 116 | + #[doc = "Nanoseconds, i.e. 10^-9 seconds"] |
| 117 | + Nanoseconds => NPY_FR_ns "ns", |
| 118 | + #[doc = "Picoseconds, i.e. 10^-12 seconds"] |
| 119 | + Picoseconds => NPY_FR_ps "ps", |
| 120 | + #[doc = "Femtoseconds, i.e. 10^-15 seconds"] |
| 121 | + Femtoseconds => NPY_FR_fs "fs", |
| 122 | + #[doc = "Attoseconds, i.e. 10^-18 seconds"] |
| 123 | + Attoseconds => NPY_FR_as "as", |
| 124 | + ); |
| 125 | +} |
| 126 | + |
| 127 | +/// Corresponds to the [`datetime64`][scalars-datetime64] scalar type |
| 128 | +/// |
| 129 | +/// [scalars-datetime64]: https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.datetime64 |
| 130 | +#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] |
| 131 | +#[repr(transparent)] |
| 132 | +pub struct Datetime<U: Unit>(i64, PhantomData<U>); |
| 133 | + |
| 134 | +impl<U: Unit> From<i64> for Datetime<U> { |
| 135 | + fn from(val: i64) -> Self { |
| 136 | + Self(val, PhantomData) |
| 137 | + } |
| 138 | +} |
| 139 | + |
| 140 | +impl<U: Unit> From<Datetime<U>> for i64 { |
| 141 | + fn from(val: Datetime<U>) -> Self { |
| 142 | + val.0 |
| 143 | + } |
| 144 | +} |
| 145 | + |
| 146 | +unsafe impl<U: Unit> Element for Datetime<U> { |
| 147 | + const IS_COPY: bool = true; |
| 148 | + |
| 149 | + fn get_dtype(py: Python) -> &PyArrayDescr { |
| 150 | + // FIXME(adamreichold): Memoize these via the Unit trait |
| 151 | + let dtype = PyArrayDescr::new_from_npy_type(py, NPY_TYPES::NPY_DATETIME); |
| 152 | + |
| 153 | + unsafe { |
| 154 | + set_unit(dtype, U::UNIT); |
| 155 | + } |
| 156 | + |
| 157 | + dtype |
| 158 | + } |
| 159 | +} |
| 160 | + |
| 161 | +impl<U: Unit> fmt::Debug for Datetime<U> { |
| 162 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 163 | + write!(f, "Datetime({} {})", self.0, U::ABBREV) |
| 164 | + } |
| 165 | +} |
| 166 | + |
| 167 | +/// Corresponds to the [`timedelta64`][scalars-datetime64] scalar type |
| 168 | +/// |
| 169 | +/// [scalars-timedelta64]: https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.timedelta64 |
| 170 | +#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] |
| 171 | +#[repr(transparent)] |
| 172 | +pub struct Timedelta<U: Unit>(i64, PhantomData<U>); |
| 173 | + |
| 174 | +impl<U: Unit> From<i64> for Timedelta<U> { |
| 175 | + fn from(val: i64) -> Self { |
| 176 | + Self(val, PhantomData) |
| 177 | + } |
| 178 | +} |
| 179 | + |
| 180 | +impl<U: Unit> From<Timedelta<U>> for i64 { |
| 181 | + fn from(val: Timedelta<U>) -> Self { |
| 182 | + val.0 |
| 183 | + } |
| 184 | +} |
| 185 | + |
| 186 | +unsafe impl<U: Unit> Element for Timedelta<U> { |
| 187 | + const IS_COPY: bool = true; |
| 188 | + |
| 189 | + fn get_dtype(py: Python) -> &PyArrayDescr { |
| 190 | + // FIXME(adamreichold): Memoize these via the Unit trait |
| 191 | + let dtype = PyArrayDescr::new_from_npy_type(py, NPY_TYPES::NPY_TIMEDELTA); |
| 192 | + |
| 193 | + unsafe { |
| 194 | + set_unit(dtype, U::UNIT); |
| 195 | + } |
| 196 | + |
| 197 | + dtype |
| 198 | + } |
| 199 | +} |
| 200 | + |
| 201 | +impl<U: Unit> fmt::Debug for Timedelta<U> { |
| 202 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 203 | + write!(f, "Timedelta({} {})", self.0, U::ABBREV) |
| 204 | + } |
| 205 | +} |
| 206 | + |
| 207 | +unsafe fn set_unit(dtype: &PyArrayDescr, unit: NPY_DATETIMEUNIT) { |
| 208 | + let metadata = &mut *((*dtype.as_dtype_ptr()).c_metadata as *mut PyArray_DatetimeDTypeMetaData); |
| 209 | + |
| 210 | + metadata.meta.base = unit; |
| 211 | + metadata.meta.num = 1; |
| 212 | +} |
| 213 | + |
| 214 | +#[cfg(test)] |
| 215 | +mod tests { |
| 216 | + use super::*; |
| 217 | + |
| 218 | + use pyo3::{ |
| 219 | + py_run, |
| 220 | + types::{PyDict, PyModule}, |
| 221 | + }; |
| 222 | + |
| 223 | + use crate::array::PyArray1; |
| 224 | + |
| 225 | + #[test] |
| 226 | + fn from_python_to_rust() { |
| 227 | + Python::with_gil(|py| { |
| 228 | + let locals = py |
| 229 | + .eval("{ 'np': __import__('numpy') }", None, None) |
| 230 | + .unwrap() |
| 231 | + .downcast::<PyDict>() |
| 232 | + .unwrap(); |
| 233 | + |
| 234 | + let array = py |
| 235 | + .eval( |
| 236 | + "np.array([np.datetime64('1970-01-01')])", |
| 237 | + None, |
| 238 | + Some(locals), |
| 239 | + ) |
| 240 | + .unwrap() |
| 241 | + .downcast::<PyArray1<Datetime<units::Days>>>() |
| 242 | + .unwrap(); |
| 243 | + |
| 244 | + let value: i64 = array.get_owned(0).unwrap().into(); |
| 245 | + assert_eq!(value, 0); |
| 246 | + }); |
| 247 | + } |
| 248 | + |
| 249 | + #[test] |
| 250 | + fn from_rust_to_python() { |
| 251 | + Python::with_gil(|py| { |
| 252 | + let array = PyArray1::<Timedelta<units::Minutes>>::zeros(py, 1, false); |
| 253 | + |
| 254 | + *array.readwrite().get_mut(0).unwrap() = Timedelta::<units::Minutes>::from(5); |
| 255 | + |
| 256 | + let np = py |
| 257 | + .eval("__import__('numpy')", None, None) |
| 258 | + .unwrap() |
| 259 | + .downcast::<PyModule>() |
| 260 | + .unwrap(); |
| 261 | + |
| 262 | + py_run!(py, array np, "assert array.dtype == np.dtype('timedelta64[m]')"); |
| 263 | + py_run!(py, array np, "assert array[0] == np.timedelta64(5, 'm')"); |
| 264 | + }); |
| 265 | + } |
| 266 | + |
| 267 | + #[test] |
| 268 | + fn debug_formatting() { |
| 269 | + assert_eq!( |
| 270 | + format!("{:?}", Datetime::<units::Days>::from(28)), |
| 271 | + "Datetime(28 d)" |
| 272 | + ); |
| 273 | + |
| 274 | + assert_eq!( |
| 275 | + format!("{:?}", Timedelta::<units::Milliseconds>::from(160)), |
| 276 | + "Timedelta(160 ms)" |
| 277 | + ); |
| 278 | + } |
| 279 | + |
| 280 | + #[test] |
| 281 | + fn unit_conversion() { |
| 282 | + #[track_caller] |
| 283 | + fn convert<S: Unit, D: Unit>(py: Python<'_>, expected_value: i64) { |
| 284 | + let array = PyArray1::<Timedelta<S>>::from_slice(py, &[Timedelta::<S>::from(1)]); |
| 285 | + let array = array.cast::<Timedelta<D>>(false).unwrap(); |
| 286 | + |
| 287 | + let value: i64 = array.get_owned(0).unwrap().into(); |
| 288 | + assert_eq!(value, expected_value); |
| 289 | + } |
| 290 | + |
| 291 | + Python::with_gil(|py| { |
| 292 | + convert::<units::Years, units::Days>(py, (97 + 400 * 365) / 400); |
| 293 | + convert::<units::Months, units::Days>(py, (97 + 400 * 365) / 400 / 12); |
| 294 | + |
| 295 | + convert::<units::Weeks, units::Seconds>(py, 7 * 24 * 60 * 60); |
| 296 | + convert::<units::Days, units::Seconds>(py, 24 * 60 * 60); |
| 297 | + convert::<units::Hours, units::Seconds>(py, 60 * 60); |
| 298 | + convert::<units::Minutes, units::Seconds>(py, 60); |
| 299 | + |
| 300 | + convert::<units::Seconds, units::Milliseconds>(py, 1_000); |
| 301 | + convert::<units::Seconds, units::Microseconds>(py, 1_000_000); |
| 302 | + convert::<units::Seconds, units::Nanoseconds>(py, 1_000_000_000); |
| 303 | + convert::<units::Seconds, units::Picoseconds>(py, 1_000_000_000_000); |
| 304 | + convert::<units::Seconds, units::Femtoseconds>(py, 1_000_000_000_000_000); |
| 305 | + |
| 306 | + convert::<units::Femtoseconds, units::Attoseconds>(py, 1_000); |
| 307 | + }); |
| 308 | + } |
| 309 | +} |
0 commit comments