Skip to content

Commit 72ff2be

Browse files
committed
Add datetime module for handling NumPy's datetime64 and timedelta64 types.
1 parent 7fc47f0 commit 72ff2be

File tree

7 files changed

+355
-5
lines changed

7 files changed

+355
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
- Add dynamic borrow checking to safely construct references into the interior of NumPy arrays. ([#274](https://github.com/PyO3/rust-numpy/pull/274))
55
- The deprecated iterator builders `NpySingleIterBuilder::{readonly,readwrite}` and `NpyMultiIterBuilder::add_{readonly,readwrite}` now take referencces to `PyReadonlyArray` and `PyReadwriteArray` instead of consuming them.
66
- The destructive `PyArray::resize` method is now unsafe if used without an instance of `PyReadwriteArray`. ([#302](https://github.com/PyO3/rust-numpy/pull/302))
7+
- Add support for `datetime64` and `timedelta64` element types via the `datetime` module. ([#308](https://github.com/PyO3/rust-numpy/pull/308))
78
- Add support for IEEE 754-2008 16-bit floating point numbers via an optional dependency on the `half` crate. ([#314](https://github.com/PyO3/rust-numpy/pull/314))
89
- The `inner`, `dot` and `einsum` functions can also return a scalar instead of a zero-dimensional array to match NumPy's types ([#285](https://github.com/PyO3/rust-numpy/pull/285))
910
- The `PyArray::resize` function supports n-dimensional contiguous arrays. ([#312](https://github.com/PyO3/rust-numpy/pull/312))

examples/simple/src/lib.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD};
1+
use numpy::ndarray::{ArrayD, ArrayViewD, ArrayViewMutD, Zip};
22
use numpy::{
3-
Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyReadonlyArrayDyn, PyReadwriteArrayDyn,
3+
datetime::{units, Timedelta},
4+
Complex64, IntoPyArray, PyArray1, PyArrayDyn, PyReadonlyArray1, PyReadonlyArrayDyn,
5+
PyReadwriteArray1, PyReadwriteArrayDyn,
46
};
57
use pyo3::{
68
pymodule,
@@ -70,5 +72,17 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
7072
x.readonly().as_array().sum()
7173
}
7274

75+
// example using timedelta64 array
76+
#[pyfn(m)]
77+
fn add_minutes_to_seconds(
78+
mut x: PyReadwriteArray1<Timedelta<units::Seconds>>,
79+
y: PyReadonlyArray1<Timedelta<units::Minutes>>,
80+
) {
81+
#[allow(deprecated)]
82+
Zip::from(x.as_array_mut())
83+
.and(y.as_array())
84+
.apply(|x, y| *x = (i64::from(*x) + 60 * i64::from(*y)).into());
85+
}
86+
7387
Ok(())
7488
}

examples/simple/tests/test_ext.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from rust_ext import axpy, conj, mult, extract
2+
from rust_ext import axpy, conj, mult, extract, add_minutes_to_seconds
33

44

55
def test_axpy():
@@ -24,3 +24,12 @@ def test_extract():
2424
x = np.arange(5.0)
2525
d = {"x": x}
2626
np.testing.assert_almost_equal(extract(d), 10.0)
27+
28+
29+
def test_add_minutes_to_seconds():
30+
x = np.array([10, 20, 30], dtype="timedelta64[s]")
31+
y = np.array([1, 2, 3], dtype="timedelta64[m]")
32+
33+
add_minutes_to_seconds(x, y)
34+
35+
assert np.all(x == np.array([70, 140, 210], dtype="timedelta64[s]"))

src/datetime.rs

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
//! Support datetimes and timedeltas
2+
//!
3+
//! [The corresponding section][datetime] of the NumPy documentation contains more information.
4+
//!
5+
//! # Example
6+
//!
7+
//! ```
8+
//! use numpy::{datetime::{units, Datetime, Timedelta}, PyArray1};
9+
//! use pyo3::Python;
10+
//! # use pyo3::types::PyDict;
11+
//!
12+
//! Python::with_gil(|py| {
13+
//! # let locals = py
14+
//! # .eval("{ 'np': __import__('numpy') }", None, None)
15+
//! # .unwrap()
16+
//! # .downcast::<PyDict>()
17+
//! # .unwrap();
18+
//! #
19+
//! let array = py
20+
//! .eval(
21+
//! "np.array([np.datetime64('2017-04-21')])",
22+
//! None,
23+
//! Some(locals),
24+
//! )
25+
//! .unwrap()
26+
//! .downcast::<PyArray1<Datetime<units::Days>>>()
27+
//! .unwrap();
28+
//!
29+
//! assert_eq!(
30+
//! array.get_owned(0).unwrap(),
31+
//! Datetime::<units::Days>::from(17_277)
32+
//! );
33+
//!
34+
//! let array = py
35+
//! .eval(
36+
//! "np.array([np.datetime64('2022-03-29')]) - np.array([np.datetime64('2017-04-21')])",
37+
//! None,
38+
//! Some(locals),
39+
//! )
40+
//! .unwrap()
41+
//! .downcast::<PyArray1<Timedelta<units::Days>>>()
42+
//! .unwrap();
43+
//!
44+
//! assert_eq!(
45+
//! array.get_owned(0).unwrap(),
46+
//! Timedelta::<units::Days>::from(1_803)
47+
//! );
48+
//! });
49+
//! ```
50+
//!
51+
//! [datetime]: https://numpy.org/doc/stable/reference/arrays.datetime.html
52+
53+
use std::fmt;
54+
use std::hash::Hash;
55+
use std::marker::PhantomData;
56+
57+
use pyo3::Python;
58+
59+
use crate::dtype::{Element, PyArrayDescr};
60+
use crate::npyffi::{PyArray_DatetimeDTypeMetaData, NPY_DATETIMEUNIT, NPY_TYPES};
61+
62+
/// Represents the [datetime units][datetime-units] supported by NumPy
63+
///
64+
/// [datetime-units]: https://numpy.org/doc/stable/reference/arrays.datetime.html#datetime-units
65+
pub trait Unit: Send + Sync + Clone + Copy + PartialEq + Eq + Hash + PartialOrd + Ord {
66+
/// The matching NumPy [datetime unit code][NPY_DATETIMEUNIT]
67+
///
68+
/// [NPY_DATETIMEUNIT]: https://github.com/numpy/numpy/blob/4c60b3263ac50e5e72f6a909e156314fc3c9cba0/numpy/core/include/numpy/ndarraytypes.h#L276
69+
const UNIT: NPY_DATETIMEUNIT;
70+
71+
/// The abbrevation used for debug formatting
72+
const ABBREV: &'static str;
73+
}
74+
75+
macro_rules! define_units {
76+
($($(#[$meta:meta])* $struct:ident => $unit:ident $abbrev:literal,)+) => {
77+
$(
78+
79+
$(#[$meta])*
80+
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
81+
pub struct $struct;
82+
83+
impl Unit for $struct {
84+
const UNIT: NPY_DATETIMEUNIT = NPY_DATETIMEUNIT::$unit;
85+
86+
const ABBREV: &'static str = $abbrev;
87+
}
88+
89+
)+
90+
};
91+
}
92+
93+
/// Predefined implementors of the [`Unit`] trait
94+
pub mod units {
95+
use super::*;
96+
97+
define_units!(
98+
#[doc = "Years, i.e. 12 months"]
99+
Years => NPY_FR_Y "a",
100+
#[doc = "Months, i.e. 30 days"]
101+
Months => NPY_FR_M "mo",
102+
#[doc = "Weeks, i.e. 7 days"]
103+
Weeks => NPY_FR_W "w",
104+
#[doc = "Days, i.e. 24 hours"]
105+
Days => NPY_FR_D "d",
106+
#[doc = "Hours, i.e. 60 minutes"]
107+
Hours => NPY_FR_h "h",
108+
#[doc = "Minutes, i.e. 60 seconds"]
109+
Minutes => NPY_FR_m "min",
110+
#[doc = "Seconds"]
111+
Seconds => NPY_FR_s "s",
112+
#[doc = "Milliseconds, i.e. 10^-3 seconds"]
113+
Milliseconds => NPY_FR_ms "ms",
114+
#[doc = "Microseconds, i.e. 10^-6 seconds"]
115+
Microseconds => NPY_FR_us "µs",
116+
#[doc = "Nanoseconds, i.e. 10^-9 seconds"]
117+
Nanoseconds => NPY_FR_ns "ns",
118+
#[doc = "Picoseconds, i.e. 10^-12 seconds"]
119+
Picoseconds => NPY_FR_ps "ps",
120+
#[doc = "Femtoseconds, i.e. 10^-15 seconds"]
121+
Femtoseconds => NPY_FR_fs "fs",
122+
#[doc = "Attoseconds, i.e. 10^-18 seconds"]
123+
Attoseconds => NPY_FR_as "as",
124+
);
125+
}
126+
127+
/// Corresponds to the [`datetime64`][scalars-datetime64] scalar type
128+
///
129+
/// [scalars-datetime64]: https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.datetime64
130+
#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
131+
#[repr(transparent)]
132+
pub struct Datetime<U: Unit>(i64, PhantomData<U>);
133+
134+
impl<U: Unit> From<i64> for Datetime<U> {
135+
fn from(val: i64) -> Self {
136+
Self(val, PhantomData)
137+
}
138+
}
139+
140+
impl<U: Unit> From<Datetime<U>> for i64 {
141+
fn from(val: Datetime<U>) -> Self {
142+
val.0
143+
}
144+
}
145+
146+
unsafe impl<U: Unit> Element for Datetime<U> {
147+
const IS_COPY: bool = true;
148+
149+
fn get_dtype(py: Python) -> &PyArrayDescr {
150+
// FIXME(adamreichold): Memoize these via the Unit trait
151+
let dtype = PyArrayDescr::new_from_npy_type(py, NPY_TYPES::NPY_DATETIME);
152+
153+
unsafe {
154+
set_unit(dtype, U::UNIT);
155+
}
156+
157+
dtype
158+
}
159+
}
160+
161+
impl<U: Unit> fmt::Debug for Datetime<U> {
162+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163+
write!(f, "Datetime({} {})", self.0, U::ABBREV)
164+
}
165+
}
166+
167+
/// Corresponds to the [`timedelta64`][scalars-datetime64] scalar type
168+
///
169+
/// [scalars-timedelta64]: https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.timedelta64
170+
#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
171+
#[repr(transparent)]
172+
pub struct Timedelta<U: Unit>(i64, PhantomData<U>);
173+
174+
impl<U: Unit> From<i64> for Timedelta<U> {
175+
fn from(val: i64) -> Self {
176+
Self(val, PhantomData)
177+
}
178+
}
179+
180+
impl<U: Unit> From<Timedelta<U>> for i64 {
181+
fn from(val: Timedelta<U>) -> Self {
182+
val.0
183+
}
184+
}
185+
186+
unsafe impl<U: Unit> Element for Timedelta<U> {
187+
const IS_COPY: bool = true;
188+
189+
fn get_dtype(py: Python) -> &PyArrayDescr {
190+
// FIXME(adamreichold): Memoize these via the Unit trait
191+
let dtype = PyArrayDescr::new_from_npy_type(py, NPY_TYPES::NPY_TIMEDELTA);
192+
193+
unsafe {
194+
set_unit(dtype, U::UNIT);
195+
}
196+
197+
dtype
198+
}
199+
}
200+
201+
impl<U: Unit> fmt::Debug for Timedelta<U> {
202+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203+
write!(f, "Timedelta({} {})", self.0, U::ABBREV)
204+
}
205+
}
206+
207+
unsafe fn set_unit(dtype: &PyArrayDescr, unit: NPY_DATETIMEUNIT) {
208+
let metadata = &mut *((*dtype.as_dtype_ptr()).c_metadata as *mut PyArray_DatetimeDTypeMetaData);
209+
210+
metadata.meta.base = unit;
211+
metadata.meta.num = 1;
212+
}
213+
214+
#[cfg(test)]
215+
mod tests {
216+
use super::*;
217+
218+
use pyo3::{
219+
py_run,
220+
types::{PyDict, PyModule},
221+
};
222+
223+
use crate::array::PyArray1;
224+
225+
#[test]
226+
fn from_python_to_rust() {
227+
Python::with_gil(|py| {
228+
let locals = py
229+
.eval("{ 'np': __import__('numpy') }", None, None)
230+
.unwrap()
231+
.downcast::<PyDict>()
232+
.unwrap();
233+
234+
let array = py
235+
.eval(
236+
"np.array([np.datetime64('1970-01-01')])",
237+
None,
238+
Some(locals),
239+
)
240+
.unwrap()
241+
.downcast::<PyArray1<Datetime<units::Days>>>()
242+
.unwrap();
243+
244+
let value: i64 = array.get_owned(0).unwrap().into();
245+
assert_eq!(value, 0);
246+
});
247+
}
248+
249+
#[test]
250+
fn from_rust_to_python() {
251+
Python::with_gil(|py| {
252+
let array = PyArray1::<Timedelta<units::Minutes>>::zeros(py, 1, false);
253+
254+
*array.readwrite().get_mut(0).unwrap() = Timedelta::<units::Minutes>::from(5);
255+
256+
let np = py
257+
.eval("__import__('numpy')", None, None)
258+
.unwrap()
259+
.downcast::<PyModule>()
260+
.unwrap();
261+
262+
py_run!(py, array np, "assert array.dtype == np.dtype('timedelta64[m]')");
263+
py_run!(py, array np, "assert array[0] == np.timedelta64(5, 'm')");
264+
});
265+
}
266+
267+
#[test]
268+
fn debug_formatting() {
269+
assert_eq!(
270+
format!("{:?}", Datetime::<units::Days>::from(28)),
271+
"Datetime(28 d)"
272+
);
273+
274+
assert_eq!(
275+
format!("{:?}", Timedelta::<units::Milliseconds>::from(160)),
276+
"Timedelta(160 ms)"
277+
);
278+
}
279+
280+
#[test]
281+
fn unit_conversion() {
282+
#[track_caller]
283+
fn convert<S: Unit, D: Unit>(py: Python<'_>, expected_value: i64) {
284+
let array = PyArray1::<Timedelta<S>>::from_slice(py, &[Timedelta::<S>::from(1)]);
285+
let array = array.cast::<Timedelta<D>>(false).unwrap();
286+
287+
let value: i64 = array.get_owned(0).unwrap().into();
288+
assert_eq!(value, expected_value);
289+
}
290+
291+
Python::with_gil(|py| {
292+
convert::<units::Years, units::Days>(py, (97 + 400 * 365) / 400);
293+
convert::<units::Months, units::Days>(py, (97 + 400 * 365) / 400 / 12);
294+
295+
convert::<units::Weeks, units::Seconds>(py, 7 * 24 * 60 * 60);
296+
convert::<units::Days, units::Seconds>(py, 24 * 60 * 60);
297+
convert::<units::Hours, units::Seconds>(py, 60 * 60);
298+
convert::<units::Minutes, units::Seconds>(py, 60);
299+
300+
convert::<units::Seconds, units::Milliseconds>(py, 1_000);
301+
convert::<units::Seconds, units::Microseconds>(py, 1_000_000);
302+
convert::<units::Seconds, units::Nanoseconds>(py, 1_000_000_000);
303+
convert::<units::Seconds, units::Picoseconds>(py, 1_000_000_000_000);
304+
convert::<units::Seconds, units::Femtoseconds>(py, 1_000_000_000_000_000);
305+
306+
convert::<units::Femtoseconds, units::Attoseconds>(py, 1_000);
307+
});
308+
}
309+
}

src/dtype.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,20 @@ impl PyArrayDescr {
127127
}
128128
}
129129

130-
pub(crate) fn from_npy_type(py: Python, npy_type: NPY_TYPES) -> &Self {
130+
fn from_npy_type(py: Python, npy_type: NPY_TYPES) -> &Self {
131131
unsafe {
132132
let descr = PY_ARRAY_API.PyArray_DescrFromType(py, npy_type as _);
133133
py.from_owned_ptr(descr as _)
134134
}
135135
}
136136

137+
pub(crate) fn new_from_npy_type(py: Python, npy_type: NPY_TYPES) -> &Self {
138+
unsafe {
139+
let descr = PY_ARRAY_API.PyArray_DescrNewFromType(py, npy_type as _);
140+
py.from_owned_ptr(descr as _)
141+
}
142+
}
143+
137144
/// Returns the [array scalar][arrays-scalars] corresponding to this type descriptor.
138145
///
139146
/// Equivalent to [`numpy.dtype.type`][dtype-type].

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
pub mod array;
4040
pub mod borrow;
4141
pub mod convert;
42+
pub mod datetime;
4243
mod dtype;
4344
mod error;
4445
pub mod npyffi;

0 commit comments

Comments
 (0)