Skip to content

Commit 2ebf5d3

Browse files
committed
feat: support datetime coordinates
1 parent cab3298 commit 2ebf5d3

File tree

5 files changed

+141
-22
lines changed

5 files changed

+141
-22
lines changed

src/common.rs

Lines changed: 83 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ use nuts_rs::Value;
66
use pyo3::{
77
exceptions::PyRuntimeError,
88
pyclass, pymethods,
9-
types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyListMethods, PyType},
10-
Bound, FromPyObject, IntoPyObject, IntoPyObjectExt, Py, PyAny, PyErr, Python,
9+
types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyListMethods, PyType, PyTypeMethods},
10+
Borrowed, Bound, BoundObject, FromPyObject, IntoPyObject, IntoPyObjectExt, Py, PyAny, PyErr,
11+
Python,
1112
};
1213
use smallvec::SmallVec;
1314

@@ -75,15 +76,29 @@ impl<'py> IntoPyObject<'py> for &ItemType {
7576
nuts_rs::ItemType::F32 => "float32",
7677
nuts_rs::ItemType::Bool => "bool",
7778
nuts_rs::ItemType::String => "object",
79+
nuts_rs::ItemType::DateTime64(unit) => match unit {
80+
nuts_rs::DateTimeUnit::Seconds => "datetime64[s]",
81+
nuts_rs::DateTimeUnit::Milliseconds => "datetime64[ms]",
82+
nuts_rs::DateTimeUnit::Microseconds => "datetime64[us]",
83+
nuts_rs::DateTimeUnit::Nanoseconds => "datetime64[ns]",
84+
},
85+
nuts_rs::ItemType::TimeDelta64(unit) => match unit {
86+
nuts_rs::DateTimeUnit::Seconds => "timedelta64[s]",
87+
nuts_rs::DateTimeUnit::Milliseconds => "timedelta64[ms]",
88+
nuts_rs::DateTimeUnit::Microseconds => "timedelta64[us]",
89+
nuts_rs::DateTimeUnit::Nanoseconds => "timedelta64[ns]",
90+
},
7891
};
7992
let numpy = py.import("numpy")?;
8093
let dtype = numpy.getattr("dtype")?.call1((dtype_str,))?;
8194
Ok(dtype)
8295
}
8396
}
8497

85-
impl<'py> FromPyObject<'py> for ItemType {
86-
fn extract_bound(ob: &Bound<'_, PyAny>) -> std::result::Result<Self, PyErr> {
98+
impl<'a, 'py> FromPyObject<'a, 'py> for ItemType {
99+
type Error = PyErr;
100+
101+
fn extract(ob: Borrowed<'a, 'py, PyAny>) -> std::result::Result<Self, PyErr> {
87102
let dtype_str: &str = ob.extract()?;
88103
let item_type = match dtype_str {
89104
"uint64" => nuts_rs::ItemType::U64,
@@ -106,12 +121,14 @@ impl<'py> FromPyObject<'py> for ItemType {
106121
#[pyclass]
107122
pub struct PyValue(Value);
108123

109-
impl<'py> FromPyObject<'py> for PyValue {
110-
fn extract_bound(ob: &Bound<'py, PyAny>) -> std::result::Result<Self, PyErr> {
124+
impl<'a, 'py> FromPyObject<'a, 'py> for PyValue {
125+
type Error = PyErr;
126+
127+
fn extract(ob: Borrowed<'a, 'py, PyAny>) -> std::result::Result<Self, PyErr> {
111128
let ob = if ob.hasattr("values")? {
112-
&ob.getattr("values")?
129+
ob.getattr("values")?
113130
} else {
114-
ob
131+
ob.into_bound()
115132
};
116133
if let Ok(arr) = ob.extract::<PyReadonlyArray1<f64>>() {
117134
let vec = arr
@@ -166,9 +183,35 @@ impl<'py> FromPyObject<'py> for PyValue {
166183
.collect::<Result<_, _>>()?;
167184
return Ok(PyValue(Value::Strings(vals_as_str)));
168185
}
169-
Err(PyRuntimeError::new_err(
170-
"Could not convert to Value. Unsupported type.",
171-
))
186+
187+
macro_rules! extract_time {
188+
($unit:ident, $type:ident, $value:ident) => {
189+
if let Ok(arr) = ob.extract::<PyReadonlyArray1<numpy::datetime::$type<numpy::datetime::units::$unit>>>() {
190+
let vec = arr
191+
.as_slice()
192+
.map_err(|_| PyRuntimeError::new_err("Array is not contiguous"))?;
193+
let vals_as_i64 = vec.iter().map(|&dt| dt.into()).collect();
194+
return Ok(PyValue(Value::$value(
195+
nuts_rs::DateTimeUnit::$unit,
196+
vals_as_i64,
197+
)));
198+
}
199+
};
200+
}
201+
202+
extract_time!(Seconds, Datetime, DateTime64);
203+
extract_time!(Milliseconds, Datetime, DateTime64);
204+
extract_time!(Microseconds, Datetime, DateTime64);
205+
extract_time!(Nanoseconds, Datetime, DateTime64);
206+
extract_time!(Seconds, Timedelta, TimeDelta64);
207+
extract_time!(Milliseconds, Timedelta, TimeDelta64);
208+
extract_time!(Microseconds, Timedelta, TimeDelta64);
209+
extract_time!(Nanoseconds, Timedelta, TimeDelta64);
210+
211+
Err(PyRuntimeError::new_err(format!(
212+
"Could not convert to Value. Unsupported type: {}",
213+
ob.get_type().name()?
214+
)))
172215
}
173216
}
174217

@@ -178,6 +221,23 @@ impl PyValue {
178221
}
179222

180223
pub fn into_array(self, py: Python) -> Result<Bound<PyAny>> {
224+
macro_rules! from_time {
225+
($unit:ident, $items:expr, $type:ident) => {
226+
Ok(
227+
PyArray1::<numpy::datetime::$type<numpy::datetime::units::$unit>>::from_vec(
228+
py,
229+
$items
230+
.into_iter()
231+
.map(|ts| {
232+
numpy::datetime::$type::<numpy::datetime::units::$unit>::from(ts)
233+
})
234+
.collect(),
235+
)
236+
.into_any(),
237+
)
238+
};
239+
}
240+
181241
match self.0 {
182242
Value::F64(vec) => Ok(PyArray1::from_vec(py, vec).into_any()),
183243
Value::F32(vec) => Ok(PyArray1::from_vec(py, vec).into_any()),
@@ -191,6 +251,18 @@ impl PyValue {
191251
Value::ScalarF64(val) => Ok(val.into_bound_py_any(py)?),
192252
Value::ScalarF32(val) => Ok(val.into_bound_py_any(py)?),
193253
Value::ScalarBool(val) => Ok(val.into_bound_py_any(py)?),
254+
Value::DateTime64(date_time_unit, items) => match date_time_unit {
255+
nuts_rs::DateTimeUnit::Seconds => from_time!(Seconds, items, Datetime),
256+
nuts_rs::DateTimeUnit::Milliseconds => from_time!(Milliseconds, items, Datetime),
257+
nuts_rs::DateTimeUnit::Microseconds => from_time!(Microseconds, items, Datetime),
258+
nuts_rs::DateTimeUnit::Nanoseconds => from_time!(Nanoseconds, items, Datetime),
259+
},
260+
Value::TimeDelta64(date_time_unit, items) => match date_time_unit {
261+
nuts_rs::DateTimeUnit::Seconds => from_time!(Seconds, items, Timedelta),
262+
nuts_rs::DateTimeUnit::Milliseconds => from_time!(Milliseconds, items, Timedelta),
263+
nuts_rs::DateTimeUnit::Microseconds => from_time!(Microseconds, items, Timedelta),
264+
nuts_rs::DateTimeUnit::Nanoseconds => from_time!(Nanoseconds, items, Timedelta),
265+
},
194266
}
195267
}
196268
}

src/pyfunc.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::{collections::HashMap, sync::Arc};
22

3-
use anyhow::{bail, Context, Result};
3+
use anyhow::{anyhow, bail, Context, Result};
44
use numpy::{
55
NotContiguousError, PyArray1, PyReadonlyArray1, PyReadonlyArrayDyn, PyUntypedArrayMethods,
66
};
@@ -67,7 +67,7 @@ impl PyModel {
6767
let key: String = key.extract().context("Coordinate key is not a string")?;
6868
let value: PyValue = value
6969
.extract()
70-
.context("Coordinate value has incorrect type")?;
70+
.with_context(|| format!("Coordinate {} value has unsupported type", key))?;
7171
Ok((key, value.into_value()))
7272
})
7373
.collect::<Result<HashMap<_, _>>>()?;
@@ -371,6 +371,20 @@ impl CpuLogpFunc for PyDensity {
371371
.collect::<Result<_, _>>()?;
372372
Some(Value::Strings(vec))
373373
}
374+
nuts_rs::ItemType::DateTime64(date_time_unit) => {
375+
let arr = as_value::<i64>(var, &val)?;
376+
let slice = arr.as_slice().map_err(|_| {
377+
nuts_rs::CpuMathError::ExpandError("Could not read as slice".into())
378+
})?;
379+
Some(Value::DateTime64(*date_time_unit, slice.to_vec()))
380+
}
381+
nuts_rs::ItemType::TimeDelta64(date_time_unit) => {
382+
let arr = as_value::<i64>(var, &val)?;
383+
let slice = arr.as_slice().map_err(|_| {
384+
nuts_rs::CpuMathError::ExpandError("Could not read as slice".into())
385+
})?;
386+
Some(Value::TimeDelta64(*date_time_unit, slice.to_vec()))
387+
}
374388
};
375389
expanded.push(val_array);
376390
}
@@ -528,7 +542,7 @@ impl Model for PyModel {
528542

529543
let init_point: PyReadonlyArray1<f64> = init_point
530544
.extract(py)
531-
.context("Initializition array returned incorrect argument")?;
545+
.map_err(|_| anyhow!("Initialization array returned incorrect argument"))?;
532546

533547
let init_point = init_point
534548
.as_slice()

src/pymc.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::{collections::HashMap, ffi::c_void, sync::Arc};
22

3-
use anyhow::{bail, Context, Result};
3+
use anyhow::{anyhow, bail, Context, Result};
44
use numpy::{NotContiguousError, PyReadonlyArray1};
55
use nuts_rs::{CpuLogpFunc, CpuMath, HasDims, LogpError, Model, Storable, Value};
66
use pyo3::{
@@ -265,6 +265,11 @@ impl CpuLogpFunc for PyMcModelRef<'_> {
265265
"String type not supported in expansion".into(),
266266
));
267267
}
268+
nuts_rs::ItemType::DateTime64(_) | nuts_rs::ItemType::TimeDelta64(_) => {
269+
return Err(nuts_rs::CpuMathError::ExpandError(
270+
"DateTime64 and TimeDelta64 types not supported in expansion".into(),
271+
));
272+
}
268273
};
269274

270275
values.push(Some(value));
@@ -437,7 +442,7 @@ impl PyMcModel {
437442
let key: String = key.extract().context("Coordinate key is not a string")?;
438443
let value: PyValue = value
439444
.extract()
440-
.context("Coordinate value has incorrect type")?;
445+
.with_context(|| format!("Coordinate {} value has unsupported type", key))?;
441446
Ok((key, value.into_value()))
442447
})
443448
.collect::<Result<HashMap<_, _>>>()?;
@@ -500,7 +505,7 @@ impl Model for PyMcModel {
500505

501506
let init_point: PyReadonlyArray1<f64> = init_point
502507
.extract(py)
503-
.context("Initializition array returned incorrect argument")?;
508+
.map_err(|_| anyhow!("Initialization array returned incorrect argument"))?;
504509

505510
let init_point = init_point
506511
.as_slice()

src/stan.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ impl StanModel {
294294
let key: String = key.extract().context("Coordinate key is not a string")?;
295295
let value: PyValue = value
296296
.extract()
297-
.context("Coordinate value has incorrect type")?;
297+
.with_context(|| format!("Coordinate {} value has unsupported type", key))?;
298298
Ok((key, value.into_value()))
299299
})
300300
.collect::<Result<HashMap<_, _>>>()?;

tests/test_pymc.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
from importlib.util import find_spec
21
import time
2+
from importlib.util import find_spec
3+
34
import pytest
45

56
if find_spec("pymc") is None:
67
pytest.skip("Skip pymc tests", allow_module_level=True)
78

89
import numpy as np
10+
import pandas as pd
911
import pymc as pm
1012
import pytest
1113

@@ -467,8 +469,17 @@ def test_deterministic_sampling_jax():
467469

468470
@pytest.mark.pymc
469471
def test_zarr_store(tmp_path):
470-
with pm.Model() as model:
471-
pm.HalfNormal("a")
472+
coords = {
473+
"a": np.arange(2).astype("f"),
474+
"b": pd.date_range("2023-01-01", periods=1),
475+
"c": ["x", "y", "z"],
476+
"d": [1],
477+
"e": pd.factorize(pd.Index(["foo"]))[1],
478+
"f": np.arange(2).astype("d"),
479+
}
480+
with pm.Model(coords=coords) as model:
481+
pm.HalfNormal("x")
482+
pm.Normal("y", dims=("a", "b", "c", "d", "e", "f"))
472483

473484
compiled = nutpie.compile_pymc_model(model, backend="numba")
474485

@@ -478,7 +489,24 @@ def test_zarr_store(tmp_path):
478489
trace = nutpie.sample(
479490
compiled, chains=2, seed=123, draws=100, tune=100, zarr_store=store
480491
)
481-
trace.load().posterior.a # noqa: B018
492+
trace.load().posterior.x
493+
494+
assert trace.posterior.coords["a"].dtype == np.float32
495+
assert trace.posterior.coords["b"].dtype == "datetime64[ns]"
496+
assert trace.posterior.coords["b"].values[0] == np.datetime64("2023-01-01")
497+
assert list(trace.posterior.coords["c"]) == ["x", "y", "z"]
498+
assert list(trace.posterior.coords["d"]) == [1]
499+
assert list(trace.posterior.coords["e"]) == ["foo"]
500+
assert trace.posterior.coords["f"].dtype == np.float64
501+
502+
trace = nutpie.sample(compiled, chains=2, seed=1234, draws=50, tune=50)
503+
assert trace.posterior.coords["a"].dtype == np.float32
504+
assert trace.posterior.coords["b"].dtype == "datetime64[ns]"
505+
assert trace.posterior.coords["b"].values[0] == np.datetime64("2023-01-01")
506+
assert list(trace.posterior.coords["c"]) == ["x", "y", "z"]
507+
assert list(trace.posterior.coords["d"]) == [1]
508+
assert list(trace.posterior.coords["e"]) == ["foo"]
509+
assert trace.posterior.coords["f"].dtype == np.float64
482510

483511

484512
@pytest.fixture

0 commit comments

Comments
 (0)