From 8acbb7dc3699234baf0a42d54ffc497ee44c52ac Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 3 Oct 2025 17:18:47 +0200 Subject: [PATCH 01/12] fix: keep python thread handle alive --- src/progress.rs | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/src/progress.rs b/src/progress.rs index 2b130a4..575efcc 100644 --- a/src/progress.rs +++ b/src/progress.rs @@ -1,4 +1,12 @@ -use std::{collections::BTreeMap, sync::Arc, time::Duration}; +use std::{ + collections::BTreeMap, + sync::{ + mpsc::{sync_channel, SyncSender}, + Arc, + }, + thread::spawn, + time::Duration, +}; use anyhow::{Context, Result}; use indicatif::ProgressBar; @@ -10,20 +18,38 @@ use upon::{Engine, Value}; pub struct ProgressHandler { engine: Engine<'static>, template: String, - callback: Arc>, rate: Duration, n_cores: usize, + updates: SyncSender, } impl ProgressHandler { pub fn new(callback: Arc>, rate: Duration, template: String, n_cores: usize) -> Self { let engine = Engine::new(); + + let (update_tx, update_rx) = sync_channel(1); + + spawn(move || { + Python::with_gil(move |py| { + py.allow_threads(move || { + let update = update_rx.recv(); + let Ok(update) = update else { + return; + }; + let res = Python::with_gil(|py| callback.call1(py, (update,))); + if let Err(err) = res { + eprintln!("Error in progress callback: {err}"); + } + }); + }); + }); + Self { engine, - callback, rate, template, n_cores, + updates: update_tx, } } @@ -50,7 +76,10 @@ impl ProgressHandler { progress_to_value(progress_update_count, self.n_cores, time_sampling, progress); let rendered = template.render_from(&self.engine, &progress).to_string(); let rendered = rendered.unwrap_or_else(|err| format!("{err}")); - let _ = Python::with_gil(|py| self.callback.call1(py, (rendered,))); + if let Err(e) = self.updates.send(rendered) { + eprintln!("Could not send progress update: {e}"); + return; + } progress_update_count += 1; }; From df2162d66283b3c3bc731789b0c961f8bc40da7c Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 16 Jul 2025 18:59:34 +0200 Subject: [PATCH 02/12] chore(release): prepare 0.15.2 From 31294e108375f8350dac1dd864e1c52f8ddc579a Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 2 Sep 2025 20:45:31 +0200 Subject: [PATCH 03/12] feat: support step size adaptation method --- src/wrapper.rs | 202 +++++++++++++++++++++++++++++++++++++------------ 1 file changed, 153 insertions(+), 49 deletions(-) diff --git a/src/wrapper.rs b/src/wrapper.rs index f29620c..67fa4b3 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -17,7 +17,7 @@ use arrow::array::Array; use numpy::{PyArray1, PyReadonlyArray1}; use nuts_rs::{ ChainProgress, DiagGradNutsSettings, LowRankNutsSettings, ProgressCallback, Sampler, - SamplerWaitResult, Trace, TransformedNutsSettings, + SamplerWaitResult, StepSizeAdaptMethod, Trace, TransformedNutsSettings, }; use pyo3::{ exceptions::PyTimeoutError, @@ -276,22 +276,13 @@ impl PyNutsSettings { fn initial_step(&self) -> f64 { match &self.inner { Settings::Diag(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .initial_step + nuts_settings.adapt_options.step_size_settings.initial_step } Settings::LowRank(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .initial_step + nuts_settings.adapt_options.step_size_settings.initial_step } Settings::Transforming(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .initial_step + nuts_settings.adapt_options.step_size_settings.initial_step } } } @@ -300,22 +291,13 @@ impl PyNutsSettings { fn set_initial_step(&mut self, val: f64) { match &mut self.inner { Settings::Diag(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .initial_step = val; + nuts_settings.adapt_options.step_size_settings.initial_step = val; } Settings::LowRank(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .initial_step = val; + nuts_settings.adapt_options.step_size_settings.initial_step = val; } Settings::Transforming(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .initial_step = val; + nuts_settings.adapt_options.step_size_settings.initial_step = val; } } } @@ -414,22 +396,13 @@ impl PyNutsSettings { fn set_target_accept(&self) -> f64 { match &self.inner { Settings::Diag(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .target_accept + nuts_settings.adapt_options.step_size_settings.target_accept } Settings::LowRank(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .target_accept + nuts_settings.adapt_options.step_size_settings.target_accept } Settings::Transforming(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .target_accept + nuts_settings.adapt_options.step_size_settings.target_accept } } } @@ -438,22 +411,13 @@ impl PyNutsSettings { fn target_accept(&mut self, val: f64) { match &mut self.inner { Settings::Diag(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .target_accept = val + nuts_settings.adapt_options.step_size_settings.target_accept = val } Settings::LowRank(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .target_accept = val + nuts_settings.adapt_options.step_size_settings.target_accept = val } Settings::Transforming(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .target_accept = val + nuts_settings.adapt_options.step_size_settings.target_accept = val } } } @@ -654,6 +618,146 @@ impl PyNutsSettings { } Ok(()) } + + #[getter] + fn step_size_adapt_method(&self) -> String { + let method = match &self.inner { + Settings::LowRank(inner) => inner.adapt_options.step_size_settings.adapt_options.method, + Settings::Diag(inner) => inner.adapt_options.step_size_settings.adapt_options.method, + Settings::Transforming(inner) => { + inner.adapt_options.step_size_settings.adapt_options.method + } + }; + + match method { + nuts_rs::StepSizeAdaptMethod::DualAverage => "dual_average", + nuts_rs::StepSizeAdaptMethod::Adam => "adam", + nuts_rs::StepSizeAdaptMethod::Fixed(_) => "fixed", + } + .to_string() + } + + #[setter(step_size_adapt_method)] + fn set_step_size_adapt_method(&mut self, method: Py) -> Result<()> { + let method = Python::with_gil(|py| { + if let Ok(method) = method.extract::(py) { + match method.as_str() { + "dual_average" => Ok(StepSizeAdaptMethod::DualAverage), + "adam" => Ok(StepSizeAdaptMethod::Adam), + _ => { + if let Ok(step_size) = method.parse::() { + Ok(StepSizeAdaptMethod::Fixed(step_size)) + } else { + bail!("step_size_adapt_method must be a positive float when using fixed step size"); + } + } + } + } else { + bail!("step_size_adapt_method must be a string"); + } + })?; + + match &mut self.inner { + Settings::LowRank(inner) => { + inner.adapt_options.step_size_settings.adapt_options.method = method + } + Settings::Diag(inner) => { + inner.adapt_options.step_size_settings.adapt_options.method = method + } + Settings::Transforming(inner) => { + inner.adapt_options.step_size_settings.adapt_options.method = method + } + }; + Ok(()) + } + + #[getter] + fn step_size_adam_learning_rate(&self) -> Option { + match &self.inner { + Settings::LowRank(inner) => { + if let StepSizeAdaptMethod::Adam = + inner.adapt_options.step_size_settings.adapt_options.method + { + Some( + inner + .adapt_options + .step_size_settings + .adapt_options + .adam + .learning_rate, + ) + } else { + None + } + } + Settings::Diag(inner) => { + if let StepSizeAdaptMethod::Adam = + inner.adapt_options.step_size_settings.adapt_options.method + { + Some( + inner + .adapt_options + .step_size_settings + .adapt_options + .adam + .learning_rate, + ) + } else { + None + } + } + Settings::Transforming(inner) => { + if let StepSizeAdaptMethod::Adam = + inner.adapt_options.step_size_settings.adapt_options.method + { + Some( + inner + .adapt_options + .step_size_settings + .adapt_options + .adam + .learning_rate, + ) + } else { + None + } + } + } + } + + #[setter(step_size_adam_learning_rate)] + fn set_step_size_adam_learning_rate(&mut self, val: Option) -> Result<()> { + let Some(val) = val else { + return Ok(()); + }; + match &mut self.inner { + Settings::LowRank(inner) => { + inner + .adapt_options + .step_size_settings + .adapt_options + .adam + .learning_rate = val + } + Settings::Diag(inner) => { + inner + .adapt_options + .step_size_settings + .adapt_options + .adam + .learning_rate = val + } + Settings::Transforming(inner) => { + inner + .adapt_options + .step_size_settings + .adapt_options + .adam + .learning_rate = val + } + }; + Ok(()) + } } pub(crate) enum SamplerState { From d4886983b61bf46f1595647b225ffa9702530e1e Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 2 Sep 2025 20:45:31 +0200 Subject: [PATCH 04/12] chore: update pyo3 --- Cargo.lock | 541 ++++++++++++++++++++++++++++--------------------- Cargo.toml | 8 +- src/pyfunc.rs | 12 +- src/pymc.rs | 12 +- src/wrapper.rs | 32 +-- 5 files changed, 342 insertions(+), 263 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 26eb163..4bd1975 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -71,9 +71,9 @@ checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd" [[package]] name = "anyhow" -version = "1.0.98" +version = "1.0.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" +checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100" [[package]] name = "arrow" @@ -81,16 +81,34 @@ version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3f15b4c6b148206ff3a2b35002e08929c2462467b62b9c02036d9c34f9ef994" dependencies = [ - "arrow-arith", - "arrow-array", - "arrow-buffer", - "arrow-cast", - "arrow-data", - "arrow-ord", - "arrow-row", - "arrow-schema", - "arrow-select", - "arrow-string", + "arrow-arith 55.2.0", + "arrow-array 55.2.0", + "arrow-buffer 55.2.0", + "arrow-cast 55.2.0", + "arrow-data 55.2.0", + "arrow-ord 55.2.0", + "arrow-row 55.2.0", + "arrow-schema 55.2.0", + "arrow-select 55.2.0", + "arrow-string 55.2.0", +] + +[[package]] +name = "arrow" +version = "56.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c26b57282a08ae92f727497805122fec964c6245cfa0e13f0e75452eaf3bc41f" +dependencies = [ + "arrow-arith 56.1.0", + "arrow-array 56.1.0", + "arrow-buffer 56.1.0", + "arrow-cast 56.1.0", + "arrow-data 56.1.0", + "arrow-ord 56.1.0", + "arrow-row 56.1.0", + "arrow-schema 56.1.0", + "arrow-select 56.1.0", + "arrow-string 56.1.0", ] [[package]] @@ -99,10 +117,24 @@ version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "30feb679425110209ae35c3fbf82404a39a4c0436bb3ec36164d8bffed2a4ce4" dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", + "arrow-array 55.2.0", + "arrow-buffer 55.2.0", + "arrow-data 55.2.0", + "arrow-schema 55.2.0", + "chrono", + "num", +] + +[[package]] +name = "arrow-arith" +version = "56.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cebf38ca279120ff522f4954b81a39527425b6e9f615e6b72842f4de1ffe02b8" +dependencies = [ + "arrow-array 56.1.0", + "arrow-buffer 56.1.0", + "arrow-data 56.1.0", + "arrow-schema 56.1.0", "chrono", "num", ] @@ -114,9 +146,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "70732f04d285d49054a48b72c54f791bb3424abae92d27aafdf776c98af161c8" dependencies = [ "ahash", - "arrow-buffer", - "arrow-data", - "arrow-schema", + "arrow-buffer 55.2.0", + "arrow-data 55.2.0", + "arrow-schema 55.2.0", + "chrono", + "half", + "hashbrown", + "num", +] + +[[package]] +name = "arrow-array" +version = "56.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "744109142cdf8e7b02795e240e20756c2a782ac9180d4992802954a8f871c0de" +dependencies = [ + "ahash", + "arrow-buffer 56.1.0", + "arrow-data 56.1.0", + "arrow-schema 56.1.0", "chrono", "half", "hashbrown", @@ -134,17 +182,48 @@ dependencies = [ "num", ] +[[package]] +name = "arrow-buffer" +version = "56.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "601bb103c4c374bcd1f62c66bcea67b42a2ee91a690486c37d4c180236f11ccc" +dependencies = [ + "bytes", + "half", + "num", +] + [[package]] name = "arrow-cast" version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e4f12eccc3e1c05a766cafb31f6a60a46c2f8efec9b74c6e0648766d30686af8" dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "arrow-select", + "arrow-array 55.2.0", + "arrow-buffer 55.2.0", + "arrow-data 55.2.0", + "arrow-schema 55.2.0", + "arrow-select 55.2.0", + "atoi", + "base64", + "chrono", + "half", + "lexical-core", + "num", + "ryu", +] + +[[package]] +name = "arrow-cast" +version = "56.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eed61d9d73eda8df9e3014843def37af3050b5080a9acbe108f045a316d5a0be" +dependencies = [ + "arrow-array 56.1.0", + "arrow-buffer 56.1.0", + "arrow-data 56.1.0", + "arrow-schema 56.1.0", + "arrow-select 56.1.0", "atoi", "base64", "chrono", @@ -160,8 +239,20 @@ version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8de1ce212d803199684b658fc4ba55fb2d7e87b213de5af415308d2fee3619c2" dependencies = [ - "arrow-buffer", - "arrow-schema", + "arrow-buffer 55.2.0", + "arrow-schema 55.2.0", + "half", + "num", +] + +[[package]] +name = "arrow-data" +version = "56.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43407f2c6ba2367f64d85d4603d6fb9c4b92ed79d2ffd21021b37efa96523e12" +dependencies = [ + "arrow-buffer 56.1.0", + "arrow-schema 56.1.0", "half", "num", ] @@ -172,11 +263,24 @@ version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6506e3a059e3be23023f587f79c82ef0bcf6d293587e3272d20f2d30b969b5a7" dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "arrow-select", + "arrow-array 55.2.0", + "arrow-buffer 55.2.0", + "arrow-data 55.2.0", + "arrow-schema 55.2.0", + "arrow-select 55.2.0", +] + +[[package]] +name = "arrow-ord" +version = "56.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c142a147dceb59d057bad82400f1693847c80dca870d008bf7b91caf902810ae" +dependencies = [ + "arrow-array 56.1.0", + "arrow-buffer 56.1.0", + "arrow-data 56.1.0", + "arrow-schema 56.1.0", + "arrow-select 56.1.0", ] [[package]] @@ -185,10 +289,23 @@ version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "52bf7393166beaf79b4bed9bfdf19e97472af32ce5b6b48169d321518a08cae2" dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", + "arrow-array 55.2.0", + "arrow-buffer 55.2.0", + "arrow-data 55.2.0", + "arrow-schema 55.2.0", + "half", +] + +[[package]] +name = "arrow-row" +version = "56.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dac6620667fccdab4204689ca173bd84a15de6bb6b756c3a8764d4d7d0c2fc04" +dependencies = [ + "arrow-array 56.1.0", + "arrow-buffer 56.1.0", + "arrow-data 56.1.0", + "arrow-schema 56.1.0", "half", ] @@ -201,6 +318,15 @@ dependencies = [ "bitflags", ] +[[package]] +name = "arrow-schema" +version = "56.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfa93af9ff2bb80de539e6eb2c1c8764abd0f4b73ffb0d7c82bf1f9868785e66" +dependencies = [ + "bitflags", +] + [[package]] name = "arrow-select" version = "55.2.0" @@ -208,10 +334,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd2b45757d6a2373faa3352d02ff5b54b098f5e21dccebc45a21806bc34501e5" dependencies = [ "ahash", - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", + "arrow-array 55.2.0", + "arrow-buffer 55.2.0", + "arrow-data 55.2.0", + "arrow-schema 55.2.0", + "num", +] + +[[package]] +name = "arrow-select" +version = "56.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be8b2e0052cd20d36d64f32640b68a5ab54d805d24a473baee5d52017c85536c" +dependencies = [ + "ahash", + "arrow-array 56.1.0", + "arrow-buffer 56.1.0", + "arrow-data 56.1.0", + "arrow-schema 56.1.0", "num", ] @@ -221,11 +361,28 @@ version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0377d532850babb4d927a06294314b316e23311503ed580ec6ce6a0158f49d40" dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "arrow-select", + "arrow-array 55.2.0", + "arrow-buffer 55.2.0", + "arrow-data 55.2.0", + "arrow-schema 55.2.0", + "arrow-select 55.2.0", + "memchr", + "num", + "regex", + "regex-syntax", +] + +[[package]] +name = "arrow-string" +version = "56.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2155e26e17f053c8975c546fc70cf19c00542f9abf43c23a88a46ef7204204f" +dependencies = [ + "arrow-array 56.1.0", + "arrow-buffer 56.1.0", + "arrow-data 56.1.0", + "arrow-schema 56.1.0", + "arrow-select 56.1.0", "memchr", "num", "regex", @@ -281,9 +438,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.9.1" +version = "2.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" +checksum = "2261d10cca569e4643e526d8dc2e62e433cc8aba21ab764233731f8d369bf394" [[package]] name = "block-buffer" @@ -304,7 +461,7 @@ dependencies = [ "libloading", "log", "path-absolutize", - "thiserror 2.0.12", + "thiserror 2.0.16", ] [[package]] @@ -315,9 +472,9 @@ checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" [[package]] name = "bytemuck" -version = "1.23.1" +version = "1.23.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c76a5792e44e4abe34d3abf15636779261d45a7450612059293d1d2cfc63422" +checksum = "3995eaeebcdf32f91f980d360f78732ddc061097ab4e39991ae7a6ace9194677" [[package]] name = "byteorder" @@ -359,10 +516,11 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.29" +version = "1.2.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c1599538de2394445747c8cf7935946e3cc27e9625f889d979bfb2aaf569362" +checksum = "590f9024a68a8c40351881787f1934dc11afd69090f5edb6831464694d836ea3" dependencies = [ + "find-msvc-tools", "jobserver", "libc", "shlex", @@ -379,9 +537,9 @@ dependencies = [ [[package]] name = "cfg-if" -version = "1.0.1" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" +checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" [[package]] name = "chrono" @@ -445,18 +603,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.41" +version = "4.5.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be92d32e80243a54711e5d7ce823c35c41c9d929dc4ab58e1276f625841aadf9" +checksum = "7eac00902d9d136acd712710d71823fb8ac8004ca445a89e73a41d45aa712931" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.41" +version = "4.5.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "707eab41e9622f9139419d573eca0900137718000c517d47da73045f54331c3d" +checksum = "2ad9bbf750e73b5884fb8a211a9424a1906c1e156724260fdae972f31d70e1d6" dependencies = [ "anstyle", "clap_lex", @@ -478,7 +636,7 @@ dependencies = [ "libc", "once_cell", "unicode-width", - "windows-sys 0.60.2", + "windows-sys", ] [[package]] @@ -533,9 +691,9 @@ dependencies = [ [[package]] name = "criterion" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bf7af66b0989381bd0be551bd7cc91912a655a58c6918420c9527b1fd8b4679" +checksum = "e1c047a62b0cc3e145fa84415a3191f628e980b194c2755aa12300a4e6cbd928" dependencies = [ "anes", "cast", @@ -556,12 +714,12 @@ dependencies = [ [[package]] name = "criterion-plot" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +checksum = "9b1bcc0dc7dfae599d84ad0b1a55f80cde8af3725da8313b528da95ef783e338" dependencies = [ "cast", - "itertools 0.10.5", + "itertools 0.13.0", ] [[package]] @@ -607,9 +765,9 @@ dependencies = [ [[package]] name = "deranged" -version = "0.4.0" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" +checksum = "d630bccd429a5bb5a64b5e94f693bfc48c9f8566418fda4c494cc94f911f87cc" dependencies = [ "powerfmt", ] @@ -736,6 +894,12 @@ dependencies = [ "reborrow", ] +[[package]] +name = "find-msvc-tools" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e178e4fba8a2726903f6ba98a6d221e76f9c12c650d5dc0e6afdc50677b49650" + [[package]] name = "flate2" version = "1.1.2" @@ -879,14 +1043,14 @@ dependencies = [ "cfg-if", "libc", "r-efi", - "wasi 0.14.2+wasi-0.2.4", + "wasi 0.14.3+wasi-0.2.4", ] [[package]] name = "glob" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" [[package]] name = "half" @@ -901,9 +1065,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.15.4" +version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5971ac85611da7067dbfcabef3c70ebb5606018acd9e2a3903a0da507521e0d5" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" [[package]] name = "heck" @@ -972,15 +1136,6 @@ dependencies = [ "generic-array", ] -[[package]] -name = "itertools" -version = "0.10.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.13.0" @@ -1007,9 +1162,9 @@ checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "jobserver" -version = "0.1.33" +version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" dependencies = [ "getrandom 0.3.3", "libc", @@ -1097,9 +1252,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.174" +version = "0.2.175" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776" +checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" [[package]] name = "libloading" @@ -1108,7 +1263,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" dependencies = [ "cfg-if", - "windows-targets 0.53.2", + "windows-targets", ] [[package]] @@ -1341,9 +1496,9 @@ dependencies = [ [[package]] name = "numpy" -version = "0.25.0" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29f1dee9aa8d3f6f8e8b9af3803006101bb3653866ef056d530d53ae68587191" +checksum = "9b2dba356160b54f5371b550575b78130a54718b4c6e46b3f33a6da74a27e78b" dependencies = [ "libc", "ndarray", @@ -1360,7 +1515,7 @@ name = "nutpie" version = "0.15.2" dependencies = [ "anyhow", - "arrow", + "arrow 56.1.0", "bridgestan", "criterion", "indicatif", @@ -1368,13 +1523,13 @@ dependencies = [ "numpy", "nuts-rs", "pyo3", - "rand 0.9.1", + "rand 0.9.2", "rand_chacha 0.9.0", "rand_distr", "rayon", "smallvec", "tch", - "thiserror 2.0.12", + "thiserror 2.0.16", "time-humanize", "upon", ] @@ -1386,15 +1541,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acad2be84df0d14341d8de7d30c1019ecc008f4722befbd45745092a918c0a02" dependencies = [ "anyhow", - "arrow", + "arrow 55.2.0", "faer", "itertools 0.14.0", "pulp", - "rand 0.9.1", + "rand 0.9.2", "rand_chacha 0.9.0", "rand_distr", "rayon", - "thiserror 2.0.12", + "thiserror 2.0.16", ] [[package]] @@ -1522,9 +1677,9 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.35" +version = "0.2.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "061c1221631e079b26479d25bbf2275bfe5917ae8419cd7e34f13bfc2aa7539a" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", "syn", @@ -1532,9 +1687,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.95" +version = "1.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" +checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" dependencies = [ "unicode-ident", ] @@ -1555,9 +1710,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.25.1" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8970a78afe0628a3e3430376fc5fd76b6b45c4d43360ffd6cdd40bdde72b682a" +checksum = "7ba0117f4212101ee6544044dae45abe1083d30ce7b29c4b5cbdfa2354e07383" dependencies = [ "anyhow", "indoc", @@ -1573,19 +1728,18 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.25.1" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "458eb0c55e7ece017adeba38f2248ff3ac615e53660d7c71a238d7d2a01c7598" +checksum = "4fc6ddaf24947d12a9aa31ac65431fb1b851b8f4365426e182901eabfb87df5f" dependencies = [ - "once_cell", "target-lexicon", ] [[package]] name = "pyo3-ffi" -version = "0.25.1" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7114fe5457c61b276ab77c5055f206295b812608083644a5c5b2640c3102565c" +checksum = "025474d3928738efb38ac36d4744a74a400c901c7596199e20e45d98eb194105" dependencies = [ "libc", "pyo3-build-config", @@ -1593,9 +1747,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.25.1" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8725c0a622b374d6cb051d11a0983786448f7785336139c3c94f5aa6bef7e50" +checksum = "2e64eb489f22fe1c95911b77c44cc41e7c19f3082fc81cce90f657cdc42ffded" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -1605,9 +1759,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.25.1" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4109984c22491085343c05b0dbc54ddc405c3cf7b4374fc533f5c3313a572ccc" +checksum = "100246c0ecf400b475341b8455a9213344569af29a3c841d29270e53102e0fcf" dependencies = [ "heck", "proc-macro2", @@ -1656,9 +1810,9 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.1" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.3", @@ -1709,7 +1863,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" dependencies = [ "num-traits", - "rand 0.9.1", + "rand 0.9.2", ] [[package]] @@ -1729,9 +1883,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" [[package]] name = "rayon" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" dependencies = [ "either", "rayon-core", @@ -1739,9 +1893,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.12.1" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" dependencies = [ "crossbeam-deque", "crossbeam-utils", @@ -1755,9 +1909,9 @@ checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" [[package]] name = "regex" -version = "1.11.1" +version = "1.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +checksum = "23d7fd106d8c02486a8d64e778353d1cffe08ce79ac2e82f540c86d0facf6912" dependencies = [ "aho-corasick", "memchr", @@ -1767,9 +1921,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +checksum = "6b9458fa0bfeeac22b5ca447c63aaf45f28439a709ccd244698632f9aa6394d6" dependencies = [ "aho-corasick", "memchr", @@ -1778,9 +1932,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001" [[package]] name = "rustc-hash" @@ -1790,9 +1944,9 @@ checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" [[package]] name = "rustversion" -version = "1.0.21" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" [[package]] name = "ryu" @@ -1847,9 +2001,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.140" +version = "1.0.143" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +checksum = "d401abef1d108fbd9cbaebc3e46611f4b1021f714a0597a71f41ee463f5f4a5a" dependencies = [ "itoa", "memchr", @@ -1905,9 +2059,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.104" +version = "2.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" +checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" dependencies = [ "proc-macro2", "quote", @@ -1948,11 +2102,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.12" +version = "2.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" +checksum = "3467d614147380f2e4e374161426ff399c91084acd2363eaf549172b3d5e60c0" dependencies = [ - "thiserror-impl 2.0.12", + "thiserror-impl 2.0.16", ] [[package]] @@ -1968,9 +2122,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.12" +version = "2.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" +checksum = "6c5e1be1c48b9172ee610da68fd9cd2770e7a4056cb3fc98710ee6906f0c7960" dependencies = [ "proc-macro2", "quote", @@ -1979,9 +2133,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.41" +version = "0.3.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" +checksum = "8ca967379f9d8eb8058d86ed467d81d03e81acd45757e4ca341c24affbe8e8e3" dependencies = [ "deranged", "num-conv", @@ -1992,9 +2146,9 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" +checksum = "a9108bb380861b07264b950ded55a44a14a4adc68b9f5efd85aafc3aa4d40a68" [[package]] name = "time-humanize" @@ -2093,11 +2247,11 @@ checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasi" -version = "0.14.2+wasi-0.2.4" +version = "0.14.3+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" +checksum = "6a51ae83037bdd272a9e28ce236db8c07016dd0d50c27038b3f407533c030c95" dependencies = [ - "wit-bindgen-rt", + "wit-bindgen", ] [[package]] @@ -2180,11 +2334,11 @@ dependencies = [ [[package]] name = "winapi-util" -version = "0.1.9" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +checksum = "0978bf7171b3d90bac376700cb56d606feb40f251a475a5d6634613564460b22" dependencies = [ - "windows-sys 0.59.0", + "windows-sys", ] [[package]] @@ -2246,146 +2400,74 @@ dependencies = [ "windows-link", ] -[[package]] -name = "windows-sys" -version = "0.59.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" -dependencies = [ - "windows-targets 0.52.6", -] - [[package]] name = "windows-sys" version = "0.60.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" dependencies = [ - "windows-targets 0.53.2", -] - -[[package]] -name = "windows-targets" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" -dependencies = [ - "windows_aarch64_gnullvm 0.52.6", - "windows_aarch64_msvc 0.52.6", - "windows_i686_gnu 0.52.6", - "windows_i686_gnullvm 0.52.6", - "windows_i686_msvc 0.52.6", - "windows_x86_64_gnu 0.52.6", - "windows_x86_64_gnullvm 0.52.6", - "windows_x86_64_msvc 0.52.6", + "windows-targets", ] [[package]] name = "windows-targets" -version = "0.53.2" +version = "0.53.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c66f69fcc9ce11da9966ddb31a40968cad001c5bedeb5c2b82ede4253ab48aef" +checksum = "d5fe6031c4041849d7c496a8ded650796e7b6ecc19df1a431c1a363342e5dc91" dependencies = [ - "windows_aarch64_gnullvm 0.53.0", - "windows_aarch64_msvc 0.53.0", - "windows_i686_gnu 0.53.0", - "windows_i686_gnullvm 0.53.0", - "windows_i686_msvc 0.53.0", - "windows_x86_64_gnu 0.53.0", - "windows_x86_64_gnullvm 0.53.0", - "windows_x86_64_msvc 0.53.0", + "windows-link", + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", ] -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" - [[package]] name = "windows_aarch64_gnullvm" version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" -[[package]] -name = "windows_aarch64_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" - [[package]] name = "windows_aarch64_msvc" version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" -[[package]] -name = "windows_i686_gnu" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" - [[package]] name = "windows_i686_gnu" version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" -[[package]] -name = "windows_i686_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" - [[package]] name = "windows_i686_gnullvm" version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" -[[package]] -name = "windows_i686_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" - [[package]] name = "windows_i686_msvc" version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" -[[package]] -name = "windows_x86_64_gnu" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" - [[package]] name = "windows_x86_64_gnu" version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" - [[package]] name = "windows_x86_64_gnullvm" version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" -[[package]] -name = "windows_x86_64_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" - [[package]] name = "windows_x86_64_msvc" version = "0.53.0" @@ -2393,13 +2475,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" [[package]] -name = "wit-bindgen-rt" -version = "0.39.0" +name = "wit-bindgen" +version = "0.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" -dependencies = [ - "bitflags", -] +checksum = "052283831dbae3d879dc7f51f3d92703a316ca49f91540417d38591826127814" [[package]] name = "zerocopy" diff --git a/Cargo.toml b/Cargo.toml index 6f1fe11..391dd25 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,13 +22,13 @@ crate-type = ["cdylib"] [dependencies] nuts-rs = "0.16.1" -numpy = "0.25.0" +numpy = "0.26.0" rand = "0.9.0" thiserror = "2.0.3" rand_chacha = "0.9.0" rayon = "1.10.0" # Keep arrow in sync with nuts-rs requirements -arrow = { version = "55.1.0", default-features = false, features = ["ffi"] } +arrow = { version = "56.1.0", default-features = false, features = ["ffi"] } anyhow = "1.0.72" itertools = "0.14.0" bridgestan = "2.6.1" @@ -40,11 +40,11 @@ indicatif = "0.18.0" tch = { version = "0.20.0", optional = true } [dependencies.pyo3] -version = "0.25.0" +version = "0.26.0" features = ["extension-module", "anyhow"] [dev-dependencies] -criterion = "0.6.0" +criterion = "0.7.0" [profile.release] lto = "fat" diff --git a/src/pyfunc.rs b/src/pyfunc.rs index de3ecd9..9640468 100644 --- a/src/pyfunc.rs +++ b/src/pyfunc.rs @@ -123,7 +123,7 @@ impl LogpError for PyLogpError { fn is_recoverable(&self) -> bool { match self { Self::BadLogp(_) => true, - Self::PyError(err) => Python::with_gil(|py| { + Self::PyError(err) => Python::attach(|py| { let Ok(attr) = err.value(py).getattr("is_recoverable") else { return false; }; @@ -149,7 +149,7 @@ impl PyDensity { dim: usize, transform_adapter: Option<&PyTransformAdapt>, ) -> Result { - let logp_func = Python::with_gil(|py| logp_clone_func.call0(py))?; + let logp_func = Python::attach(|py| logp_clone_func.call0(py))?; let transform_adapter = transform_adapter.cloned(); Ok(Self { logp: logp_func, @@ -164,7 +164,7 @@ impl CpuLogpFunc for PyDensity { type TransformParams = Py; fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result { - Python::with_gil(|py| { + Python::attach(|py| { let pos_array = PyArray1::from_slice(py, position); let result = self.logp.call1(py, (pos_array,)); match result { @@ -321,7 +321,7 @@ impl PyTrace { ) -> Result { let seed1 = rng.next_u64(); let seed2 = rng.next_u64(); - let expand = Python::with_gil(|py| { + let expand = Python::attach(|py| { make_expand_func .call1(py, (seed1, seed2, chain)) .context("Failed to call expand function factory") @@ -449,7 +449,7 @@ impl ExpandDtype { impl DrawStorage for PyTrace { fn append_value(&mut self, point: &[f64]) -> Result<()> { - Python::with_gil(|py| { + Python::attach(|py| { let point = PyArray1::from_slice(py, point); let full_point = self .expand @@ -647,7 +647,7 @@ impl Model for PyModel { let seed = rng.next_u64(); - Python::with_gil(|py| { + Python::attach(|py| { let init_point = init_func .call1(py, (seed,)) .context("Failed to initialize point")?; diff --git a/src/pymc.rs b/src/pymc.rs index 220ad5f..1001a29 100644 --- a/src/pymc.rs +++ b/src/pymc.rs @@ -12,7 +12,7 @@ use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Model, Settings}; use pyo3::{ pyclass, pymethods, types::{PyAnyMethods, PyList}, - Bound, Py, PyAny, PyObject, PyResult, Python, + Bound, Py, PyAny, PyResult, Python, }; use rand_distr::num_traits::CheckedEuclid; @@ -40,7 +40,7 @@ type RawExpandFunc = unsafe extern "C" fn( #[derive(Clone)] pub(crate) struct LogpFunc { func: RawLogpFunc, - _keep_alive: Arc, + _keep_alive: Arc>, user_data_ptr: UserData, dim: usize, } @@ -51,7 +51,7 @@ unsafe impl Sync for LogpFunc {} #[pymethods] impl LogpFunc { #[new] - fn new(dim: usize, ptr: usize, user_data_ptr: usize, keep_alive: PyObject) -> Self { + fn new(dim: usize, ptr: usize, user_data_ptr: usize, keep_alive: Py) -> Self { let func = unsafe { std::mem::transmute::<*const c_void, RawLogpFunc>(ptr as *const c_void) }; Self { @@ -67,7 +67,7 @@ impl LogpFunc { #[derive(Clone)] pub(crate) struct ExpandFunc { func: RawExpandFunc, - _keep_alive: Arc, + _keep_alive: Arc>, user_data_ptr: UserData, dim: usize, expanded_dim: usize, @@ -81,7 +81,7 @@ impl ExpandFunc { expanded_dim: usize, ptr: usize, user_data_ptr: usize, - keep_alive: PyObject, + keep_alive: Py, ) -> Self { let func = unsafe { std::mem::transmute::<*const c_void, RawExpandFunc>(ptr as *const c_void) }; @@ -306,7 +306,7 @@ impl Model for PyMcModel { ) -> Result<()> { let seed = rng.next_u64(); - Python::with_gil(|py| { + Python::attach(|py| { let init_point = self .init_func .call1(py, (seed,)) diff --git a/src/wrapper.rs b/src/wrapper.rs index 67fa4b3..8d312cb 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -910,7 +910,7 @@ impl PySampler { } fn is_finished(&mut self, py: Python<'_>) -> PyResult { - py.allow_threads(|| { + py.detach(|| { let guard = &mut self.0.lock().expect("Poisond sampler state mutex"); let slot = guard.deref_mut(); @@ -939,7 +939,7 @@ impl PySampler { } fn pause(&mut self, py: Python<'_>) -> PyResult<()> { - py.allow_threads(|| { + py.detach(|| { if let SamplerState::Running(ref mut control) = self .0 .lock() @@ -953,7 +953,7 @@ impl PySampler { } fn resume(&mut self, py: Python<'_>) -> PyResult<()> { - py.allow_threads(|| { + py.detach(|| { if let SamplerState::Running(ref mut control) = self .0 .lock() @@ -968,7 +968,7 @@ impl PySampler { #[pyo3(signature = (timeout_seconds=None))] fn wait(&mut self, py: Python<'_>, timeout_seconds: Option) -> PyResult<()> { - py.allow_threads(|| { + py.detach(|| { let guard = &mut self.0.lock().expect("Poisond sampler state mutex"); let slot = guard.deref_mut(); @@ -1016,7 +1016,7 @@ impl PySampler { } } - if let Err(err) = Python::with_gil(|py| py.check_signals()) { + if let Err(err) = Python::attach(|py| py.check_signals()) { break (SamplerState::Running(control), Err(err)); } }; @@ -1027,7 +1027,7 @@ impl PySampler { } fn abort(&mut self, py: Python<'_>) -> PyResult<()> { - py.allow_threads(|| { + py.detach(|| { let guard = &mut self.0.lock().expect("Poisond sampler state mutex"); let slot = guard.deref_mut(); @@ -1074,7 +1074,7 @@ impl PySampler { } fn inspect<'py>(&mut self, py: Python<'py>) -> PyResult> { - let trace = py.allow_threads(|| { + let trace = py.detach(|| { let mut guard = self.0.lock().unwrap(); let SamplerState::Running(ref mut sampler) = guard.deref_mut() else { return Err(anyhow::anyhow!("Sampler is not running"))?; @@ -1107,7 +1107,7 @@ fn trace_to_list(trace: Trace, py: Python<'_>) -> PyResult> { Ok(list) } -fn export_array(py: Python<'_>, data: Arc) -> PyResult { +fn export_array(py: Python<'_>, data: Arc) -> PyResult> { let pa = py.import("pyarrow")?; let array = pa.getattr("Array")?; @@ -1148,7 +1148,7 @@ impl PyTransformAdapt { transformed_position: &mut [f64], transformed_gradient: &mut [f64], ) -> Result { - Python::with_gil(|py| { + Python::attach(|py| { let untransformed_position = PyArray1::from_slice(py, untransformed_position); let untransformed_gradient = PyArray1::from_slice(py, untransformed_gradient); @@ -1203,7 +1203,7 @@ impl PyTransformAdapt { transformed_position: &[f64], transformed_gradient: &mut [f64], ) -> Result<(f64, f64)> { - Python::with_gil(|py| { + Python::attach(|py| { let transformed_position = PyArray1::from_slice(py, transformed_position); let output = params @@ -1236,7 +1236,7 @@ impl PyTransformAdapt { untransformed_position: &mut [f64], transformed_position: &[f64], ) -> Result> { - Python::with_gil(|py| { + Python::attach(|py| { let transformed_position = PyArray1::from_slice(py, transformed_position); let output = params @@ -1257,7 +1257,7 @@ impl PyTransformAdapt { untransformed_gradient: &[f64], transformed_gradient: &mut [f64], ) -> Result { - Python::with_gil(|py| { + Python::attach(|py| { let untransformed_gradient = PyArray1::from_slice(py, untransformed_gradient); let output = params @@ -1279,7 +1279,7 @@ impl PyTransformAdapt { transformed_position: &mut [f64], transformed_gradient: &mut [f64], ) -> Result<(f64, f64)> { - Python::with_gil(|py| { + Python::attach(|py| { let untransformed_position = PyArray1::from_slice(py, untransformed_position); let output = params @@ -1318,7 +1318,7 @@ impl PyTransformAdapt { untransformed_logp: impl ExactSizeIterator, params: &'a mut Py, ) -> Result<()> { - Python::with_gil(|py| { + Python::attach(|py| { let positions = PyList::new( py, untransformed_positions.map(|pos| PyArray1::from_slice(py, pos)), @@ -1345,7 +1345,7 @@ impl PyTransformAdapt { untransformed_gradient: &[f64], chain: u64, ) -> Result> { - Python::with_gil(|py| { + Python::attach(|py| { let position = PyArray1::from_slice(py, untransformed_position); let gradient = PyArray1::from_slice(py, untransformed_gradient); @@ -1358,7 +1358,7 @@ impl PyTransformAdapt { } pub fn transformation_id(&self, params: &Py) -> Result { - Python::with_gil(|py| { + Python::attach(|py| { let id: i64 = params .getattr(py, intern!(py, "transformation_id"))? .extract(py)?; From a0afd87a7a06967c1055734436baf782191dfbff Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 2 Sep 2025 20:45:31 +0200 Subject: [PATCH 05/12] feat: add argument for mindepth --- src/wrapper.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/wrapper.rs b/src/wrapper.rs index 8d312cb..9f4eaec 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -320,6 +320,24 @@ impl PyNutsSettings { } } + #[getter] + fn mindepth(&self) -> u64 { + match &self.inner { + Settings::Diag(nuts_settings) => nuts_settings.mindepth, + Settings::LowRank(nuts_settings) => nuts_settings.mindepth, + Settings::Transforming(nuts_settings) => nuts_settings.mindepth, + } + } + + #[setter(maxdepth)] + fn set_mindepth(&mut self, val: u64) { + match &mut self.inner { + Settings::Diag(nuts_settings) => nuts_settings.mindepth = val, + Settings::LowRank(nuts_settings) => nuts_settings.mindepth = val, + Settings::Transforming(nuts_settings) => nuts_settings.mindepth = val, + } + } + #[getter] fn store_gradient(&self) -> bool { match &self.inner { From ed1ae6b6cde657e7dc9d09e4d9567d2231a2944c Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 2 Sep 2025 20:45:31 +0200 Subject: [PATCH 06/12] fix: no errors for unused parameters --- src/wrapper.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/wrapper.rs b/src/wrapper.rs index 9f4eaec..00d290c 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -543,7 +543,10 @@ impl PyNutsSettings { } #[setter(mass_matrix_eigval_cutoff)] - fn set_mass_matrix_eigval_cutoff(&mut self, val: f64) -> Result<()> { + fn set_mass_matrix_eigval_cutoff(&mut self, val: Option) -> Result<()> { + let Some(val) = val else { + return Ok(()); + }; match &mut self.inner { Settings::LowRank(inner) => inner.adapt_options.mass_matrix_options.eigval_cutoff = val, Settings::Diag(_) => { @@ -570,7 +573,10 @@ impl PyNutsSettings { } #[setter(mass_matrix_gamma)] - fn set_mass_matrix_gamma(&mut self, val: f64) -> Result<()> { + fn set_mass_matrix_gamma(&mut self, val: Option) -> Result<()> { + let Some(val) = val else { + return Ok(()); + }; match &mut self.inner { Settings::LowRank(inner) => { inner.adapt_options.mass_matrix_options.gamma = val; From f31f47cbb4e06f095103f07f1d6ec5d2f59517d6 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 2 Sep 2025 21:08:47 +0200 Subject: [PATCH 07/12] feat: support free-threaded python build --- src/wrapper.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wrapper.rs b/src/wrapper.rs index 00d290c..725e99a 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -1392,7 +1392,7 @@ impl PyTransformAdapt { } /// A Python module implemented in Rust. -#[pymodule] +#[pymodule(gil_used = false)] pub fn _lib(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; From 68ff97bed66d9faa69a27bd165d84d906e3a0785 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 11 Sep 2025 09:39:27 +0200 Subject: [PATCH 08/12] feat: use new nuts-rs storage interface --- Cargo.lock | 2886 ++++++++++++++--- Cargo.toml | 19 +- pyproject.toml | 13 +- python/nutpie/compile_pymc.py | 36 +- python/nutpie/compile_stan.py | 19 +- python/nutpie/compiled_pyfunc.py | 32 +- python/nutpie/sample.py | 265 +- src/common.rs | 381 +++ src/lib.rs | 1 + src/progress.rs | 13 +- src/pyfunc.rs | 630 ++-- src/pymc.rs | 449 ++- src/stan.rs | 466 ++- src/wrapper.rs | 593 ++-- .../test_deterministic_sampling_jax.txt | 400 +-- .../test_deterministic_sampling_numba.txt | 400 +-- .../test_deterministic_sampling_stan.txt | 4 +- tests/test_pymc.py | 11 +- tests/test_stan.py | 6 +- 19 files changed, 4729 insertions(+), 1895 deletions(-) create mode 100644 src/common.rs diff --git a/Cargo.lock b/Cargo.lock index 4bd1975..2f93395 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,15 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 + +[[package]] +name = "addr2line" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b5d307320b3181d6d7954e663bd7c774a838b8220fe0593c86d9fb09f498b4b" +dependencies = [ + "gimli", +] [[package]] name = "adler2" @@ -43,10 +52,10 @@ dependencies = [ ] [[package]] -name = "android-tzdata" -version = "0.1.1" +name = "allocator-api2" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" [[package]] name = "android_system_properties" @@ -65,128 +74,73 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstyle" -version = "1.0.11" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" [[package]] name = "anyhow" -version = "1.0.99" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100" - -[[package]] -name = "arrow" -version = "55.2.0" +version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3f15b4c6b148206ff3a2b35002e08929c2462467b62b9c02036d9c34f9ef994" -dependencies = [ - "arrow-arith 55.2.0", - "arrow-array 55.2.0", - "arrow-buffer 55.2.0", - "arrow-cast 55.2.0", - "arrow-data 55.2.0", - "arrow-ord 55.2.0", - "arrow-row 55.2.0", - "arrow-schema 55.2.0", - "arrow-select 55.2.0", - "arrow-string 55.2.0", -] +checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" [[package]] name = "arrow" -version = "56.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c26b57282a08ae92f727497805122fec964c6245cfa0e13f0e75452eaf3bc41f" -dependencies = [ - "arrow-arith 56.1.0", - "arrow-array 56.1.0", - "arrow-buffer 56.1.0", - "arrow-cast 56.1.0", - "arrow-data 56.1.0", - "arrow-ord 56.1.0", - "arrow-row 56.1.0", - "arrow-schema 56.1.0", - "arrow-select 56.1.0", - "arrow-string 56.1.0", -] - -[[package]] -name = "arrow-arith" -version = "55.2.0" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30feb679425110209ae35c3fbf82404a39a4c0436bb3ec36164d8bffed2a4ce4" +checksum = "6e833808ff2d94ed40d9379848a950d995043c7fb3e81a30b383f4c6033821cc" dependencies = [ - "arrow-array 55.2.0", - "arrow-buffer 55.2.0", - "arrow-data 55.2.0", - "arrow-schema 55.2.0", - "chrono", - "num", + "arrow-arith", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-csv", + "arrow-data", + "arrow-ipc", + "arrow-json", + "arrow-ord", + "arrow-row", + "arrow-schema", + "arrow-select", + "arrow-string", ] [[package]] name = "arrow-arith" -version = "56.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cebf38ca279120ff522f4954b81a39527425b6e9f615e6b72842f4de1ffe02b8" -dependencies = [ - "arrow-array 56.1.0", - "arrow-buffer 56.1.0", - "arrow-data 56.1.0", - "arrow-schema 56.1.0", - "chrono", - "num", -] - -[[package]] -name = "arrow-array" -version = "55.2.0" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70732f04d285d49054a48b72c54f791bb3424abae92d27aafdf776c98af161c8" +checksum = "ad08897b81588f60ba983e3ca39bda2b179bdd84dced378e7df81a5313802ef8" dependencies = [ - "ahash", - "arrow-buffer 55.2.0", - "arrow-data 55.2.0", - "arrow-schema 55.2.0", + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", "chrono", - "half", - "hashbrown", "num", ] [[package]] name = "arrow-array" -version = "56.1.0" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "744109142cdf8e7b02795e240e20756c2a782ac9180d4992802954a8f871c0de" +checksum = "8548ca7c070d8db9ce7aa43f37393e4bfcf3f2d3681df278490772fd1673d08d" dependencies = [ "ahash", - "arrow-buffer 56.1.0", - "arrow-data 56.1.0", - "arrow-schema 56.1.0", + "arrow-buffer", + "arrow-data", + "arrow-schema", "chrono", + "chrono-tz", "half", - "hashbrown", - "num", -] - -[[package]] -name = "arrow-buffer" -version = "55.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "169b1d5d6cb390dd92ce582b06b23815c7953e9dfaaea75556e89d890d19993d" -dependencies = [ - "bytes", - "half", + "hashbrown 0.16.0", "num", ] [[package]] name = "arrow-buffer" -version = "56.1.0" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "601bb103c4c374bcd1f62c66bcea67b42a2ee91a690486c37d4c180236f11ccc" +checksum = "e003216336f70446457e280807a73899dd822feaf02087d31febca1363e2fccc" dependencies = [ "bytes", "half", @@ -195,18 +149,19 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "55.2.0" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4f12eccc3e1c05a766cafb31f6a60a46c2f8efec9b74c6e0648766d30686af8" +checksum = "919418a0681298d3a77d1a315f625916cb5678ad0d74b9c60108eb15fd083023" dependencies = [ - "arrow-array 55.2.0", - "arrow-buffer 55.2.0", - "arrow-data 55.2.0", - "arrow-schema 55.2.0", - "arrow-select 55.2.0", + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", "atoi", "base64", "chrono", + "comfy-table", "half", "lexical-core", "num", @@ -214,179 +169,167 @@ dependencies = [ ] [[package]] -name = "arrow-cast" -version = "56.1.0" +name = "arrow-csv" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eed61d9d73eda8df9e3014843def37af3050b5080a9acbe108f045a316d5a0be" +checksum = "bfa9bf02705b5cf762b6f764c65f04ae9082c7cfc4e96e0c33548ee3f67012eb" dependencies = [ - "arrow-array 56.1.0", - "arrow-buffer 56.1.0", - "arrow-data 56.1.0", - "arrow-schema 56.1.0", - "arrow-select 56.1.0", - "atoi", - "base64", + "arrow-array", + "arrow-cast", + "arrow-schema", "chrono", - "half", - "lexical-core", - "num", - "ryu", + "csv", + "csv-core", + "regex", ] [[package]] name = "arrow-data" -version = "55.2.0" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8de1ce212d803199684b658fc4ba55fb2d7e87b213de5af415308d2fee3619c2" +checksum = "a5c64fff1d142f833d78897a772f2e5b55b36cb3e6320376f0961ab0db7bd6d0" dependencies = [ - "arrow-buffer 55.2.0", - "arrow-schema 55.2.0", + "arrow-buffer", + "arrow-schema", "half", "num", ] [[package]] -name = "arrow-data" -version = "56.1.0" +name = "arrow-ipc" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43407f2c6ba2367f64d85d4603d6fb9c4b92ed79d2ffd21021b37efa96523e12" +checksum = "1d3594dcddccc7f20fd069bc8e9828ce37220372680ff638c5e00dea427d88f5" dependencies = [ - "arrow-buffer 56.1.0", - "arrow-schema 56.1.0", - "half", - "num", + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "flatbuffers", ] [[package]] -name = "arrow-ord" -version = "55.2.0" +name = "arrow-json" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6506e3a059e3be23023f587f79c82ef0bcf6d293587e3272d20f2d30b969b5a7" +checksum = "88cf36502b64a127dc659e3b305f1d993a544eab0d48cce704424e62074dc04b" dependencies = [ - "arrow-array 55.2.0", - "arrow-buffer 55.2.0", - "arrow-data 55.2.0", - "arrow-schema 55.2.0", - "arrow-select 55.2.0", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "chrono", + "half", + "indexmap", + "lexical-core", + "memchr", + "num", + "serde", + "serde_json", + "simdutf8", ] [[package]] name = "arrow-ord" -version = "56.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c142a147dceb59d057bad82400f1693847c80dca870d008bf7b91caf902810ae" -dependencies = [ - "arrow-array 56.1.0", - "arrow-buffer 56.1.0", - "arrow-data 56.1.0", - "arrow-schema 56.1.0", - "arrow-select 56.1.0", -] - -[[package]] -name = "arrow-row" -version = "55.2.0" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52bf7393166beaf79b4bed9bfdf19e97472af32ce5b6b48169d321518a08cae2" +checksum = "3c8f82583eb4f8d84d4ee55fd1cb306720cddead7596edce95b50ee418edf66f" dependencies = [ - "arrow-array 55.2.0", - "arrow-buffer 55.2.0", - "arrow-data 55.2.0", - "arrow-schema 55.2.0", - "half", + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", ] [[package]] name = "arrow-row" -version = "56.1.0" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dac6620667fccdab4204689ca173bd84a15de6bb6b756c3a8764d4d7d0c2fc04" +checksum = "9d07ba24522229d9085031df6b94605e0f4b26e099fb7cdeec37abd941a73753" dependencies = [ - "arrow-array 56.1.0", - "arrow-buffer 56.1.0", - "arrow-data 56.1.0", - "arrow-schema 56.1.0", + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", "half", ] [[package]] name = "arrow-schema" -version = "55.2.0" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af7686986a3bf2254c9fb130c623cdcb2f8e1f15763e7c71c310f0834da3d292" +checksum = "b3aa9e59c611ebc291c28582077ef25c97f1975383f1479b12f3b9ffee2ffabe" dependencies = [ "bitflags", + "serde", + "serde_json", ] [[package]] -name = "arrow-schema" -version = "56.1.0" +name = "arrow-select" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfa93af9ff2bb80de539e6eb2c1c8764abd0f4b73ffb0d7c82bf1f9868785e66" +checksum = "8c41dbbd1e97bfcaee4fcb30e29105fb2c75e4d82ae4de70b792a5d3f66b2e7a" dependencies = [ - "bitflags", + "ahash", + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "num", ] [[package]] -name = "arrow-select" -version = "55.2.0" +name = "arrow-string" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd2b45757d6a2373faa3352d02ff5b54b098f5e21dccebc45a21806bc34501e5" +checksum = "53f5183c150fbc619eede22b861ea7c0eebed8eaac0333eaa7f6da5205fd504d" dependencies = [ - "ahash", - "arrow-array 55.2.0", - "arrow-buffer 55.2.0", - "arrow-data 55.2.0", - "arrow-schema 55.2.0", + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "memchr", "num", + "regex", + "regex-syntax", ] [[package]] -name = "arrow-select" -version = "56.1.0" +name = "async-generic" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be8b2e0052cd20d36d64f32640b68a5ab54d805d24a473baee5d52017c85536c" +checksum = "ddf3728566eefa873833159754f5732fb0951d3649e6e5b891cc70d56dd41673" dependencies = [ - "ahash", - "arrow-array 56.1.0", - "arrow-buffer 56.1.0", - "arrow-data 56.1.0", - "arrow-schema 56.1.0", - "num", + "proc-macro2", + "quote", + "syn 2.0.106", ] [[package]] -name = "arrow-string" -version = "55.2.0" +name = "async-lock" +version = "3.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0377d532850babb4d927a06294314b316e23311503ed580ec6ce6a0158f49d40" +checksum = "5fd03604047cee9b6ce9de9f70c6cd540a0520c813cbd49bae61f33ab80ed1dc" dependencies = [ - "arrow-array 55.2.0", - "arrow-buffer 55.2.0", - "arrow-data 55.2.0", - "arrow-schema 55.2.0", - "arrow-select 55.2.0", - "memchr", - "num", - "regex", - "regex-syntax", + "event-listener", + "event-listener-strategy", + "pin-project-lite", ] [[package]] -name = "arrow-string" -version = "56.1.0" +name = "async-trait" +version = "0.1.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2155e26e17f053c8975c546fc70cf19c00542f9abf43c23a88a46ef7204204f" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ - "arrow-array 56.1.0", - "arrow-buffer 56.1.0", - "arrow-data 56.1.0", - "arrow-schema 56.1.0", - "arrow-select 56.1.0", - "memchr", - "num", - "regex", - "regex-syntax", + "proc-macro2", + "quote", + "syn 2.0.106", ] [[package]] @@ -398,12 +341,44 @@ dependencies = [ "num-traits", ] +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + +[[package]] +name = "auto_impl" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffdcb70bdbc4d478427380519163274ac86e52916e10f0a8889adf0f96d3fee7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "autocfg" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "backtrace" +version = "0.3.76" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb531853791a215d7c62a30daf0dde835f381ab5de4589cfe7c649d2cbe92bd6" +dependencies = [ + "addr2line", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", + "windows-link", +] + [[package]] name = "base64" version = "0.22.1" @@ -418,9 +393,9 @@ checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" [[package]] name = "bindgen" -version = "0.71.1" +version = "0.72.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3" +checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895" dependencies = [ "bitflags", "cexpr", @@ -433,7 +408,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn", + "syn 2.0.106", ] [[package]] @@ -451,17 +426,30 @@ dependencies = [ "generic-array", ] +[[package]] +name = "blosc-src" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb68d27ab5ceb94ae9cd343f6fbc7bb84543496d547ed7c0db6718175fd41cb6" +dependencies = [ + "cc", + "libz-sys", + "lz4-sys", + "snappy_src", + "zstd-sys", +] + [[package]] name = "bridgestan" -version = "2.6.2" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fcf23cdd20237d4699464b803c6aef49f547266514c7361c27b25875ee69298" +checksum = "f6d0e34116970162606ca313a4d3cf76b4828600877ae30959f6f122e434cb29" dependencies = [ "bindgen", "libloading", "log", "path-absolutize", - "thiserror 2.0.16", + "thiserror 2.0.17", ] [[package]] @@ -472,9 +460,23 @@ checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" [[package]] name = "bytemuck" -version = "1.23.2" +version = "1.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3995eaeebcdf32f91f980d360f78732ddc061097ab4e39991ae7a6ace9194677" +checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] [[package]] name = "byteorder" @@ -516,9 +518,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.35" +version = "1.2.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "590f9024a68a8c40351881787f1934dc11afd69090f5edb6831464694d836ea3" +checksum = "e1d05d92f4b1fd76aad469d46cdd858ca761576082cd37df81416691e50199fb" dependencies = [ "find-msvc-tools", "jobserver", @@ -541,18 +543,36 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "chrono" -version = "0.4.41" +version = "0.4.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" +checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2" dependencies = [ - "android-tzdata", "iana-time-zone", + "js-sys", "num-traits", + "serde", + "wasm-bindgen", "windows-link", ] +[[package]] +name = "chrono-tz" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6139a8597ed92cf816dfb33f5dd6cf0bb93a6adc938f11039f371bc5bcd26c3" +dependencies = [ + "chrono", + "phf", +] + [[package]] name = "ciborium" version = "0.2.2" @@ -603,18 +623,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.47" +version = "4.5.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7eac00902d9d136acd712710d71823fb8ac8004ca445a89e73a41d45aa712931" +checksum = "e2134bb3ea021b78629caa971416385309e0131b351b25e01dc16fb54e1b5fae" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.47" +version = "4.5.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ad9bbf750e73b5884fb8a211a9424a1906c1e156724260fdae972f31d70e1d6" +checksum = "c2ba64afa3c0a6df7fa517765e31314e983f51dda798ffba27b988194fb65dc9" dependencies = [ "anstyle", "clap_lex", @@ -626,17 +646,37 @@ version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" +[[package]] +name = "comfy-table" +version = "7.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0d05af1e006a2407bedef5af410552494ce5be9090444dbbcb57258c1af3d56" +dependencies = [ + "strum", + "strum_macros", + "unicode-width", +] + +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "console" -version = "0.16.0" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e09ced7ebbccb63b4c65413d821f2e00ce54c5ca4514ddc6b3c892fdbcbc69d" +checksum = "b430743a6eb14e9764d4260d4c0d8123087d504eeb9c48f2b2a5e810dd369df4" dependencies = [ "encode_unicode", "libc", "once_cell", "unicode-width", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -665,6 +705,16 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -680,6 +730,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32c" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a47af21622d091a8f0fb295b88bc886ac74efcc613efc19f5d0b21de5c89e47" +dependencies = [ + "rustc_version", +] + [[package]] name = "crc32fast" version = "1.5.0" @@ -722,6 +781,15 @@ dependencies = [ "itertools 0.13.0", ] +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -763,15 +831,57 @@ dependencies = [ "typenum", ] +[[package]] +name = "csv" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d02f3b0da4c6504f86e9cd789d8dbafab48c2321be74e9987593de5a894d93d" +dependencies = [ + "memchr", +] + [[package]] name = "deranged" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d630bccd429a5bb5a64b5e94f693bfc48c9f8566418fda4c494cc94f911f87cc" +checksum = "a41953f86f8a05768a6cda24def994fd2f424b04ec5c719cf89989779f199071" dependencies = [ "powerfmt", ] +[[package]] +name = "derive_more" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", + "unicode-xid", +] + [[package]] name = "digest" version = "0.10.7" @@ -783,15 +893,33 @@ dependencies = [ "subtle", ] +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "dyn-stack" -version = "0.13.0" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "490bd48eb68fffcfed519b4edbfd82c69cbe741d175b84f0e0cbe8c57cbe0bdd" +checksum = "1c4713e43e2886ba72b8271aa66c93d722116acf7a75555cce11dcde84388fe8" dependencies = [ "bytemuck", + "dyn-stack-macros", ] +[[package]] +name = "dyn-stack-macros" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05dbec7076f432bb132db738df90d87a4f5789e99f59e7b1219a6b8ef61eaa68" + [[package]] name = "either" version = "1.15.0" @@ -830,7 +958,7 @@ checksum = "3bf679796c0322556351f287a51b49e48f7c4986e727b5dd78c972d30e2e16cc" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] @@ -841,14 +969,41 @@ checksum = "44f23cf4b44bfce11a86ace86f8a73ffdec849c9fd00a386a53d278bd9e81fb3" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", +] + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "event-listener" +version = "5.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13b66accf52311f30a0db42147dadea9850cb48cd070028831ae5f5d4b856ab" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" +dependencies = [ + "event-listener", + "pin-project-lite", ] [[package]] name = "faer" -version = "0.22.6" +version = "0.23.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49fce40ad65c366fbc6cd70a99d09d1008f075280bf2455e558e163c82913a9f" +checksum = "3cb922206162d9405f9fc059052b3f997bdc92745da7bfd620645f5092df20d1" dependencies = [ "bytemuck", "dyn-stack", @@ -867,20 +1022,20 @@ dependencies = [ [[package]] name = "faer-macros" -version = "0.21.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d0a255d1442b5825c61812a7eafda9034ec53d969c98555251085e148428e6a" +checksum = "2cc4b8cd876795d3b19ddfd59b03faa303c0b8adb9af6e188e81fc647c485bb9" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] name = "faer-traits" -version = "0.22.1" +version = "0.23.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54febfcbb90edaab562d85447a94d500f1601f11db0b30d27da87ed6542c8f91" +checksum = "24b69235b5f54416286c485fb047f2f499fc935a4eee2caadf4757f3c94c7b62" dependencies = [ "bytemuck", "dyn-stack", @@ -896,31 +1051,151 @@ dependencies = [ [[package]] name = "find-msvc-tools" -version = "0.1.0" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e178e4fba8a2726903f6ba98a6d221e76f9c12c650d5dc0e6afdc50677b49650" +checksum = "0399f9d26e5191ce32c498bebd31e7a3ceabc2745f0ac54af3f335126c3f24b3" + +[[package]] +name = "flatbuffers" +version = "25.9.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09b6620799e7340ebd9968d2e0708eb82cf1971e9a16821e2091b6d6e475eed5" +dependencies = [ + "bitflags", + "rustc_version", +] [[package]] name = "flate2" -version = "1.1.2" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a3d7db9596fecd151c5f638c0ee5d5bd487b6e0ea232e5dc96d5250f6f94b1d" +checksum = "dc5a4e564e38c699f2880d3fda590bedc2e69f3f84cd48b457bd892ce61d0aa9" dependencies = [ "crc32fast", "miniz_oxide", ] [[package]] -name = "gemm" -version = "0.18.2" +name = "fnv" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab96b703d31950f1aeddded248bc95543c9efc7ac9c4a21fda8703a83ee35451" -dependencies = [ - "dyn-stack", - "gemm-c32", - "gemm-c64", - "gemm-common", - "gemm-f32", +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "gemm" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab96b703d31950f1aeddded248bc95543c9efc7ac9c4a21fda8703a83ee35451" +dependencies = [ + "dyn-stack", + "gemm-c32", + "gemm-c64", + "gemm-common", + "gemm-f32", "gemm-f64", "num-complex", "num-traits", @@ -1030,8 +1305,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi 0.11.1+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] @@ -1041,23 +1318,51 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi", - "wasi 0.14.3+wasi-0.2.4", + "wasi 0.14.7+wasi-0.2.4", + "wasm-bindgen", ] +[[package]] +name = "gimli" +version = "0.32.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e629b9b98ef3dd8afe6ca2bd0f89306cec16d43d907889945bc5d6687f2f13c7" + [[package]] name = "glob" version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" +[[package]] +name = "h2" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3c0b69cfcb4e1b9f1bf2f53f95f766e4661169728ec61cd3fe5a0166f2d1386" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "half" version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" dependencies = [ + "bytemuck", "cfg-if", "crunchy", "num-traits", @@ -1068,6 +1373,17 @@ name = "hashbrown" version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d" [[package]] name = "heck" @@ -1084,11 +1400,120 @@ dependencies = [ "digest", ] +[[package]] +name = "http" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "humantime" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424" + +[[package]] +name = "hyper" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb3aa54a13a0dfe7fbe3a59e0c76093041720fdc77b110cc0fc260fafb4dc51e" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "h2", + "http", + "http-body", + "httparse", + "itoa", + "pin-project-lite", + "pin-utils", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "rustls-native-certs", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8" +dependencies = [ + "base64", + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "http", + "http-body", + "hyper", + "ipnet", + "libc", + "percent-encoding", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", +] + [[package]] name = "iana-time-zone" -version = "0.1.63" +version = "0.1.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8" +checksum = "33e57f83510bb73707521ebaffa789ec8caf86f9657cad665b092b581d40e9fb" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -1108,6 +1533,123 @@ dependencies = [ "cc", ] +[[package]] +name = "icu_collections" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "200072f5d0e3614556f94a9930d5dc3e0662a652823904c3a75dc3b0af7fee47" +dependencies = [ + "displaydoc", + "potential_utf", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cde2700ccaed3872079a65fb1a78f6c0a36c91570f28755dda67bc8f7d9f00a" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "436880e8e18df4d7bbc06d58432329d6458cc84531f7ac5f024e93deadb37979" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00210d6893afc98edb752b664b8890f0ef174c8adbb8d0be9710fa66fbbf72d3" + +[[package]] +name = "icu_properties" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "016c619c1eeb94efb86809b015c58f479963de65bdb6253345c1a1276f22e32b" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "potential_utf", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "298459143998310acd25ffe6810ed544932242d3f07083eee1084d83a71bd632" + +[[package]] +name = "icu_provider" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03c80da27b5f4187909049ee2d72f276f0d9f99a42c306bd0131ecfe04d8e5af" +dependencies = [ + "displaydoc", + "icu_locale_core", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] +name = "indexmap" +version = "2.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b0f83760fb341a774ed326568e19f5a863af4a952def8c39f9ab92fd95b88e5" +dependencies = [ + "equivalent", + "hashbrown 0.16.0", +] + [[package]] name = "indicatif" version = "0.18.0" @@ -1136,6 +1678,42 @@ dependencies = [ "generic-array", ] +[[package]] +name = "inventory" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc61209c082fbeb19919bee74b176221b27223e27b65d781eb91af24eb1fb46e" +dependencies = [ + "rustversion", +] + +[[package]] +name = "io-uring" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "046fa2d4d00aea763528b4950358d0ead425372445dc8ff86312b3c69ff7727b" +dependencies = [ + "bitflags", + "cfg-if", + "libc", +] + +[[package]] +name = "ipnet" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" + +[[package]] +name = "iri-string" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc5ebe9c3a1a7a5127f920a418f7585e9e758e911d0466ed004f393b0e380b2" +dependencies = [ + "memchr", + "serde", +] + [[package]] name = "itertools" version = "0.13.0" @@ -1172,9 +1750,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.77" +version = "0.3.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" +checksum = "ec48937a97411dcb524a265206ccd4c90bb711fca92b2792c407f268825b9305" dependencies = [ "once_cell", "wasm-bindgen", @@ -1188,9 +1766,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "lexical-core" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b765c31809609075565a70b4b71402281283aeda7ecaf4818ac14a7b2ade8958" +checksum = "7d8d125a277f807e55a77304455eb7b1cb52f2b18c143b60e766c120bd64a594" dependencies = [ "lexical-parse-float", "lexical-parse-integer", @@ -1201,69 +1779,62 @@ dependencies = [ [[package]] name = "lexical-parse-float" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de6f9cb01fb0b08060209a057c048fcbab8717b4c1ecd2eac66ebfe39a65b0f2" +checksum = "52a9f232fbd6f550bc0137dcb5f99ab674071ac2d690ac69704593cb4abbea56" dependencies = [ "lexical-parse-integer", "lexical-util", - "static_assertions", ] [[package]] name = "lexical-parse-integer" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72207aae22fc0a121ba7b6d479e42cbfea549af1479c3f3a4f12c70dd66df12e" +checksum = "9a7a039f8fb9c19c996cd7b2fcce303c1b2874fe1aca544edc85c4a5f8489b34" dependencies = [ "lexical-util", - "static_assertions", ] [[package]] name = "lexical-util" -version = "1.0.6" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a82e24bf537fd24c177ffbbdc6ebcc8d54732c35b50a3f28cc3f4e4c949a0b3" -dependencies = [ - "static_assertions", -] +checksum = "2604dd126bb14f13fb5d1bd6a66155079cb9fa655b37f875b3a742c705dbed17" [[package]] name = "lexical-write-float" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5afc668a27f460fb45a81a757b6bf2f43c2d7e30cb5a2dcd3abf294c78d62bd" +checksum = "50c438c87c013188d415fbabbb1dceb44249ab81664efbd31b14ae55dabb6361" dependencies = [ "lexical-util", "lexical-write-integer", - "static_assertions", ] [[package]] name = "lexical-write-integer" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "629ddff1a914a836fb245616a7888b62903aae58fa771e1d83943035efa0f978" +checksum = "409851a618475d2d5796377cad353802345cba92c867d9fbcde9cf4eac4e14df" dependencies = [ "lexical-util", - "static_assertions", ] [[package]] name = "libc" -version = "0.2.175" +version = "0.2.176" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" +checksum = "58f929b4d672ea937a23a1ab494143d968337a5f47e56d0815df1e0890ddf174" [[package]] name = "libloading" -version = "0.8.8" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" +checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" dependencies = [ "cfg-if", - "windows-targets", + "windows-link", ] [[package]] @@ -1273,59 +1844,185 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" [[package]] -name = "log" -version = "0.4.27" +name = "libz-sys" +version = "1.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +checksum = "8b70e7a7df205e92a1a4cd9aaae7898dac0aa555503cc0a649494d0d60e7651d" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] [[package]] -name = "matrixmultiply" -version = "0.3.10" +name = "link-cplusplus" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +checksum = "7f78c730aaa7d0b9336a299029ea49f9ee53b0ed06e9202e8cb7db9bae7b8c82" dependencies = [ - "autocfg", - "rawpointer", + "cc", ] [[package]] -name = "memchr" -version = "2.7.5" +name = "litemap" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" +checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" [[package]] -name = "memoffset" -version = "0.9.1" +name = "lock_api" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" dependencies = [ - "autocfg", + "scopeguard", ] [[package]] -name = "minimal-lexical" -version = "0.2.1" +name = "log" +version = "0.4.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" +checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" [[package]] -name = "miniz_oxide" -version = "0.8.9" +name = "lru" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +checksum = "bfe949189f46fabb938b3a9a0be30fdd93fd8a09260da863399a8cf3db756ec8" dependencies = [ - "adler2", + "hashbrown 0.15.5", ] [[package]] -name = "nano-gemm" -version = "0.1.3" +name = "lru-slab" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb5ba2bea1c00e53de11f6ab5bd0761ba87dc0045d63b0c87ee471d2d3061376" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + +[[package]] +name = "lz4-sys" +version = "1.11.1+lz4-1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6" dependencies = [ - "equator 0.2.2", - "nano-gemm-c32", + "cc", + "libc", +] + +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "md-5" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" +dependencies = [ + "cfg-if", + "digest", +] + +[[package]] +name = "memchr" +version = "2.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", + "simd-adler32", +] + +[[package]] +name = "mio" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" +dependencies = [ + "libc", + "wasi 0.11.1+wasi-snapshot-preview1", + "windows-sys 0.59.0", +] + +[[package]] +name = "moka" +version = "0.12.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8261cd88c312e0004c1d51baad2980c66528dfdb2bee62003e643a4d8f86b077" +dependencies = [ + "async-lock", + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "equivalent", + "event-listener", + "futures-util", + "parking_lot", + "portable-atomic", + "rustc_version", + "smallvec", + "tagptr", + "uuid", +] + +[[package]] +name = "monostate" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3341a273f6c9d5bef1908f17b7267bbab0e95c9bf69a0d4dcf8e9e1b2c76ef67" +dependencies = [ + "monostate-impl", + "serde", + "serde_core", +] + +[[package]] +name = "monostate-impl" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4db6d5580af57bf992f59068d4ea26fd518574ff48d7639b255a36f9de6e7e9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "nano-gemm" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb5ba2bea1c00e53de11f6ab5bd0761ba87dc0045d63b0c87ee471d2d3061376" +dependencies = [ + "equator 0.2.2", + "nano-gemm-c32", "nano-gemm-c64", "nano-gemm-codegen", "nano-gemm-core", @@ -1500,6 +2197,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b2dba356160b54f5371b550575b78130a54718b4c6e46b3f33a6da74a27e78b" dependencies = [ + "half", "libc", "ndarray", "num-complex", @@ -1515,7 +2213,7 @@ name = "nutpie" version = "0.15.2" dependencies = [ "anyhow", - "arrow 56.1.0", + "arrow", "bridgestan", "criterion", "indicatif", @@ -1523,33 +2221,104 @@ dependencies = [ "numpy", "nuts-rs", "pyo3", + "pyo3-arrow", + "pyo3-object_store", "rand 0.9.2", "rand_chacha 0.9.0", "rand_distr", "rayon", "smallvec", "tch", - "thiserror 2.0.16", + "thiserror 2.0.17", "time-humanize", + "tokio", "upon", + "zarrs", + "zarrs_object_store", +] + +[[package]] +name = "nuts-derive" +version = "0.1.0" +dependencies = [ + "nuts-storable", + "proc-macro2", + "quote", + "syn 1.0.109", ] [[package]] name = "nuts-rs" version = "0.16.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acad2be84df0d14341d8de7d30c1019ecc008f4722befbd45745092a918c0a02" dependencies = [ "anyhow", - "arrow 55.2.0", + "arrow", + "arrow-schema", "faer", "itertools 0.14.0", + "nuts-derive", + "nuts-storable", "pulp", "rand 0.9.2", "rand_chacha 0.9.0", "rand_distr", "rayon", - "thiserror 2.0.16", + "serde", + "serde_json", + "thiserror 2.0.17", + "tokio", + "zarrs", +] + +[[package]] +name = "nuts-storable" +version = "0.1.0" + +[[package]] +name = "object" +version = "0.37.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff76201f031d8863c38aa7f905eca4f53abbfa15f609db4277d44cd8938f33fe" +dependencies = [ + "memchr", +] + +[[package]] +name = "object_store" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c1be0c6c22ec0817cdc77d3842f721a17fd30ab6965001415b5402a74e6b740" +dependencies = [ + "async-trait", + "base64", + "bytes", + "chrono", + "form_urlencoded", + "futures", + "http", + "http-body-util", + "httparse", + "humantime", + "hyper", + "itertools 0.14.0", + "md-5", + "parking_lot", + "percent-encoding", + "quick-xml", + "rand 0.9.2", + "reqwest", + "ring", + "rustls-pemfile", + "serde", + "serde_json", + "serde_urlencoded", + "thiserror 2.0.17", + "tokio", + "tracing", + "url", + "walkdir", + "wasm-bindgen-futures", + "web-time", ] [[package]] @@ -1564,6 +2333,51 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "page_size" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da" +dependencies = [ + "libc", + "winapi", +] + +[[package]] +name = "parking" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + [[package]] name = "password-hash" version = "0.4.2" @@ -1599,6 +2413,12 @@ dependencies = [ "once_cell", ] +[[package]] +name = "pathdiff" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" + [[package]] name = "pbkdf2" version = "0.11.0" @@ -1611,6 +2431,42 @@ dependencies = [ "sha2", ] +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "phf" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "913273894cec178f401a31ec4b656318d95473527be05c0752cc41cdc32be8b7" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_shared" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06005508882fb681fd97892ecff4b7fd0fee13ef1aa569f8695dae7ab9099981" +dependencies = [ + "siphasher", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "pkg-config" version = "0.3.32" @@ -1660,6 +2516,15 @@ dependencies = [ "portable-atomic", ] +[[package]] +name = "potential_utf" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84df19adbe5b5a0782edcab45899906947ab039ccf4573713735ee7de1e6b08a" +dependencies = [ + "zerovec", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -1682,7 +2547,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", - "syn", + "syn 2.0.106", ] [[package]] @@ -1715,6 +2580,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ba0117f4212101ee6544044dae45abe1083d30ce7b29c4b5cbdfa2354e07383" dependencies = [ "anyhow", + "chrono", + "chrono-tz", + "indexmap", "indoc", "libc", "memoffset", @@ -1726,6 +2594,38 @@ dependencies = [ "unindent", ] +[[package]] +name = "pyo3-arrow" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbbf9d6d0573f13480184e789095d6b5cfa11403d8d8311931bd5d111dbf007a" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "arrow-select", + "half", + "indexmap", + "numpy", + "pyo3", + "thiserror 1.0.69", +] + +[[package]] +name = "pyo3-async-runtimes" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6ee6d4cb3e8d5b925f5cdb38da183e0ff18122eb2048d4041c9e7034d026e23" +dependencies = [ + "futures", + "once_cell", + "pin-project-lite", + "pyo3", + "tokio", +] + [[package]] name = "pyo3-build-config" version = "0.26.0" @@ -1754,7 +2654,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn", + "syn 2.0.106", ] [[package]] @@ -1767,7 +2667,30 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn", + "syn 2.0.106", +] + +[[package]] +name = "pyo3-object_store" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cda46869b9ce0e94ca68a8c2f48fdc940a543ed5e2d9272c3e7cc4bcc579fd6" +dependencies = [ + "async-trait", + "bytes", + "chrono", + "futures", + "http", + "humantime", + "itertools 0.14.0", + "object_store", + "percent-encoding", + "pyo3", + "pyo3-async-runtimes", + "serde", + "thiserror 1.0.69", + "tokio", + "url", ] [[package]] @@ -1782,11 +2705,88 @@ dependencies = [ "pulp", ] +[[package]] +name = "quick-xml" +version = "0.38.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42a232e7487fc2ef313d96dde7948e7a3c05101870d8985e4fd8d26aedd27b89" +dependencies = [ + "memchr", + "serde", +] + +[[package]] +name = "quick_cache" +version = "0.6.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba15f5bccfb18c666351668b97bbff66da5093f96757ca15299e4e594fe1316e" +dependencies = [ + "ahash", + "equivalent", + "hashbrown 0.16.0", + "parking_lot", +] + +[[package]] +name = "quinn" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" +dependencies = [ + "bytes", + "cfg_aliases", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls", + "socket2", + "thiserror 2.0.17", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" +dependencies = [ + "bytes", + "getrandom 0.3.3", + "lru-slab", + "rand 0.9.2", + "ring", + "rustc-hash", + "rustls", + "rustls-pki-types", + "slab", + "thiserror 2.0.17", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" +dependencies = [ + "cfg_aliases", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.60.2", +] + [[package]] name = "quote" -version = "1.0.40" +version = "1.0.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +checksum = "ce25767e7b499d1b604768e7cde645d14cc8584231ea6b295e9c9eb22c02e1d1" dependencies = [ "proc-macro2", ] @@ -1868,9 +2868,9 @@ dependencies = [ [[package]] name = "raw-cpuid" -version = "11.5.0" +version = "11.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6df7ab838ed27997ba19a4664507e6f82b41fe6e20be42929332156e5e85146" +checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" dependencies = [ "bitflags", ] @@ -1902,16 +2902,34 @@ dependencies = [ ] [[package]] -name = "reborrow" -version = "0.5.5" +name = "rayon_iter_concurrent_limit" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" - +checksum = "d09ee01023de07fa073ce14c37cbe0a9e099c6b0b60a29cf4af6d04d9553fed7" +dependencies = [ + "rayon", +] + +[[package]] +name = "reborrow" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + [[package]] name = "regex" -version = "1.11.2" +version = "1.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23d7fd106d8c02486a8d64e778353d1cffe08ce79ac2e82f540c86d0facf6912" +checksum = "8b5288124840bee7b386bc413c487869b360b2b4ec421ea56425128692f2a82c" dependencies = [ "aho-corasick", "memchr", @@ -1921,9 +2939,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.10" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b9458fa0bfeeac22b5ca447c63aaf45f28439a709ccd244698632f9aa6394d6" +checksum = "833eb9ce86d40ef33cb1306d8accf7bc8ec2bfea4355cbdebb3df68b40925cad" dependencies = [ "aho-corasick", "memchr", @@ -1936,12 +2954,139 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001" +[[package]] +name = "reqwest" +version = "0.12.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d429f34c8092b2d42c7c93cec323bb4adeb7c67698f70839adec842ec10c7ceb" +dependencies = [ + "base64", + "bytes", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-util", + "js-sys", + "log", + "percent-encoding", + "pin-project-lite", + "quinn", + "rustls", + "rustls-native-certs", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tokio-rustls", + "tokio-util", + "tower", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-streams", + "web-sys", +] + +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.16", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustc-demangle" +version = "0.1.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" + [[package]] name = "rustc-hash" version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + +[[package]] +name = "rustls" +version = "0.23.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd3c25631629d034ce7cd9940adc9d45762d46de2b0f57193c4443b92c6d4d40" +dependencies = [ + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-native-certs" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "rustls-pki-types" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" +dependencies = [ + "web-time", + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10b3f4191e8a80e6b43eebabfac91e5dcecebb27a71f04e820c47ec41d314bf" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.22" @@ -1973,6 +3118,50 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schannel" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "security-framework" +version = "3.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3297343eaf830f66ede390ea39da1d462b6b0c1b000f420d0a83f898bbbe6ef" +dependencies = [ + "bitflags", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + [[package]] name = "seq-macro" version = "0.3.6" @@ -1981,34 +3170,69 @@ checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" [[package]] name = "serde" -version = "1.0.219" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] name = "serde_json" -version = "1.0.143" +version = "1.0.145" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d401abef1d108fbd9cbaebc3e46611f4b1021f714a0597a71f41ee463f5f4a5a" +checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" dependencies = [ + "indexmap", "itoa", "memchr", "ryu", "serde", + "serde_core", +] + +[[package]] +name = "serde_repr" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", ] [[package]] @@ -2039,6 +3263,30 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "simd-adler32" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" + +[[package]] +name = "simdutf8" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" + +[[package]] +name = "siphasher" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" + +[[package]] +name = "slab" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" + [[package]] name = "smallvec" version = "1.15.1" @@ -2046,10 +3294,49 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" [[package]] -name = "static_assertions" -version = "1.1.0" +name = "snappy_src" +version = "0.2.5+snappy.1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e1432067a55bcfb1fd522d2aca6537a4fcea32bba87ea86921226d14f9bad53" +dependencies = [ + "cc", + "link-cplusplus", +] + +[[package]] +name = "socket2" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "233504af464074f9d066d7b5416c5f9b894a5862a6506e306f7b816cdd6f1807" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + +[[package]] +name = "strum" +version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.106", +] [[package]] name = "subtle" @@ -2057,6 +3344,17 @@ version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.106" @@ -2069,10 +3367,36 @@ dependencies = [ ] [[package]] -name = "target-lexicon" +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +dependencies = [ + "futures-core", +] + +[[package]] +name = "synstructure" version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + +[[package]] +name = "target-lexicon" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df7f62577c25e07834649fc3b39fafdc597c0a3527dc1c60129201ccfcbaa50c" [[package]] name = "tch" @@ -2102,11 +3426,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.16" +version = "2.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3467d614147380f2e4e374161426ff399c91084acd2363eaf549172b3d5e60c0" +checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" dependencies = [ - "thiserror-impl 2.0.16", + "thiserror-impl 2.0.17", ] [[package]] @@ -2117,25 +3441,34 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] name = "thiserror-impl" -version = "2.0.16" +version = "2.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c5e1be1c48b9172ee610da68fd9cd2770e7a4056cb3fc98710ee6906f0c7960" +checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", +] + +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", ] [[package]] name = "time" -version = "0.3.42" +version = "0.3.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ca967379f9d8eb8058d86ed467d81d03e81acd45757e4ca341c24affbe8e8e3" +checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" dependencies = [ "deranged", "num-conv", @@ -2146,9 +3479,9 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9108bb380861b07264b950ded55a44a14a4adc68b9f5efd85aafc3aa4d40a68" +checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" [[package]] name = "time-humanize" @@ -2165,6 +3498,16 @@ dependencies = [ "crunchy", ] +[[package]] +name = "tinystr" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d4f6d1145dcb577acf783d4e601bc1d76a13337bb54e6233add580b07344c8b" +dependencies = [ + "displaydoc", + "zerovec", +] + [[package]] name = "tinytemplate" version = "1.2.1" @@ -2176,67 +3519,278 @@ dependencies = [ ] [[package]] -name = "torch-sys" -version = "0.20.0" +name = "tinyvec" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad6fa4ac5662b84047081375b007f102d4968d5a0191f567a9776294445af9ac" +checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" dependencies = [ - "anyhow", - "cc", - "libc", - "zip", + "tinyvec_macros", ] [[package]] -name = "typenum" -version = "1.18.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" - -[[package]] -name = "unicode-ident" -version = "1.0.18" +name = "tinyvec_macros" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] -name = "unicode-width" -version = "0.2.1" +name = "tokio" +version = "1.47.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a1a07cc7db3810833284e8d372ccdc6da29741639ecc70c9ec107df0fa6154c" +checksum = "89e49afdadebb872d3145a5638b59eb0691ea23e46ca484037cfab3b76b95038" +dependencies = [ + "backtrace", + "bytes", + "io-uring", + "libc", + "mio", + "pin-project-lite", + "slab", + "socket2", + "tokio-macros", + "windows-sys 0.59.0", +] [[package]] -name = "unindent" -version = "0.2.4" +name = "tokio-macros" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] [[package]] -name = "unit-prefix" -version = "0.5.1" +name = "tokio-rustls" +version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "323402cff2dd658f39ca17c789b502021b3f18707c91cdf22e3838e1b4023817" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] [[package]] -name = "upon" -version = "0.10.0" +name = "tokio-util" +version = "0.7.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3ead40aa15464f4d808014183fa0b030761ff6f57e162f7fc76d6a900df7a28" +checksum = "14307c986784f72ef81c89db7d9e28d6ac26d16213b109ea501696195e6e3ce5" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] [[package]] -name = "version_check" -version = "0.9.5" +name = "torch-sys" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +checksum = "ad6fa4ac5662b84047081375b007f102d4968d5a0191f567a9776294445af9ac" +dependencies = [ + "anyhow", + "cc", + "libc", + "zip", +] [[package]] -name = "walkdir" -version = "2.5.0" +name = "tower" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" dependencies = [ - "same-file", - "winapi-util", + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-http" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" +dependencies = [ + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "iri-string", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + +[[package]] +name = "tracing" +version = "0.1.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "tracing-core" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" +dependencies = [ + "once_cell", +] + +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + +[[package]] +name = "unicode-ident" +version = "1.0.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f63a545481291138910575129486daeaf8ac54aee4387fe7906919f7830c7d9d" + +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "unindent" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" + +[[package]] +name = "unit-prefix" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "323402cff2dd658f39ca17c789b502021b3f18707c91cdf22e3838e1b4023817" + +[[package]] +name = "unsafe_cell_slice" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6659959f702dcdaad77bd6e42a9409a32ceccc06943ec93c8a4306be00eb6cf1" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "upon" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3ead40aa15464f4d808014183fa0b030761ff6f57e162f7fc76d6a900df7a28" + +[[package]] +name = "url" +version = "2.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] +name = "uuid" +version = "1.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" +dependencies = [ + "getrandom 0.3.3", + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", ] [[package]] @@ -2247,44 +3801,67 @@ checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasi" -version = "0.14.3+wasi-0.2.4" +version = "0.14.7+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a51ae83037bdd272a9e28ce236db8c07016dd0d50c27038b3f407533c030c95" +checksum = "883478de20367e224c0090af9cf5f9fa85bed63a95c1abf3afc5c083ebc06e8c" +dependencies = [ + "wasip2", +] + +[[package]] +name = "wasip2" +version = "1.0.1+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" dependencies = [ "wit-bindgen", ] [[package]] name = "wasm-bindgen" -version = "0.2.100" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" +checksum = "c1da10c01ae9f1ae40cbfac0bac3b1e724b320abfcf52229f80b547c0d250e2d" dependencies = [ "cfg-if", "once_cell", "rustversion", "wasm-bindgen-macro", + "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.100" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" +checksum = "671c9a5a66f49d8a47345ab942e2cb93c7d1d0339065d4f8139c486121b43b19" dependencies = [ "bumpalo", "log", "proc-macro2", "quote", - "syn", + "syn 2.0.106", "wasm-bindgen-shared", ] +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e038d41e478cc73bae0ff9b36c60cff1c98b8f38f8d7e8061e79ee63608ac5c" +dependencies = [ + "cfg-if", + "js-sys", + "once_cell", + "wasm-bindgen", + "web-sys", +] + [[package]] name = "wasm-bindgen-macro" -version = "0.2.100" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" +checksum = "7ca60477e4c59f5f2986c50191cd972e3a50d8a95603bc9434501cf156a9a119" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2292,31 +3869,44 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.100" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" +checksum = "9f07d2f20d4da7b26400c9f4a0511e6e0345b040694e8a75bd41d578fa4421d7" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.100" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +checksum = "bad67dc8b2a1a6e5448428adec4c3e84c43e561d8c9ee8a9e5aabeb193ec41d1" dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" -version = "0.3.77" +version = "0.3.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" +checksum = "9367c417a924a74cae129e6a2ae3b47fabb1f8995595ab474029da749a8be120" dependencies = [ "js-sys", "wasm-bindgen", @@ -2332,20 +3922,42 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + [[package]] name = "winapi-util" -version = "0.1.10" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0978bf7171b3d90bac376700cb56d606feb40f251a475a5d6634613564460b22" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys", + "windows-sys 0.61.2", ] +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-core" -version = "0.61.2" +version = "0.62.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" dependencies = [ "windows-implement", "windows-interface", @@ -2356,148 +3968,486 @@ dependencies = [ [[package]] name = "windows-implement" -version = "0.60.0" +version = "0.60.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] name = "windows-interface" -version = "0.59.1" +version = "0.59.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] name = "windows-link" -version = "0.1.3" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" [[package]] name = "windows-result" -version = "0.3.4" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" dependencies = [ "windows-link", ] [[package]] name = "windows-strings" -version = "0.4.2" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" dependencies = [ "windows-link", ] +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.60.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" dependencies = [ - "windows-targets", + "windows-targets 0.53.5", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] [[package]] name = "windows-targets" -version = "0.53.3" +version = "0.53.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5fe6031c4041849d7c496a8ded650796e7b6ecc19df1a431c1a363342e5dc91" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" dependencies = [ "windows-link", - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", ] [[package]] name = "windows_aarch64_gnullvm" -version = "0.53.0" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" [[package]] name = "windows_aarch64_msvc" -version = "0.53.0" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" [[package]] name = "windows_i686_gnu" -version = "0.53.0" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_gnullvm" -version = "0.53.0" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_i686_msvc" -version = "0.53.0" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnu" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" [[package]] name = "windows_x86_64_gnullvm" -version = "0.53.0" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "windows_x86_64_msvc" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" [[package]] name = "wit-bindgen" -version = "0.45.0" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" + +[[package]] +name = "writeable" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" + +[[package]] +name = "yoke" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "052283831dbae3d879dc7f51f3d92703a316ca49f91540417d38591826127814" +checksum = "5f41bb01b8226ef4bfd589436a297c53d118f65921786300e427be8d487695cc" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", + "synstructure", +] + +[[package]] +name = "zarrs" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad12c7c2b91d2f6871f21efc28fd5a302809d7246974fdd99ef55bd3b16b78a0" +dependencies = [ + "async-generic", + "async-lock", + "async-trait", + "blosc-src", + "bytemuck", + "bytes", + "crc32c", + "derive_more", + "flate2", + "futures", + "getrandom 0.3.3", + "half", + "inventory", + "itertools 0.14.0", + "itoa", + "lru", + "moka", + "ndarray", + "num", + "num-complex", + "quick_cache", + "rayon", + "rayon_iter_concurrent_limit", + "serde", + "serde_json", + "thiserror 2.0.17", + "thread_local", + "unsafe_cell_slice", + "uuid", + "zarrs_data_type", + "zarrs_filesystem", + "zarrs_metadata", + "zarrs_metadata_ext", + "zarrs_plugin", + "zarrs_registry", + "zarrs_storage", + "zstd 0.13.3", +] + +[[package]] +name = "zarrs_data_type" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e68a3b9e663cf4933afcd90f460ee72986fdf6c2b4d43d0441ad049b802342" +dependencies = [ + "derive_more", + "half", + "inventory", + "num", + "thiserror 2.0.17", + "zarrs_metadata", + "zarrs_plugin", +] + +[[package]] +name = "zarrs_filesystem" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e135c32621a3a5796d917768d5c7aa7f58be9480ae00778956b82ec6409150b" +dependencies = [ + "bytes", + "derive_more", + "itertools 0.14.0", + "libc", + "page_size", + "parking_lot", + "pathdiff", + "thiserror 2.0.17", + "walkdir", + "zarrs_storage", +] + +[[package]] +name = "zarrs_metadata" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "708b938e5af9e6564d7135fb2d1e05c0deff3d7124694ff6822aa01614a6c991" +dependencies = [ + "derive_more", + "half", + "monostate", + "serde", + "serde_json", + "thiserror 2.0.17", +] + +[[package]] +name = "zarrs_metadata_ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4fb56ca32761b64c4b2a3db1097fbd29adfb321a129279b1db99be0a61d361a" +dependencies = [ + "derive_more", + "half", + "monostate", + "num", + "serde", + "serde_json", + "serde_repr", + "thiserror 2.0.17", + "zarrs_metadata", + "zarrs_registry", +] + +[[package]] +name = "zarrs_object_store" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5d9d0d3db426dd50dfb0b5d7cc1660bda368caeb5cd8645c60c46bc4f261a19" +dependencies = [ + "async-trait", + "futures", + "object_store", + "zarrs_storage", +] + +[[package]] +name = "zarrs_plugin" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3c9e0514d4c50f44d11285d5df70e4e586486a39826579c9d87ddc3f3dac561" +dependencies = [ + "thiserror 2.0.17", +] + +[[package]] +name = "zarrs_registry" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe4e55522eeb87eefab89017bef78cb823f86861fd8a3cc12e9f6538c348d57" +dependencies = [ + "regex", +] + +[[package]] +name = "zarrs_storage" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bc1037a8fa8c44ccb8f5c6c85753a63ddf296fb43280f28150f0f29fda8d301" +dependencies = [ + "async-trait", + "auto_impl", + "bytes", + "derive_more", + "futures", + "itertools 0.14.0", + "parking_lot", + "thiserror 2.0.17", + "unsafe_cell_slice", +] [[package]] name = "zerocopy" -version = "0.8.26" +version = "0.8.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f" +checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.26" +version = "0.8.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" +checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", +] + +[[package]] +name = "zerofrom" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", + "synstructure", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + +[[package]] +name = "zerotrie" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36f0bbd478583f79edad978b407914f61b2972f5af6fa089686016be8f9af595" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7aa2bd55086f1ab526693ecbe444205da57e25f4489879da80635a46d90e73b" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", ] [[package]] @@ -2517,7 +4467,7 @@ dependencies = [ "pbkdf2", "sha1", "time", - "zstd", + "zstd 0.11.2+zstd.1.5.2", ] [[package]] @@ -2526,7 +4476,16 @@ version = "0.11.2+zstd.1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" dependencies = [ - "zstd-safe", + "zstd-safe 5.0.2+zstd.1.5.2", +] + +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe 7.2.4", ] [[package]] @@ -2539,11 +4498,20 @@ dependencies = [ "zstd-sys", ] +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + [[package]] name = "zstd-sys" -version = "2.0.15+zstd.1.5.7" +version = "2.0.16+zstd.1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb81183ddd97d0c74cedf1d50d85c8d08c1b8b68ee863bdee9e706eedba1a237" +checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" dependencies = [ "cc", "pkg-config", diff --git a/Cargo.toml b/Cargo.toml index 391dd25..a09f519 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ license = "MIT" repository = "https://github.com/pymc-devs/nutpie" keywords = ["statistics", "bayes"] description = "Python wrapper for nuts-rs -- a NUTS sampler written in Rust." -rust-version = "1.76" +rust-version = "1.90" [features] extension-module = ["pyo3/extension-module"] @@ -21,23 +21,28 @@ name = "_lib" crate-type = ["cdylib"] [dependencies] -nuts-rs = "0.16.1" +nuts-rs = { version = "0.16.1", features = ["zarr", "arrow"] } numpy = "0.26.0" rand = "0.9.0" thiserror = "2.0.3" rand_chacha = "0.9.0" -rayon = "1.10.0" -# Keep arrow in sync with nuts-rs requirements -arrow = { version = "56.1.0", default-features = false, features = ["ffi"] } +rayon = "1.11.0" anyhow = "1.0.72" itertools = "0.14.0" -bridgestan = "2.6.1" +bridgestan = "2.7.0" rand_distr = "0.5.0" -smallvec = "1.14.0" +smallvec = "1.15.0" upon = { version = "0.10.0", default-features = false, features = [] } time-humanize = { version = "0.1.3", default-features = false } indicatif = "0.18.0" tch = { version = "0.20.0", optional = true } +pyo3-object_store = "0.6.0" +# Keep zarrs crates in sync with nuts-rs requirements +zarrs = { version = "0.22.2", features = ["async"] } +zarrs_object_store = "0.5.0" +tokio = { version = "1.47.1", features = ["rt", "rt-multi-thread"] } +pyo3-arrow = "0.12.0" +arrow = { version = "56.2.0", features = ["json"] } [dependencies.pyo3] version = "0.26.0" diff --git a/pyproject.toml b/pyproject.toml index c53f390..822d2e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ classifiers = [ dependencies = [ "pyarrow >= 12.0.0", + "arro3-core >= 0.6.0", "pandas >= 2.0", "xarray >= 2025.01.2", "arviz >= 0.20.0", @@ -28,12 +29,12 @@ Homepage = "https://pymc-devs.github.io/nutpie/" Repository = "https://github.com/pymc-devs/nutpie" [project.optional-dependencies] -stan = ["bridgestan >= 2.6.1", "stanio >= 0.5.1"] +stan = ["bridgestan >= 2.7.0", "stanio >= 0.5.1"] pymc = ["pymc >= 5.20.1", "numba >= 0.60.0"] pymc-jax = ["pymc >= 5.20.1", "jax >= 0.4.27"] nnflow = ["flowjax >= 17.1.0", "equinox >= 0.11.12"] dev = [ - "bridgestan >= 2.6.1", + "bridgestan >= 2.7.0", "stanio >= 0.5.1", "pymc >= 5.20.1", "numba >= 0.60.0", @@ -44,7 +45,7 @@ dev = [ "pytest-arraydiff", ] all = [ - "bridgestan >= 2.6.1", + "bridgestan >= 2.7.0", "stanio >= 0.5.1", "pymc >= 5.20.1", "numba >= 0.60.0", @@ -76,7 +77,7 @@ features = ["pyo3/extension-module"] [tool.pytest.ini_options] markers = [ - "flow: tests for normalizing flows", - "stan: tests for Stan models", - "pymc: tests for PyMC models", + "flow: tests for normalizing flows", + "stan: tests for Stan models", + "pymc: tests for PyMC models", ] diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 36a3a09..5028f33 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -111,6 +111,7 @@ class CompiledPyMCModel(CompiledModel): _n_dim: int _shapes: dict[str, tuple[int, ...]] _coords: Optional[dict[str, Any]] + _transform_adapt_args: dict | None = None @property def n_dim(self): @@ -146,13 +147,14 @@ def with_data(self, **updates): user_data=user_data, ) - def _make_sampler(self, settings, init_mean, cores, progress_type): + def _make_sampler(self, settings, init_mean, cores, progress_type, store): model = self._make_model(init_mean) return _lib.PySampler.from_pymc( settings, cores, model, progress_type, + store, ) def _make_model(self, init_mean): @@ -164,24 +166,46 @@ def _make_model(self, init_mean): self, ) logp_fn = _lib.LogpFunc( - self.n_dim, self.compiled_logp_func.address, self.user_data.ctypes.data, self, ) - var_sizes = [prod(shape) for shape in self.shape_info[2]] var_names = self.shape_info[0] + coords = self._coords.copy() if self._coords is not None else {} + dim_sizes = {name: len(vals) for name, vals in coords.items()} + dims = self.dims.copy() if self.dims is not None else {} + var_types = ["float64"] * len(var_names) + var_shapes = self.shape_info[2] + + variables = _lib.PyVariable.new_variables( + var_names, var_types, var_shapes, dim_sizes, dims + ) + + outer_kwargs = self._transform_adapt_args + if outer_kwargs is None: + outer_kwargs = {} + + def make_adapter(*args, **kwargs): + from nutpie.transform_adapter import make_transform_adapter + + return make_transform_adapter(**outer_kwargs)(*args, **kwargs, logp_fn=None) + return _lib.PyMcModel( - self.n_dim, logp_fn, expand_fn, + variables, + self.n_dim, + dim_sizes, + coords, self.initial_point_func, - var_sizes, - var_names, + make_adapter, ) + def with_transform_adapt(self, **kwargs): + return dataclasses.replace(self, _transform_adapt_args=kwargs) + def update_user_data(user_data, user_data_storage): user_data = user_data[()] diff --git a/python/nutpie/compile_stan.py b/python/nutpie/compile_stan.py index 9e65b1e..35dfa18 100644 --- a/python/nutpie/compile_stan.py +++ b/python/nutpie/compile_stan.py @@ -52,13 +52,25 @@ def make_adapter(*args, **kwargs): return make_transform_adapter(**outer_kwargs)(*args, **kwargs, logp_fn=None) - model = _lib.StanModel(self.library, seed, data_json, make_adapter) + coords = self._coords + if coords is None: + coords = {} + coords = coords.copy() + + dims = self.dims + if dims is None: + dims = {} + dims = dims.copy() + dim_sizes = {name: len(dim) for name, dim in coords.items()} + + model = _lib.StanModel( + self.library, dim_sizes, dims, coords, seed, data_json, make_adapter + ) coords = self._coords if coords is None: coords = {} else: coords = coords.copy() - coords["unconstrained_parameter"] = pd.Index(model.param_unc_names()) return CompiledStanModel( _coords=coords, @@ -93,13 +105,14 @@ def _make_model(self, init_mean): return self.with_data().model return self.model - def _make_sampler(self, settings, init_mean, cores, progress_type): + def _make_sampler(self, settings, init_mean, cores, progress_type, store): model = self._make_model(init_mean) return _lib.PySampler.from_stan( settings, cores, model, progress_type, + store, ) @property diff --git a/python/nutpie/compiled_pyfunc.py b/python/nutpie/compiled_pyfunc.py index 618feea..db58c28 100644 --- a/python/nutpie/compiled_pyfunc.py +++ b/python/nutpie/compiled_pyfunc.py @@ -19,6 +19,7 @@ class PyFuncModel(CompiledModel): _shared_data: dict[str, Any] _n_dim: int _variables: list[_lib.PyVariable] + _dim_sizes: dict[str, int] _coords: dict[str, Any] _raw_logp_fn: Callable | None _transform_adapt_args: dict | None = None @@ -47,13 +48,14 @@ def with_data(self, **updates): def with_transform_adapt(self, **kwargs): return dataclasses.replace(self, _transform_adapt_args=kwargs) - def _make_sampler(self, settings, init_mean, cores, progress_type): + def _make_sampler(self, settings, init_mean, cores, progress_type, store): model = self._make_model(init_mean) return _lib.PySampler.from_pyfunc( settings, cores, model, progress_type, + store, ) def _make_model(self, init_mean): @@ -85,6 +87,8 @@ def make_adapter(*args, **kwargs): make_expand_func, self._variables, self.n_dim, + dim_sizes=self._dim_sizes, + coords=self._coords, init_point_func=self._make_initial_points, transform_adapter=make_adapter, ) @@ -105,19 +109,6 @@ def from_pyfunc( make_transform_adapter=None, raw_logp_fn=None, ): - variables = [] - for name, shape, dtype in zip( - expanded_names, expanded_shapes, expanded_dtypes, strict=True - ): - shape = _lib.TensorShape(list(shape)) - if dtype == np.float64: - dtype = _lib.ExpandDtype.float64_array(shape) - elif dtype == np.float32: - dtype = _lib.ExpandDtype.float32_array(shape) - elif dtype == np.int64: - dtype = _lib.ExpandDtype.int64_array(shape) - variables.append(_lib.PyVariable(name, dtype)) - if coords is None: coords = {} if dims is None: @@ -125,10 +116,23 @@ def from_pyfunc( if shared_data is None: shared_data = {} + coords = coords.copy() + + dim_sizes = {k: len(v) for k, v in coords.items()} + shapes = [tuple(shape) for shape in expanded_shapes] + variables = _lib.PyVariable.new_variables( + expanded_names, + [str(dtype) for dtype in expanded_dtypes], + shapes, + dim_sizes, + dims, + ) + return PyFuncModel( _n_dim=ndim, dims=dims, _coords=coords, + _dim_sizes=dim_sizes, _make_logp_func=make_logp_fn, _make_expand_func=make_expand_fn, _make_initial_points=make_initial_point_fn, diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index 0655173..5b25404 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -54,75 +54,102 @@ def benchmark_logp(self, point, num_evals, cores): return pd.concat(times) -def _trace_to_arviz(traces, n_tune, shapes, **kwargs): - n_chains = len(traces) - - data_dict = {} - data_dict_tune = {} - stats_dict = {} - stats_dict_tune = {} - - draw_batches = [] - stats_batches = [] - for draws, stats in traces: - draw_batches.append(pyarrow.RecordBatch.from_struct_array(draws)) - stats_batches.append(pyarrow.RecordBatch.from_struct_array(stats)) - - table = pyarrow.Table.from_batches(draw_batches) - table_stats = pyarrow.Table.from_batches(stats_batches) - for name, col in zip(table.column_names, table.columns): - lengths = [len(chunk) for chunk in col.chunks] - length = max(lengths) - dtype = col.chunks[0].values.to_numpy().dtype - if dtype in [np.float64, np.float32]: - data = np.full( - (n_chains, length, *tuple(shapes[name])), np.nan, dtype=dtype - ) - else: - data = np.zeros((n_chains, length, *tuple(shapes[name])), dtype=dtype) - for i, chunk in enumerate(col.chunks): - data[i, : len(chunk)] = chunk.values.to_numpy().reshape( - (len(chunk),) + shapes[name] - ) - - data_dict[name] = data[:, n_tune:] - data_dict_tune[name] = data[:, :n_tune] +def _arrow_to_arviz(draw_batches, stat_batches, skip_vars=None, **kwargs): + if skip_vars is None: + skip_vars = [] + + n_chains = len(draw_batches) + assert n_chains == len(stat_batches) + + max_tuning = 0 + max_posterior = 0 + num_tuning = [] + + for draw, stat in zip(draw_batches, stat_batches): + tuning = stat.column("tuning") + _num_tuning = tuning.to_numpy().sum() + assert draw.num_rows == stat.num_rows + max_tuning = max(max_tuning, _num_tuning) + max_posterior = max(max_posterior, draw.num_rows - _num_tuning) + num_tuning.append(_num_tuning) + + data_tune = {} + data_posterior = {} + + stats_tune = {} + stats_posterior = {} + + dims = {} + + for i, draw in enumerate(draw_batches): + draw_tune = draw.slice(0, num_tuning[i]) + _add_arrow_data(data_tune, max_tuning, draw_tune, i, n_chains, dims, []) + draw_posterior = draw.slice(num_tuning[i], draw.num_rows - num_tuning[i]) + _add_arrow_data( + data_posterior, max_posterior, draw_posterior, i, n_chains, dims, [] + ) + for i, stat in enumerate(stat_batches): + stat_tune = stat.slice(0, num_tuning[i]) + _add_arrow_data(stats_tune, max_tuning, stat_tune, i, n_chains, dims, skip_vars) + stat_posterior = stat.slice(num_tuning[i], stat.num_rows - num_tuning[i]) + _add_arrow_data( + stats_posterior, max_posterior, stat_posterior, i, n_chains, dims, skip_vars + ) - for name, col in zip(table_stats.column_names, table_stats.columns): - if name in ["chain", "draw", "divergence_message"]: - continue - col_type = col.type - if hasattr(col_type, "list_size"): - last_shape = (col_type.list_size,) - dtype = col_type.field(0).type.to_pandas_dtype() - else: - dtype = col_type.to_pandas_dtype() - last_shape = () + return arviz.from_dict( + data_posterior, + sample_stats=stats_posterior, + warmup_posterior=data_tune, + warmup_sample_stats=stats_tune, + dims=dims, + **kwargs, + ) - lengths = [len(chunk) for chunk in col.chunks] - length = max(lengths) - if dtype in [np.float64, np.float32]: - data = np.full((n_chains, length, *last_shape), np.nan, dtype=dtype) - else: - data = np.zeros((n_chains, length, *last_shape), dtype=dtype) +def _add_arrow_data(data_dict, max_length, batch, chain, n_chains, dims, skip_vars): + num_draws = batch.num_rows - for i, chunk in enumerate(col.chunks): - if hasattr(chunk, "values"): - values = chunk.values.to_numpy(False) + for name in batch.column_names: + if name in skip_vars: + continue + col = batch.column(name) + meta = col.field.metadata + item_dims = meta.get(b"dims", []) + if item_dims: + item_dims = item_dims.decode("utf-8").split(",") + item_shape = meta.get(b"shape", []) + if item_shape: + item_shape = item_shape.decode("utf-8").split(",") + item_shape = [int(s) for s in item_shape] + total_shape = [n_chains, max_length, *item_shape] + + col = pyarrow.array(col) + + is_null = col.is_null() + + if hasattr(col, "flatten"): + col = col.flatten() + dtype = col.type.to_pandas_dtype() + + if name not in data_dict: + if dtype in [np.float64, np.float32]: + data = np.full(total_shape, np.nan, dtype=dtype) else: - values = chunk.to_numpy(False) - data[i, : len(chunk)] = values.reshape((len(chunk), *last_shape)) - stats_dict[name] = data[:, n_tune:] - stats_dict_tune[name] = data[:, :n_tune] + data = np.zeros(total_shape, dtype=dtype) + data_dict[name] = data - return arviz.from_dict( - data_dict, - sample_stats=stats_dict, - warmup_posterior=data_dict_tune, - warmup_sample_stats=stats_dict_tune, - **kwargs, - ) + dims[name] = item_dims + + values = col.to_numpy(False) + if is_null.sum() == 0: + data_dict[name][chain, :num_draws] = values.reshape( + (num_draws,) + tuple(item_shape) + ) + else: + is_null = is_null.to_numpy(False) + data_dict[name][chain, :num_draws][~is_null] = values.reshape( + ((~is_null).sum(),) + tuple(item_shape) + ) _progress_style = """ @@ -360,6 +387,15 @@ def in_colab(): return False # Probably standard Python interpreter +_ZarrStoreType = ( + _lib.store.S3Store + | _lib.store.LocalStore + | _lib.store.HTTPStore + | _lib.store.GCSStore + | _lib.store.AzureStore +) + + class _BackgroundSampler: _sampler: Any _num_divs: int @@ -369,6 +405,8 @@ class _BackgroundSampler: _chains_finished: int _compiled_model: CompiledModel _save_warmup: bool + _store: _lib.PyStorage + _zarr_store: _ZarrStoreType | None = None def __init__( self, @@ -383,6 +421,7 @@ def __init__( progress_template=None, progress_style=None, progress_rate=100, + store=None, ): self._settings = settings self._compiled_model = compiled_model @@ -391,6 +430,14 @@ def __init__( self._html = None + if store is None: + store = _lib.PyStorage.arrow() + elif type(store).__module__ == "_lib.store": + self._zarr_store = store + store = _lib.PyStorage.zarr(store) + + self._store = store + if not progress_bar: progress_type = _lib.ProgressType.none() @@ -411,8 +458,11 @@ def __init__( self.display_id = IPython.display.display(self, display_id=True) def callback(formatted): - self._html = formatted - self.display_id.update(self) + try: + self._html = formatted + self.display_id.update(self) + except Exception as e: + print(f"Error updating progress display: {e}") progress_type = _lib.ProgressType.template_callback( progress_rate, progress_template, cores, callback @@ -447,6 +497,7 @@ def callback(formatted): init_mean, cores, progress_type, + self._store, ) def wait(self, *, timeout=None): @@ -460,35 +511,64 @@ def wait(self, *, timeout=None): This resumes the sampler in case it had been paused. """ self._sampler.wait(timeout) - results = self._sampler.extract_results() + results = self._sampler.take_results() return self._extract(results) def _extract(self, results): - dims = {name: list(dim) for name, dim in self._compiled_model.dims.items()} - dims["mass_matrix_inv"] = ["unconstrained_parameter"] - dims["gradient"] = ["unconstrained_parameter"] - dims["unconstrained_draw"] = ["unconstrained_parameter"] - dims["divergence_start"] = ["unconstrained_parameter"] - dims["divergence_start_gradient"] = ["unconstrained_parameter"] - dims["divergence_end"] = ["unconstrained_parameter"] - dims["divergence_momentum"] = ["unconstrained_parameter"] - dims["transformed_gradient"] = ["unconstrained_parameter"] - dims["transformed_position"] = ["unconstrained_parameter"] - if self._return_raw_trace: return results else: - return _trace_to_arviz( - results, - self._settings.num_tune, - self._compiled_model.shapes, - dims=dims, - coords={ - name: pd.Index(vals) - for name, vals in self._compiled_model.coords.items() - }, - save_warmup=self._save_warmup, - ) + if results.is_zarr(): + from zarr.storage import ObjectStore + import obstore + import xarray as xr + + assert self._zarr_store is not None + + args, kwargs = self._zarr_store.__getnewargs_ex__() + name = self._zarr_store.__class__.__name__ + cls = getattr(obstore.store, name) + store = cls(*args, **kwargs) + + obj_store = ObjectStore(store, read_only=True) + ds = xr.open_datatree(obj_store, engine="zarr", consolidated=False) + return arviz.from_datatree(ds) + + elif results.is_arrow(): + skip_vars = [] + skips = { + "store_gradient": ["gradient"], + "store_unconstrained": ["unconstrained"], + "store_mass_matrix": [ + "mass_matrix_inv", + "mass_matrix_eigvals", + "mass_matrix_stds", + ], + "store_divergences": [ + "divergence_start", + "divergence_end", + "divergence_momentum", + "divergence_start_gradient", + ], + } + + for setting, names in skips.items(): + if not getattr(self._settings, setting, False): + skip_vars.extend(names) + + draw_batches, stat_batches = results.get_arrow_trace() + return _arrow_to_arviz( + draw_batches, + stat_batches, + skip_vars=skip_vars, + coords={ + name: pd.Index(vals) + for name, vals in self._compiled_model.coords.items() + }, + save_warmup=self._save_warmup, + ) + else: + raise ValueError("Unknown results type") def inspect(self): """Get a copy of the current state of the trace""" @@ -543,6 +623,7 @@ def sample( progress_template: str | None = None, progress_style: str | None = None, progress_rate: int = 100, + zarr_store: _ZarrStoreType | None = None, ) -> arviz.InferenceData: ... @@ -565,6 +646,7 @@ def sample( progress_template: str | None = None, progress_style: str | None = None, progress_rate: int = 100, + zarr_store: _ZarrStoreType | None = None, **kwargs, ) -> arviz.InferenceData: ... @@ -588,6 +670,7 @@ def sample( progress_template: str | None = None, progress_style: str | None = None, progress_rate: int = 100, + zarr_store: _ZarrStoreType | None = None, **kwargs, ) -> _BackgroundSampler: ... @@ -610,6 +693,7 @@ def sample( progress_template: str | None = None, progress_style: str | None = None, progress_rate: int = 100, + zarr_store: _ZarrStoreType | None = None, **kwargs, ) -> arviz.InferenceData | _BackgroundSampler: """Sample the posterior distribution for a compiled model. @@ -694,6 +778,10 @@ def sample( transform_adapt: bool, default=False Use the experimental transform adaptation algorithm during tuning. + zarr_store: nutpie.store.Store + A store created using nutpie.store to store the samples + in. If None (default), the samples will be stored in + memory using an arrow table. **kwargs Pass additional arguments to nutpie._lib.PySamplerArgs @@ -750,6 +838,7 @@ def sample( progress_template=progress_template, progress_style=progress_style, progress_rate=progress_rate, + store=zarr_store, ) if not blocking: diff --git a/src/common.rs b/src/common.rs new file mode 100644 index 0000000..855352c --- /dev/null +++ b/src/common.rs @@ -0,0 +1,381 @@ +use std::collections::HashMap; + +use anyhow::{bail, Context, Result}; +use numpy::{PyArray1, PyReadonlyArray1}; +use nuts_rs::Value; +use pyo3::{ + exceptions::PyRuntimeError, + pyclass, pymethods, + types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyListMethods, PyType}, + Bound, FromPyObject, IntoPyObject, IntoPyObjectExt, Py, PyAny, PyErr, Python, +}; +use smallvec::SmallVec; + +#[derive(Debug, Clone)] +pub struct Dims(pub SmallVec<[String; 4]>); + +impl Dims { + pub fn as_slice(&self) -> &[String] { + &self.0 + } +} + +#[derive(Debug, Clone)] +pub struct Shape(pub SmallVec<[u64; 4]>); + +impl Shape { + pub fn as_slice(&self) -> &[u64] { + &self.0 + } +} + +#[derive(Debug, Clone)] +pub struct ItemType(pub nuts_rs::ItemType); + +impl ItemType { + pub fn into_inner(self) -> nuts_rs::ItemType { + self.0 + } + + pub fn as_inner(&self) -> &nuts_rs::ItemType { + &self.0 + } +} + +impl<'py> IntoPyObject<'py> for &Dims { + type Target = PyList; + type Output = Bound<'py, PyList>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> std::result::Result { + PyList::new(py, self.0.iter()) + } +} + +impl<'py> IntoPyObject<'py> for &Shape { + type Target = PyList; + type Output = Bound<'py, PyList>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> std::result::Result { + PyList::new(py, self.0.iter()) + } +} + +impl<'py> IntoPyObject<'py> for &ItemType { + type Target = PyAny; + type Output = Bound<'py, PyAny>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> std::result::Result { + let dtype_str = match self.0 { + nuts_rs::ItemType::U64 => "uint64", + nuts_rs::ItemType::I64 => "int64", + nuts_rs::ItemType::F64 => "float64", + nuts_rs::ItemType::F32 => "float32", + nuts_rs::ItemType::Bool => "bool", + nuts_rs::ItemType::String => "object", + }; + let numpy = py.import("numpy")?; + let dtype = numpy.getattr("dtype")?.call1((dtype_str,))?; + Ok(dtype) + } +} + +impl<'py> FromPyObject<'py> for ItemType { + fn extract_bound(ob: &Bound<'_, PyAny>) -> std::result::Result { + let dtype_str: &str = ob.extract()?; + let item_type = match dtype_str { + "uint64" => nuts_rs::ItemType::U64, + "int64" => nuts_rs::ItemType::I64, + "float64" => nuts_rs::ItemType::F64, + "float32" => nuts_rs::ItemType::F32, + "bool" => nuts_rs::ItemType::Bool, + "object" => nuts_rs::ItemType::String, + _ => { + return Err(PyRuntimeError::new_err(format!( + "Unsupported item type: {}", + dtype_str + ))) + } + }; + Ok(ItemType(item_type)) + } +} + +#[pyclass] +pub struct PyValue(Value); + +impl<'py> FromPyObject<'py> for PyValue { + fn extract_bound(ob: &Bound<'py, PyAny>) -> std::result::Result { + let ob = if ob.hasattr("values")? { + &ob.getattr("values")? + } else { + ob + }; + if let Ok(arr) = ob.extract::>() { + let vec = arr + .as_slice() + .map_err(|_| PyRuntimeError::new_err("Array is not contiguous"))?; + return Ok(PyValue(Value::F64(vec.to_vec()))); + } + if let Ok(arr) = ob.extract::>() { + let vec = arr + .as_slice() + .map_err(|_| PyRuntimeError::new_err("Array is not contiguous"))?; + return Ok(PyValue(Value::F32(vec.to_vec()))); + } + if let Ok(arr) = ob.extract::>() { + let vec = arr + .as_slice() + .map_err(|_| PyRuntimeError::new_err("Array is not contiguous"))?; + return Ok(PyValue(Value::I64(vec.to_vec()))); + } + if let Ok(arr) = ob.extract::>() { + let vec = arr + .as_slice() + .map_err(|_| PyRuntimeError::new_err("Array is not contiguous"))?; + return Ok(PyValue(Value::U64(vec.to_vec()))); + } + if let Ok(arr) = ob.extract::>() { + let vec = arr + .as_slice() + .map_err(|_| PyRuntimeError::new_err("Array is not contiguous"))?; + return Ok(PyValue(Value::Bool(vec.to_vec()))); + } + if let Ok(list) = ob.extract::>() { + let vec: Vec = list + .iter() + .map(|item| { + item.extract::() + .map_err(|_| PyRuntimeError::new_err("List item is not a string")) + }) + .collect::>()?; + return Ok(PyValue(Value::Strings(vec))); + } + if let Ok(arr) = ob.extract::>>() { + let vec = arr + .as_slice() + .map_err(|_| PyRuntimeError::new_err("Array is not contiguous"))?; + let vals_as_str = vec + .iter() + .map(|item| { + item.extract::(ob.py()) + .map_err(|_| PyRuntimeError::new_err("Array item is not a string")) + }) + .collect::>()?; + return Ok(PyValue(Value::Strings(vals_as_str))); + } + Err(PyRuntimeError::new_err( + "Could not convert to Value. Unsupported type.", + )) + } +} + +impl PyValue { + pub fn into_value(self) -> Value { + self.0 + } + + pub fn into_array(self, py: Python) -> Result> { + match self.0 { + Value::F64(vec) => Ok(PyArray1::from_vec(py, vec).into_any()), + Value::F32(vec) => Ok(PyArray1::from_vec(py, vec).into_any()), + Value::I64(vec) => Ok(PyArray1::from_vec(py, vec).into_any()), + Value::U64(vec) => Ok(PyArray1::from_vec(py, vec).into_any()), + Value::Bool(vec) => Ok(PyArray1::from_vec(py, vec).into_any()), + Value::Strings(vec) => Ok(PyList::new(py, vec)?.into_any()), + Value::ScalarString(val) => Ok(val.into_bound_py_any(py)?), + Value::ScalarU64(val) => Ok(val.into_bound_py_any(py)?), + Value::ScalarI64(val) => Ok(val.into_bound_py_any(py)?), + Value::ScalarF64(val) => Ok(val.into_bound_py_any(py)?), + Value::ScalarF32(val) => Ok(val.into_bound_py_any(py)?), + Value::ScalarBool(val) => Ok(val.into_bound_py_any(py)?), + } + } +} + +#[pyclass] +#[derive(Debug, Clone)] +#[non_exhaustive] +pub struct PyVariable { + #[pyo3(get)] + pub name: String, + pub item_type: ItemType, + #[pyo3(get)] + pub dims: Dims, + #[pyo3(get)] + pub shape: Shape, + #[pyo3(get)] + pub num_elements: usize, + #[pyo3(get)] + pub start_idx: Option, + #[pyo3(get)] + pub end_idx: Option, +} + +impl PyVariable { + pub fn new( + name: String, + item_type: ItemType, + shape: Option>, + all_dims: &mut HashMap>, + dim_sizes: &mut HashMap, + start_idx: Option, + ) -> anyhow::Result { + let dims = all_dims.get(&name); + + let (dims, shape) = match (dims, shape) { + (Some(dims), Some(shape)) => { + if dims.len() != shape.len() { + bail!( + "Variable '{}': number of dims ({}) does not match number of shape entries ({})", + name, + dims.len(), + shape.len(), + ); + } + for (dim, size) in dims.iter().zip(shape.iter()) { + if let Some(existing_size) = dim_sizes.get(dim) { + if *existing_size != *size { + bail!("Variable '{}': dimension '{}' has inconsistent size. Expected {}, but previously defined as {}", + name, dim, size, existing_size); + } + } + } + (dims.clone(), shape) + } + (Some(dims), None) => { + let mut inferred_shape = Vec::new(); + for dim in dims.iter() { + if let Some(size) = dim_sizes.get(dim) { + inferred_shape.push(*size); + } else { + bail!( + "Variable '{}': dimension '{}' size unknown and no shape provided", + name, + dim + ); + } + } + (dims.clone(), inferred_shape) + } + (None, Some(shape)) => { + let mut inferred_dims = Vec::new(); + for (i, size) in shape.iter().enumerate() { + let generated_name = format!("{}_dim_{}", name, i); + if dim_sizes.contains_key(&generated_name) { + bail!("Variable '{}': generated anonymous dimension name '{}' already exists.", + name, generated_name); + } + dim_sizes.insert(generated_name.clone(), *size); + inferred_dims.push(generated_name); + } + all_dims.insert(name.clone(), inferred_dims.clone()); + (inferred_dims, shape) + } + (None, None) => { + bail!("Variable '{}': no dims or shape provided", name); + } + }; + + let num_elements = shape.iter().product::() as usize; + + Ok(PyVariable { + name, + item_type, + dims: Dims(dims.into()), + shape: Shape(shape.into()), + num_elements, + start_idx, + end_idx: start_idx.map(|idx| idx + num_elements), + }) + } +} + +#[pymethods] +impl PyVariable { + #[classmethod] + fn new_variables<'py>( + cls: &Bound<'py, PyType>, + names: Vec, + item_types: Vec, + shapes: Vec>>, + dim_sizes: Py, + dims: Py, + ) -> Result> { + let mut rust_all_dims = HashMap::new(); + let mut rust_dim_sizes = HashMap::new(); + + let py = cls.py(); + + for (key, value) in dims.bind(py).iter() { + let key: String = key.extract().context("Dimension key is not a string")?; + let value: Vec = value + .extract() + .context("Dimension value is not a list of strings")?; + rust_all_dims.insert(key, value); + } + + for (key, value) in dim_sizes.bind(py).iter() { + let key: String = key + .extract() + .context("Dimension size key is not a string")?; + let value: u64 = value + .extract() + .context("Dimension size value is not an integer")?; + rust_dim_sizes.insert(key, value); + } + + let mut current_idx = 0; + + let variables = names + .into_iter() + .zip(item_types) + .zip(shapes) + .map(|((name, item_type), shape)| { + let item_type = match item_type.as_str() { + "uint64" => ItemType(nuts_rs::ItemType::U64), + "int64" => ItemType(nuts_rs::ItemType::I64), + "float64" => ItemType(nuts_rs::ItemType::F64), + "float32" => ItemType(nuts_rs::ItemType::F32), + "bool" => ItemType(nuts_rs::ItemType::Bool), + "string" => ItemType(nuts_rs::ItemType::String), + _ => bail!("Unsupported item type: {}", item_type), + }; + + let start_idx = Some(current_idx); + let var = Self::new( + name, + item_type, + shape, + &mut rust_all_dims, + &mut rust_dim_sizes, + start_idx, + ) + .context("Could not create variable")?; + current_idx += var.num_elements; + Ok(var) + }) + .collect::>>()?; + + let dim_sizes = dim_sizes.bind(py); + for key in rust_dim_sizes.keys() { + if !dim_sizes.contains(key).unwrap_or(false) { + dim_sizes + .set_item(key, rust_dim_sizes[key]) + .context("Could not update dimension sizes")?; + } + } + + let all_dims = dims.bind(py); + for key in rust_all_dims.keys() { + if !all_dims.contains(key).unwrap_or(false) { + all_dims + .set_item(key, rust_all_dims[key].clone()) + .context("Could not update all_dims")?; + } + } + Ok(variables) + } +} diff --git a/src/lib.rs b/src/lib.rs index 6154f92..287118f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +mod common; mod progress; mod pyfunc; mod pymc; diff --git a/src/progress.rs b/src/progress.rs index 575efcc..2881c75 100644 --- a/src/progress.rs +++ b/src/progress.rs @@ -30,13 +30,14 @@ impl ProgressHandler { let (update_tx, update_rx) = sync_channel(1); spawn(move || { - Python::with_gil(move |py| { - py.allow_threads(move || { + // We keep an extra gil reference alive, to ensure the + // python ThreadState is not destroyed. + // See https://github.com/PyO3/pyo3/issues/5467 + Python::attach(move |py| { + py.detach(move || loop { let update = update_rx.recv(); - let Ok(update) = update else { - return; - }; - let res = Python::with_gil(|py| callback.call1(py, (update,))); + let Ok(update) = update else { break }; + let res = Python::attach(|py| callback.call1(py, (update,))); if let Err(err) = res { eprintln!("Error in progress callback: {err}"); } diff --git a/src/pyfunc.rs b/src/pyfunc.rs index 9640468..8a21875 100644 --- a/src/pyfunc.rs +++ b/src/pyfunc.rs @@ -1,75 +1,24 @@ -use std::sync::Arc; - -use anyhow::{anyhow, bail, Context, Result}; -use arrow::{ - array::{ - Array, ArrayBuilder, BooleanBuilder, Float32Builder, Float64Builder, Int64Builder, - LargeListBuilder, PrimitiveBuilder, StructBuilder, - }, - datatypes::{DataType, Field, Float32Type, Float64Type, Int64Type}, +use std::{collections::HashMap, sync::Arc}; + +use anyhow::{bail, Context, Result}; +use numpy::{ + NotContiguousError, PyArray1, PyReadonlyArray1, PyReadonlyArrayDyn, PyUntypedArrayMethods, }; -use numpy::{NotContiguousError, PyArray1, PyReadonlyArray1}; -use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Model}; +use nuts_rs::{CpuLogpFunc, CpuMath, HasDims, LogpError, Model, Storable, Value}; use pyo3::{ exceptions::PyRuntimeError, pyclass, pymethods, - types::{PyAnyMethods, PyDict, PyDictMethods}, + types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyListMethods}, Bound, Py, PyAny, PyErr, Python, }; use rand::Rng; use rand_distr::{Distribution, Uniform}; -use smallvec::SmallVec; use thiserror::Error; -use crate::wrapper::PyTransformAdapt; - -#[pyclass] -#[derive(Debug, Clone)] -#[non_exhaustive] -pub struct PyVariable { - #[pyo3(get)] - pub name: String, - #[pyo3(get)] - pub dtype: ExpandDtype, -} - -impl PyVariable { - fn arrow_dtype(&self) -> DataType { - match &self.dtype { - ExpandDtype::Boolean {} => DataType::Boolean, - ExpandDtype::Float64 {} => DataType::Float64, - ExpandDtype::Float32 {} => DataType::Float32, - ExpandDtype::Int64 {} => DataType::Int64, - ExpandDtype::BooleanArray { tensor_type: _ } => { - let field = Arc::new(Field::new("item", DataType::Boolean, false)); - DataType::LargeList(field) - } - ExpandDtype::ArrayFloat64 { tensor_type: _ } => { - let field = Arc::new(Field::new("item", DataType::Float64, true)); - DataType::LargeList(field) - } - ExpandDtype::ArrayFloat32 { tensor_type: _ } => { - let field = Arc::new(Field::new("item", DataType::Float32, false)); - DataType::LargeList(field) - } - ExpandDtype::ArrayInt64 { tensor_type: _ } => { - let field = Arc::new(Field::new("item", DataType::Int64, false)); - DataType::LargeList(field) - } - } - } -} - -#[pymethods] -impl PyVariable { - #[new] - fn new(name: String, value_type: ExpandDtype) -> Self { - Self { - name, - dtype: value_type, - } - } -} +use crate::{ + common::{PyValue, PyVariable}, + wrapper::PyTransformAdapt, +}; #[pyclass] #[derive(Debug, Clone)] @@ -80,28 +29,59 @@ pub struct PyModel { variables: Arc>, transform_adapter: Option, ndim: usize, + dim_sizes: HashMap, + coords: HashMap, } #[pymethods] impl PyModel { #[new] - #[pyo3(signature = (make_logp_func, make_expand_func, variables, ndim, *, init_point_func=None, transform_adapter=None))] + #[pyo3(signature = (make_logp_func, make_expand_func, variables, ndim, dim_sizes, coords, *, init_point_func=None, transform_adapter=None))] fn new<'py>( + py: Python<'py>, make_logp_func: Py, make_expand_func: Py, variables: Vec, ndim: usize, + dim_sizes: Py, + coords: Py, init_point_func: Option>, transform_adapter: Option>, - ) -> Self { - Self { + ) -> Result { + let dim_sizes = dim_sizes + .bind(py) + .iter() + .map(|(key, value)| { + let key: String = key.extract().context("Dimension key is not a string")?; + let value: u64 = value + .extract() + .context("Dimension size value is not an integer")?; + Ok((key, value)) + }) + .collect::>>()?; + + let coords = coords + .bind(py) + .iter() + .map(|(key, value)| { + let key: String = key.extract().context("Coordinate key is not a string")?; + let value: PyValue = value + .extract() + .context("Coordinate value has incorrect type")?; + Ok((key, value.into_value())) + }) + .collect::>>()?; + + Ok(Self { make_logp_func: Arc::new(make_logp_func), make_expand_func: Arc::new(make_expand_func), init_point_func: init_point_func.map(|x| x.into()), variables: Arc::new(variables), ndim, transform_adapter: transform_adapter.map(PyTransformAdapt::new), - } + dim_sizes, + coords, + }) } } @@ -139,29 +119,91 @@ impl LogpError for PyLogpError { pub struct PyDensity { logp: Py, + expand_func: Py, transform_adapter: Option, dim: usize, + variables: Arc>, + dim_sizes: HashMap, + coords: HashMap, } impl PyDensity { fn new( logp_clone_func: &Py, + expand_clone_func: &Py, dim: usize, transform_adapter: Option<&PyTransformAdapt>, + variables: Arc>, + dim_sizes: HashMap, + coords: HashMap, ) -> Result { let logp_func = Python::attach(|py| logp_clone_func.call0(py))?; + let expand_func = Python::attach(|py| expand_clone_func.call1(py, (0u64, 0u64, 0u64)))?; let transform_adapter = transform_adapter.cloned(); Ok(Self { logp: logp_func, + expand_func, transform_adapter, dim, + variables, + dim_sizes, + coords, }) } } +impl HasDims for PyDensity { + fn dim_sizes(&self) -> HashMap { + self.dim_sizes.clone() + } + + fn coords(&self) -> HashMap { + self.coords.clone() + } +} + +pub struct ExpandedVector(Vec>); + +impl Storable for ExpandedVector { + fn names(parent: &PyDensity) -> Vec<&str> { + parent + .variables + .iter() + .map(|var| var.name.as_str()) + .collect() + } + + fn item_type(parent: &PyDensity, item: &str) -> nuts_rs::ItemType { + parent + .variables + .iter() + .find(|var| var.name == item) + .map(|var| var.item_type.as_inner().clone()) + .expect("Item not found") + } + + fn dims<'a>(parent: &'a PyDensity, item: &str) -> Vec<&'a str> { + parent + .variables + .iter() + .find(|var| var.name == item) + .map(|var| var.dims.as_slice().iter().map(|s| s.as_str()).collect()) + .expect("Item not found") + } + + fn get_all<'a>(&'a mut self, parent: &'a PyDensity) -> Vec<(&'a str, Option)> { + self.0 + .iter_mut() + .zip(parent.variables.iter()) + .map(|(val, var)| (var.name.as_str(), val.take())) + .collect() + } +} + impl CpuLogpFunc for PyDensity { type LogpError = PyLogpError; - type TransformParams = Py; + type FlowParameters = Py; + type ExpandedVector = ExpandedVector; fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result { Python::attach(|py| { @@ -193,6 +235,139 @@ impl CpuLogpFunc for PyDensity { self.dim } + fn expand_vector( + &mut self, + _rng: &mut R, + array: &[f64], + ) -> std::result::Result + where + R: rand::Rng + ?Sized, + { + Python::attach(|py| { + let expanded = self + .expand_func + .call1(py, (PyArray1::from_slice(py, array),)); + let Ok(expanded) = expanded else { + return Err(nuts_rs::CpuMathError::ExpandError( + "Expanding function raised an error".into(), + )); + }; + let expanded: Bound = expanded.extract(py).map_err(|_| { + nuts_rs::CpuMathError::ExpandError("Expand function did not return a dict".into()) + })?; + let values = expanded.iter(); + let vars = self.variables.iter(); + + let mut expanded = Vec::with_capacity(self.variables.len()); + for (var, (name2, val)) in vars.zip(values) { + let name2 = name2.extract::<&str>().map_err(|_| { + nuts_rs::CpuMathError::ExpandError("expand key was not a string".into()) + })?; + if var.name != name2 { + return Err(nuts_rs::CpuMathError::ExpandError(format!( + "Unexpected expand key: expected {} but found {}", + var.name, name2 + ))); + } + + if val.is_none() { + expanded.push(None); + continue; + } + + fn as_value<'py, 'a, T>( + var: &'a PyVariable, + val: &'a Bound<'py, PyAny>, + ) -> Result, nuts_rs::CpuMathError> + where + T: numpy::Element + Clone, + { + let arr: PyReadonlyArrayDyn = val.extract().map_err(|_| { + nuts_rs::CpuMathError::ExpandError(format!( + "variable {} had incorrect type", + var.name + )) + })?; + if !arr.is_c_contiguous() { + return Err(nuts_rs::CpuMathError::ExpandError( + "not c contiguous".into(), + )); + } + if !arr + .shape() + .iter() + .zip(var.shape.as_slice()) + .all(|(a, &b)| *a as u64 == b) + { + return Err(nuts_rs::CpuMathError::ExpandError("upected shape".into())); + } + Ok(arr) + } + + let val_array = match var.item_type.as_inner() { + nuts_rs::ItemType::F64 => { + let arr = as_value::(var, &val)?; + let slice = arr.as_slice().map_err(|_| { + nuts_rs::CpuMathError::ExpandError("Could not read as slice".into()) + })?; + Some(Value::F64(slice.to_vec())) + } + nuts_rs::ItemType::F32 => { + let arr = as_value::(var, &val)?; + let slice = arr.as_slice().map_err(|_| { + nuts_rs::CpuMathError::ExpandError("Could not read as slice".into()) + })?; + Some(Value::F32(slice.to_vec())) + } + nuts_rs::ItemType::I64 => { + let arr = as_value::(var, &val)?; + let slice = arr.as_slice().map_err(|_| { + nuts_rs::CpuMathError::ExpandError("Could not read as slice".into()) + })?; + Some(Value::I64(slice.to_vec())) + } + nuts_rs::ItemType::Bool => { + let arr = as_value::(var, &val)?; + let slice = arr.as_slice().map_err(|_| { + nuts_rs::CpuMathError::ExpandError("Could not read as slice".into()) + })?; + Some(Value::Bool(slice.to_vec())) + } + nuts_rs::ItemType::U64 => { + let arr = as_value::(var, &val)?; + let slice = arr.as_slice().map_err(|_| { + nuts_rs::CpuMathError::ExpandError("Could not read as slice".into()) + })?; + Some(Value::U64(slice.to_vec())) + } + nuts_rs::ItemType::String => { + let list: Bound = val.extract().map_err(|_| { + nuts_rs::CpuMathError::ExpandError("did not return list".into()) + })?; + if list.len() != var.shape.as_slice().iter().product::() as usize { + return Err(nuts_rs::CpuMathError::ExpandError( + "Incorrect number of items".into(), + )); + } + let vec: Vec = list + .iter() + .map(|item| { + item.extract::().map_err(|_| { + nuts_rs::CpuMathError::ExpandError( + "items were not all strings".into(), + ) + }) + }) + .collect::>()?; + Some(Value::Strings(vec)) + } + }; + expanded.push(val_array); + } + Ok(ExpandedVector(expanded)) + }) + } + fn inv_transform_normalize( &mut self, params: &Py, @@ -305,332 +480,21 @@ impl CpuLogpFunc for PyDensity { } } -pub struct PyTrace { - expand: Py, - variables: Arc>, - builder: StructBuilder, -} - -impl PyTrace { - pub fn new( - rng: &mut R, - chain: u64, - variables: Arc>, - make_expand_func: &Py, - capacity: usize, - ) -> Result { - let seed1 = rng.next_u64(); - let seed2 = rng.next_u64(); - let expand = Python::attach(|py| { - make_expand_func - .call1(py, (seed1, seed2, chain)) - .context("Failed to call expand function factory") - })?; - - let fields: Vec = variables - .iter() - .map(|variable| Field::new(variable.name.clone(), variable.arrow_dtype(), false)) - .collect(); - let builder = StructBuilder::from_fields(fields, capacity); - - Ok(Self { - expand, - variables, - builder, - }) - } -} - -pub type ShapeVec = SmallVec<[usize; 4]>; - -#[derive(Debug, Clone)] -#[non_exhaustive] -#[pyclass] -pub struct TensorShape { - pub shape: ShapeVec, - pub dims: Vec>, - size: usize, -} - -impl TensorShape { - pub fn new(shape: ShapeVec, dims: Vec>) -> Self { - let size = shape.iter().product(); - Self { shape, dims, size } - } - pub fn size(&self) -> usize { - self.size - } -} - -#[pymethods] -impl TensorShape { - #[new] - #[pyo3(signature = (shape, dims=None))] - fn py_new(shape: Vec, dims: Option>>) -> Result { - let dims = dims.unwrap_or(shape.iter().map(|_| None).collect()); - if dims.len() != shape.len() { - bail!("Number of dimensions must be the same as the shape"); - } - - let size = shape.iter().product(); - Ok(Self { - shape: shape.into(), - dims, - size, - }) - } -} - -#[non_exhaustive] -#[pyclass] -#[derive(Debug, Clone)] -pub enum ExpandDtype { - Boolean {}, - Float64 {}, - Float32 {}, - Int64 {}, - BooleanArray { tensor_type: TensorShape }, - ArrayFloat64 { tensor_type: TensorShape }, - ArrayFloat32 { tensor_type: TensorShape }, - ArrayInt64 { tensor_type: TensorShape }, -} - -#[pymethods] -impl ExpandDtype { - #[staticmethod] - fn boolean() -> Self { - Self::Boolean {} - } - - #[staticmethod] - fn float64() -> Self { - Self::Float64 {} - } - - #[staticmethod] - fn float32() -> Self { - Self::Float32 {} - } - - #[staticmethod] - fn int64() -> Self { - Self::Int64 {} - } - - #[staticmethod] - fn boolean_array(shape: TensorShape) -> Self { - Self::BooleanArray { tensor_type: shape } - } - - #[staticmethod] - fn float64_array(shape: TensorShape) -> Self { - Self::ArrayFloat64 { tensor_type: shape } - } - #[staticmethod] - fn float32_array(shape: TensorShape) -> Self { - Self::ArrayFloat32 { tensor_type: shape } - } - #[staticmethod] - fn int64_array(shape: TensorShape) -> Self { - Self::ArrayInt64 { tensor_type: shape } - } - - #[getter] - fn shape(&self) -> Option> { - match self { - Self::BooleanArray { tensor_type } => Some(tensor_type.shape.iter().cloned().collect()), - Self::ArrayFloat64 { tensor_type } => Some(tensor_type.shape.iter().cloned().collect()), - Self::ArrayFloat32 { tensor_type } => Some(tensor_type.shape.iter().cloned().collect()), - Self::ArrayInt64 { tensor_type } => Some(tensor_type.shape.iter().cloned().collect()), - _ => None, - } - } -} - -impl DrawStorage for PyTrace { - fn append_value(&mut self, point: &[f64]) -> Result<()> { - Python::attach(|py| { - let point = PyArray1::from_slice(py, point); - let full_point = self - .expand - .call1(py, (point,)) - .context("Failed to call expand function")? - .into_bound(py); - let point: &Bound = full_point - .downcast() - .map_err(|_| anyhow!("expand function must return a dict")) - .context("Expand function must return dict")?; - point - .iter() - .zip(self.variables.iter()) - .enumerate() - .try_for_each(|(i, ((key, value), variable))| { - let key: &str = key.extract()?; - if key != variable.name { - return Err(anyhow!("Incorrectly ordered expanded point")); - } - - match &variable.dtype { - ExpandDtype::Boolean {} => { - let builder: &mut BooleanBuilder = - self.builder.field_builder(i).context( - "Builder has incorrect type", - )?; - let value = value - .extract() - .expect("Return value from expand function could not be converted to boolean"); - builder.append_value(value) - }, - ExpandDtype::Float64 {} => { - let builder: &mut Float64Builder = - self.builder.field_builder(i).context( - "Builder has incorrect type", - )?; - builder.append_value( - value - .extract() - .expect("Return value from expand function could not be converted to float64") - ) - }, - ExpandDtype::Float32 {} => { - let builder: &mut Float32Builder = - self.builder.field_builder(i).context( - "Builder has incorrect type", - )?; - builder.append_value( - value - .extract() - .expect("Return value from expand function could not be converted to float32") - ) - }, - ExpandDtype::Int64 {} => { - let builder: &mut Int64Builder = - self.builder.field_builder(i).context( - "Builder has incorrect type", - )?; - let value = value.extract().expect("Return value from expand function could not be converted to int64"); - builder.append_value(value) - }, - ExpandDtype::BooleanArray { tensor_type } => { - let builder: &mut LargeListBuilder> = - self.builder.field_builder(i).context( - "Builder has incorrect type. Expected LargeListBuilder of Bool", - )?; - let value_builder = builder - .values() - .as_any_mut() - .downcast_mut::() - .context("Could not downcast builder to boolean type")?; - let values: PyReadonlyArray1 = value.extract().context("Could not convert object to array")?; - if values.len()? != tensor_type.size() { - bail!("Extracted array has incorrect shape"); - } - value_builder.append_slice(values.as_slice().context("Extracted array is not contiguous")?); - builder.append(true); - }, - ExpandDtype::ArrayFloat64 { tensor_type } => { - //let builder: &mut FixedSizeListBuilder> = - let builder: &mut LargeListBuilder> = - self.builder.field_builder(i).context( - "Builder has incorrect type. Expected LargeListBuilder of Float64", - )?; - let value_builder = builder - .values() - .as_any_mut() - .downcast_mut::>() - .context("Could not downcast builder to float64 type")?; - let values: PyReadonlyArray1 = value.extract().context("Could not convert object to array")?; - if values.len()? != tensor_type.size() { - bail!("Extracted array has incorrect shape"); - } - value_builder.append_slice(values.as_slice().context("Extracted array is not contiguous")?); - builder.append(true); - }, - ExpandDtype::ArrayFloat32 { tensor_type } => { - let builder: &mut LargeListBuilder> = - self.builder.field_builder(i).context( - "Builder has incorrect type. Expected LargeListBuilder of Float32", - )?; - let value_builder = builder - .values() - .as_any_mut() - .downcast_mut::>() - .context("Could not downcast builder to float32 type")?; - let values: PyReadonlyArray1 = value.extract().context("Could not convert object to array")?; - if values.len()? != tensor_type.size() { - bail!("Extracted array has incorrect shape"); - } - value_builder.append_slice(values.as_slice().context("Extracted array is not contiguous")?); - builder.append(true); - }, - ExpandDtype::ArrayInt64 {tensor_type} => { - let builder: &mut LargeListBuilder> = - self.builder.field_builder(i).context( - "Builder has incorrect type. Expected LargeListBuilder of Int64", - )?; - let value_builder = builder - .values() - .as_any_mut() - .downcast_mut::>() - .context("Could not downcast builder to i64 type")?; - let values: PyReadonlyArray1 = value.extract().context("Could not convert object to array")?; - if values.len()? != tensor_type.size() { - bail!("Extracted array has incorrect shape"); - } - value_builder.append_slice(values.as_slice().context("Extracted array is not contiguous")?); - builder.append(true); - }, - } - - Ok(()) - }).context("Could not save output of expand function to trace")?; - self.builder.append(true); - Ok(()) - }) - } - - fn finalize(mut self) -> Result> { - Ok(Arc::new(self.builder.finish())) - } - - fn inspect(&self) -> Result> { - Ok(Arc::new(self.builder.finish_cloned())) - } -} - impl Model for PyModel { type Math<'model> = CpuMath where Self: 'model; - type DrawStorage<'model, S: nuts_rs::Settings> - = PyTrace - where - Self: 'model; - - fn new_trace<'model, S: nuts_rs::Settings, R: rand::prelude::Rng + ?Sized>( - &'model self, - rng: &mut R, - chain_id: u64, - settings: &'model S, - ) -> Result> { - let draws = settings.hint_num_tune() + settings.hint_num_draws(); - PyTrace::new( - rng, - chain_id, - self.variables.clone(), - &self.make_expand_func, - draws, - ) - .context("Could not create PyTrace object") - } - - fn math(&self) -> Result> { + fn math(&self, _rng: &mut R) -> Result> { Ok(CpuMath::new(PyDensity::new( &self.make_logp_func, + &self.make_expand_func, self.ndim, self.transform_adapter.as_ref(), + self.variables.clone(), + self.dim_sizes.clone(), + self.coords.clone(), )?)) } diff --git a/src/pymc.rs b/src/pymc.rs index 1001a29..ce6db33 100644 --- a/src/pymc.rs +++ b/src/pymc.rs @@ -1,23 +1,23 @@ -use std::{ffi::c_void, fmt::Display, sync::Arc}; +use std::{collections::HashMap, ffi::c_void, sync::Arc}; use anyhow::{bail, Context, Result}; -use arrow::{ - array::{Array, Float64Array, LargeListArray, StructArray}, - buffer::OffsetBuffer, - datatypes::{DataType, Field, Fields}, -}; -use itertools::{izip, Itertools}; -use numpy::PyReadonlyArray1; -use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Model, Settings}; +use numpy::{NotContiguousError, PyReadonlyArray1}; +use nuts_rs::{CpuLogpFunc, CpuMath, HasDims, LogpError, Model, Storable, Value}; use pyo3::{ + exceptions::PyRuntimeError, pyclass, pymethods, - types::{PyAnyMethods, PyList}, - Bound, Py, PyAny, PyResult, Python, + types::{PyAnyMethods, PyDict, PyDictMethods}, + Py, PyAny, PyErr, PyResult, Python, }; -use rand_distr::num_traits::CheckedEuclid; +use rand::Rng; use thiserror::Error; +use crate::{ + common::{PyValue, PyVariable}, + wrapper::PyTransformAdapt, +}; + type UserData = *const std::ffi::c_void; type RawLogpFunc = unsafe extern "C" fn( @@ -42,7 +42,6 @@ pub(crate) struct LogpFunc { func: RawLogpFunc, _keep_alive: Arc>, user_data_ptr: UserData, - dim: usize, } unsafe impl Send for LogpFunc {} @@ -51,15 +50,15 @@ unsafe impl Sync for LogpFunc {} #[pymethods] impl LogpFunc { #[new] - fn new(dim: usize, ptr: usize, user_data_ptr: usize, keep_alive: Py) -> Self { + fn new(ptr: usize, user_data_ptr: usize, keep_alive: Py) -> Result { let func = unsafe { std::mem::transmute::<*const c_void, RawLogpFunc>(ptr as *const c_void) }; - Self { + + Ok(Self { func, _keep_alive: Arc::new(keep_alive), user_data_ptr: user_data_ptr as UserData, - dim, - } + }) } } @@ -98,142 +97,297 @@ impl ExpandFunc { unsafe impl Send for ExpandFunc {} unsafe impl Sync for ExpandFunc {} -#[derive(Error, Debug)] -pub(crate) struct ErrorCode(std::os::raw::c_int); +impl HasDims for PyMcModelRef<'_> { + fn dim_sizes(&self) -> HashMap { + self.model.dim_sizes.clone() + } + + fn coords(&self) -> HashMap { + self.model.coords.clone() + } +} + +pub struct ExpandedVector(Vec>); + +impl<'f> Storable> for ExpandedVector { + fn names<'a>(parent: &'a PyMcModelRef<'f>) -> Vec<&'a str> { + parent + .model + .variables + .iter() + .map(|var| var.name.as_str()) + .collect() + } + + fn item_type(parent: &PyMcModelRef<'f>, item: &str) -> nuts_rs::ItemType { + parent + .model + .variables + .iter() + .find(|var| var.name == item) + .map(|var| var.item_type.as_inner().clone()) + .expect("Item not found") + } -impl Display for ErrorCode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Logp function returned error code {}", self.0) + fn dims<'a>(parent: &'a PyMcModelRef<'f>, item: &str) -> Vec<&'a str> { + parent + .model + .variables + .iter() + .find(|var| var.name == item) + .map(|var| var.dims.as_slice().iter().map(|s| s.as_str()).collect()) + .expect("Item not found") + } + + fn get_all<'a>( + &'a mut self, + parent: &'a PyMcModelRef<'f>, + ) -> Vec<(&'a str, Option)> { + self.0 + .iter_mut() + .zip(parent.model.variables.iter()) + .map(|(val, var)| (var.name.as_str(), val.take())) + .collect() } } -impl LogpError for ErrorCode { +#[derive(Debug, Error)] +pub enum PyMcLogpError { + #[error("Python error: {0}")] + PyError(#[from] PyErr), + #[error("Python retured a non-contigous array")] + NotContiguousError(#[from] NotContiguousError), + #[error("Unknown error: {0}")] + Anyhow(#[from] anyhow::Error), + #[error("Logp function returned error code: {0}")] + ErrorCode(std::os::raw::c_int), +} + +impl LogpError for PyMcLogpError { fn is_recoverable(&self) -> bool { - self.0 > 0 + match self { + Self::PyError(err) => Python::attach(|py| { + let Ok(attr) = err.value(py).getattr("is_recoverable") else { + return false; + }; + attr.is_truthy() + .expect("Could not access is_recoverable in error check") + }), + Self::NotContiguousError(_) => false, + Self::Anyhow(_) => false, + Self::ErrorCode(code) => *code > (0 as std::os::raw::c_int), + } } } -impl CpuLogpFunc for &LogpFunc { - type LogpError = ErrorCode; - type TransformParams = (); +pub struct PyMcModelRef<'a> { + model: &'a PyMcModel, + transform_adapter: Option, +} + +impl CpuLogpFunc for PyMcModelRef<'_> { + type LogpError = PyMcLogpError; + type FlowParameters = Py; + type ExpandedVector = ExpandedVector; fn dim(&self) -> usize { - self.dim + self.model.dim } fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result { let mut logp = 0f64; let logp_ptr = (&mut logp) as *mut f64; - assert!(position.len() == self.dim); - assert!(gradient.len() == self.dim); + assert!(position.len() == self.model.dim); + assert!(gradient.len() == self.model.dim); let retcode = unsafe { - (self.func)( - self.dim, + (self.model.density.func)( + self.model.dim, position.as_ptr(), gradient.as_mut_ptr(), logp_ptr, - self.user_data_ptr, + self.model.density.user_data_ptr, ) }; if retcode == 0 { return Ok(logp); } - Err(ErrorCode(retcode)) + Err(PyMcLogpError::ErrorCode(retcode)) } -} - -#[derive(Clone)] -pub(crate) struct PyMcTrace<'model> { - dim: usize, - data: Vec>, - var_sizes: Vec, - var_names: Vec, - expand: &'model ExpandFunc, - count: usize, -} - -impl<'model> DrawStorage for PyMcTrace<'model> { - fn append_value(&mut self, point: &[f64]) -> Result<()> { - assert!(point.len() == self.dim); - let point = self - .expand_draw(point) - .context("Could not compute deterministic variables")?; + fn expand_vector( + &mut self, + _rng: &mut R, + array: &[f64], + ) -> std::result::Result + where + R: rand::Rng + ?Sized, + { + let mut out = vec![0f64; self.model.expand.expanded_dim].into_boxed_slice(); + let retcode = unsafe { + (self.model.expand.func)( + self.model.expand.dim, + self.model.expand.expanded_dim, + array.as_ptr(), + out.as_mut_ptr(), + self.model.expand.user_data_ptr, + ) + }; - let mut start: usize = 0; - for (&size, data) in self.var_sizes.iter().zip_eq(self.data.iter_mut()) { - let end = start.checked_add(size).unwrap(); - let vals = &point[start..end]; - data.extend_from_slice(vals); - start = end; + let mut values = Vec::new(); + for var in self.model.variables.iter() { + let start = var.start_idx.expect("Variable has no start index"); + let end = var.end_idx.expect("Variable has no end index"); + let slice = &out[start..end]; + + let value = match var.item_type.as_inner() { + nuts_rs::ItemType::U64 => { + let vec: Vec = slice.iter().map(|&x| x as u64).collect(); + nuts_rs::Value::U64(vec.into()) + } + nuts_rs::ItemType::I64 => { + let vec: Vec = slice.iter().map(|&x| x as i64).collect(); + nuts_rs::Value::I64(vec.into()) + } + nuts_rs::ItemType::F64 => { + let vec: Vec = slice.iter().map(|&x| x as f64).collect(); + nuts_rs::Value::F64(vec.into()) + } + nuts_rs::ItemType::F32 => { + let vec: Vec = slice.iter().map(|&x| x as f32).collect(); + nuts_rs::Value::F32(vec.into()) + } + nuts_rs::ItemType::Bool => { + let vec: Vec = slice.iter().map(|&x| x != 0.0).collect(); + nuts_rs::Value::Bool(vec.into()) + } + nuts_rs::ItemType::String => { + return Err(nuts_rs::CpuMathError::ExpandError( + "String type not supported in expansion".into(), + )); + } + }; + + values.push(Some(value)); } - self.count += 1; - Ok(()) + if retcode == 0 { + Ok(ExpandedVector(values)) + } else { + Err(nuts_rs::CpuMathError::ExpandError(format!( + "Expand function returned error code {}", + retcode + ))) + } + } + fn inv_transform_normalize( + &mut self, + params: &Py, + untransformed_position: &[f64], + untransformed_gradient: &[f64], + transformed_position: &mut [f64], + transformed_gradient: &mut [f64], + ) -> std::result::Result { + let logdet = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .inv_transform_normalize( + params, + untransformed_position, + untransformed_gradient, + transformed_position, + transformed_gradient, + )?; + Ok(logdet) } - fn finalize(self) -> Result> { - let (fields, arrays): (Vec<_>, _) = izip!(self.data, self.var_names, self.var_sizes) - .map(|(data, name, size)| { - let (num_arrays, rem) = data - .len() - .checked_div_rem_euclid(&size) - .unwrap_or((self.count, 0)); - assert!(rem == 0); - assert!(num_arrays == self.count); - let data = Float64Array::from(data); - let item_field = Arc::new(Field::new("item", DataType::Float64, false)); - let offsets = OffsetBuffer::from_lengths((0..num_arrays).map(|_| size)); - let array = LargeListArray::new(item_field.clone(), offsets, Arc::new(data), None); - let field = Field::new(name, DataType::LargeList(item_field), false); - (Arc::new(field), Arc::new(array) as Arc) - }) - .unzip(); + fn init_from_transformed_position( + &mut self, + params: &Py, + untransformed_position: &mut [f64], + untransformed_gradient: &mut [f64], + transformed_position: &[f64], + transformed_gradient: &mut [f64], + ) -> std::result::Result<(f64, f64), Self::LogpError> { + let (logp, logdet) = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .init_from_transformed_position( + params, + untransformed_position, + untransformed_gradient, + transformed_position, + transformed_gradient, + )?; + Ok((logp, logdet)) + } - let fields = Fields::from(fields); - Ok(Arc::new( - StructArray::try_new(fields, arrays, None).context("Could not create arrow struct")?, - )) + fn init_from_untransformed_position( + &mut self, + params: &Py, + untransformed_position: &[f64], + untransformed_gradient: &mut [f64], + transformed_position: &mut [f64], + transformed_gradient: &mut [f64], + ) -> std::result::Result<(f64, f64), Self::LogpError> { + let (logp, logdet) = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .init_from_untransformed_position( + params, + untransformed_position, + untransformed_gradient, + transformed_position, + transformed_gradient, + )?; + Ok((logp, logdet)) } - fn inspect(&self) -> Result> { - self.clone().finalize() + fn update_transformation<'a, R: rand::Rng + ?Sized>( + &'a mut self, + rng: &mut R, + untransformed_positions: impl ExactSizeIterator, + untransformed_gradients: impl ExactSizeIterator, + untransformed_logp: impl ExactSizeIterator, + params: &'a mut Py, + ) -> std::result::Result<(), Self::LogpError> { + self.transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .update_transformation( + rng, + untransformed_positions, + untransformed_gradients, + untransformed_logp, + params, + )?; + Ok(()) } -} -impl<'model> PyMcTrace<'model> { - fn new(model: &'model PyMcModel, settings: &impl Settings) -> Self { - let draws = settings.hint_num_draws() + settings.hint_num_tune(); - Self { - dim: model.dim, - data: model - .var_sizes - .iter() - .map(|&size| Vec::with_capacity(size * draws)) - .collect(), - var_sizes: model.var_sizes.clone(), - var_names: model.var_names.clone(), - expand: &model.expand, - count: 0, - } + fn new_transformation( + &mut self, + rng: &mut R, + untransformed_position: &[f64], + untransformed_gradient: &[f64], + chain: u64, + ) -> std::result::Result, Self::LogpError> { + let trafo = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .new_transformation(rng, untransformed_position, untransformed_gradient, chain)?; + Ok(trafo) } - fn expand_draw(&mut self, point: &[f64]) -> Result> { - let mut out = vec![0f64; self.expand.expanded_dim].into_boxed_slice(); - let retcode = unsafe { - (self.expand.func)( - self.expand.dim, - self.expand.expanded_dim, - point.as_ptr(), - out.as_mut_ptr(), - self.expand.user_data_ptr, - ) - }; - if retcode == 0 { - Ok(out) - } else { - Err(anyhow::Error::msg("Failed to expand a draw.")) - } + fn transformation_id(&self, params: &Py) -> std::result::Result { + let id = self + .transform_adapter + .as_ref() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .transformation_id(params)?; + Ok(id) } } @@ -244,28 +398,59 @@ pub(crate) struct PyMcModel { density: LogpFunc, expand: ExpandFunc, init_func: Arc>, - var_sizes: Vec, - var_names: Vec, + transform_adapter: Option, + variables: Arc>, + dim_sizes: HashMap, + coords: HashMap, } #[pymethods] impl PyMcModel { #[new] fn new<'py>( - dim: usize, + py: Python<'py>, density: LogpFunc, expand: ExpandFunc, + variables: Vec, + dim: usize, + dim_sizes: Py, + coords: Py, init_func: Py, - var_sizes: &Bound<'py, PyList>, - var_names: &Bound<'py, PyList>, + transform_adapter: Option>, ) -> PyResult { + let dim_sizes = dim_sizes + .bind(py) + .iter() + .map(|(key, value)| { + let key: String = key.extract().context("Dimension key is not a string")?; + let value: u64 = value + .extract() + .context("Dimension size value is not an integer")?; + Ok((key, value)) + }) + .collect::>>()?; + + let coords = coords + .bind(py) + .iter() + .map(|(key, value)| { + let key: String = key.extract().context("Coordinate key is not a string")?; + let value: PyValue = value + .extract() + .context("Coordinate value has incorrect type")?; + Ok((key, value.into_value())) + }) + .collect::>>()?; + Ok(Self { dim, density, expand, init_func: init_func.into(), - var_names: var_names.extract()?, - var_sizes: var_sizes.extract()?, + coords, + dim_sizes, + transform_adapter: transform_adapter.map(PyTransformAdapt::new), + variables: Arc::new(variables), }) } @@ -291,12 +476,13 @@ impl PyMcModel { } impl Model for PyMcModel { - type Math<'model> = CpuMath<&'model LogpFunc>; + type Math<'model> = CpuMath>; - type DrawStorage<'model, S: Settings> = PyMcTrace<'model>; - - fn math(&self) -> Result> { - Ok(CpuMath::new(&self.density)) + fn math(&self, _rng: &mut R) -> Result> { + Ok(CpuMath::new(PyMcModelRef { + model: self, + transform_adapter: self.transform_adapter.clone(), + })) } fn init_position( @@ -329,13 +515,4 @@ impl Model for PyMcModel { })?; Ok(()) } - - fn new_trace<'model, S: Settings, R: rand::prelude::Rng + ?Sized>( - &'model self, - _rng: &mut R, - _chain_id: u64, - settings: &'model S, - ) -> Result> { - Ok(PyMcTrace::new(self, settings)) - } } diff --git a/src/stan.rs b/src/stan.rs index b10ac44..37d4be5 100644 --- a/src/stan.rs +++ b/src/stan.rs @@ -1,23 +1,23 @@ +use std::collections::HashMap; use std::sync::Arc; use std::{ffi::CString, path::PathBuf}; -use anyhow::{bail, Context}; -use arrow::array::{Array, FixedSizeListArray, Float64Array, StructArray}; -use arrow::datatypes::{DataType, Field}; +use anyhow::{bail, Context, Result}; use bridgestan::open_library; -use itertools::{izip, Itertools}; -use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Model, Settings}; +use itertools::Itertools; +use nuts_rs::{CpuLogpFunc, CpuMath, HasDims, LogpError, Model, Storable, Value}; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use pyo3::types::{PyDict, PyTuple}; use pyo3::{exceptions::PyValueError, pyclass, pymethods, PyResult}; use rand::prelude::Distribution; -use rand::{rng, RngCore}; +use rand::{rng, Rng, RngCore}; use rand_distr::StandardNormal; use smallvec::{SmallVec, ToSmallVec}; use thiserror::Error; +use crate::common::{ItemType, PyValue, PyVariable}; use crate::wrapper::PyTransformAdapt; type InnerModel = bridgestan::Model>; @@ -79,13 +79,22 @@ impl StanVariable { #[pyclass] #[derive(Clone)] pub struct StanModel { - model: Arc, - variables: Vec, + inner: Arc, + variables: Vec, transform_adapter: Option, + dim_sizes: HashMap, + coords: HashMap, + #[pyo3(get)] + dims: HashMap>, + unc_names: Value, } /// Return meta information about the constrained parameters of the model -fn params(var_string: &str) -> anyhow::Result> { +fn params( + var_string: &str, + all_dims: &mut HashMap>, + dim_sizes: &mut HashMap, +) -> anyhow::Result> { if var_string.is_empty() { return Ok(vec![]); } @@ -143,35 +152,38 @@ fn params(var_string: &str) -> anyhow::Result> { .context(format!("Error while parsing stan variable {name}"))?; // Calculate total size of this variable - let size = shape.iter().product(); + let size: usize = shape.iter().product(); let mut end_idx = start_idx + size; // Create Parameter objects (one for real and one for imag if complex) if is_complex { - variables.push(Parameter { - name: format!("{name}.real"), - shape: shape.clone(), - size, - start_idx, - end_idx, - }); + variables.push(PyVariable::new( + format!("{name}.real"), + ItemType(nuts_rs::ItemType::F64), + Some(shape.iter().map(|&d| d as u64).collect()), + all_dims, + dim_sizes, + Some(start_idx), + )?); start_idx = end_idx; end_idx = start_idx + size; - variables.push(Parameter { - name: format!("{name}.imag"), - shape, - size, - start_idx, - end_idx, - }); + variables.push(PyVariable::new( + format!("{name}.imag"), + ItemType(nuts_rs::ItemType::F64), + Some(shape.iter().map(|&d| d as u64).collect()), + all_dims, + dim_sizes, + Some(start_idx), + )?); } else { - variables.push(Parameter { - name: name.to_string(), - shape, - size, - start_idx, - end_idx, - }); + variables.push(PyVariable::new( + name.to_string(), + ItemType(nuts_rs::ItemType::F64), + Some(shape.iter().map(|&d| d as u64).collect()), + all_dims, + dim_sizes, + Some(start_idx), + )?); } // Move to the next variable @@ -240,29 +252,85 @@ where #[pymethods] impl StanModel { #[new] - #[pyo3(signature = (lib, seed=None, data=None, transform_adapter=None))] + #[pyo3(signature = (lib, dim_sizes, dims, coords, seed=None, data=None, transform_adapter=None))] pub fn new( + py: Python<'_>, lib: StanLibrary, + dim_sizes: Py, + dims: Py, + coords: Py, seed: Option, data: Option, transform_adapter: Option>, ) -> anyhow::Result { + let mut dim_sizes = dim_sizes + .bind(py) + .iter() + .map(|(key, value)| { + let key: String = key.extract().context("Dimension key is not a string")?; + let value: u64 = value + .extract() + .context("Dimension size value is not an integer")?; + Ok((key, value)) + }) + .collect::>>()?; + + let mut dims = dims + .bind(py) + .iter() + .map(|(key, value)| { + let key: String = key.extract().context("Dimension key is not a string")?; + let value: Vec = value + .extract() + .context("Dimension value is not a list of strings")?; + Ok((key, value)) + }) + .collect::>>()?; + + let coords = coords + .bind(py) + .iter() + .map(|(key, value)| { + let key: String = key.extract().context("Coordinate key is not a string")?; + let value: PyValue = value + .extract() + .context("Coordinate value has incorrect type")?; + Ok((key, value.into_value())) + }) + .collect::>>()?; + let seed = match seed { Some(seed) => seed, None => rng().next_u32(), }; let data: Option = data.map(CString::new).transpose()?; - let model = Arc::new( - bridgestan::Model::new(lib.0, data.as_ref(), seed).map_err(anyhow::Error::new)?, - ); + let mut model = + bridgestan::Model::new(lib.0, data.as_ref(), seed).map_err(anyhow::Error::new)?; + + // TODO: bridgestan should not require mut self here + let names = model.param_unc_names(); + let mut names: Vec<_> = names.split(',').map(|v| v.to_string()).collect(); + if let Some(first) = names.first() { + if first.is_empty() { + names = vec![]; + } + }; + let unc_names = Value::Strings(names); + + let model = Arc::new(model); let var_string = model.param_names(true, true); - let variables = params(var_string)?; + let variables = params(var_string, &mut dims, &mut dim_sizes)?; let transform_adapter = transform_adapter.map(PyTransformAdapt::new); + Ok(StanModel { - model, + inner: model, variables, transform_adapter, + dim_sizes, + coords, + dims, + unc_names, }) } @@ -271,29 +339,14 @@ impl StanModel { let results: Result, _> = self .variables .iter() - .map(|var| { - out.set_item( - var.name.clone(), - StanVariable(var.clone()).into_pyobject(py)?, - ) - }) + .map(|var| out.set_item(var.name.clone(), var.clone())) .collect(); results?; Ok(out) } pub fn ndim(&self) -> usize { - self.model.param_unc_num() - } - - pub fn param_unc_names(&mut self) -> anyhow::Result> { - Ok(Arc::get_mut(&mut self.model) - .ok_or_else(|| anyhow::format_err!("Model is currently in use")) - .context("Failed to access the names of unconstrained parameters")? - .param_unc_names() - .split(',') - .map(|name| name.to_string()) - .collect()) + self.inner.param_unc_num() } /* @@ -318,8 +371,10 @@ impl StanModel { } pub struct StanDensity<'model> { - inner: &'model InnerModel, + model: &'model StanModel, + rng: bridgestan::Rng<&'model bridgestan::StanLibrary>, transform_adapter: Option, + expanded_buffer: Vec, } #[derive(Debug, Error)] @@ -340,12 +395,65 @@ impl LogpError for StanLogpError { } } +pub struct ExpandedVector(Vec>); + +impl<'model> Storable> for ExpandedVector { + fn names<'a>(parent: &'a StanDensity<'model>) -> Vec<&'a str> { + parent + .model + .variables + .iter() + .map(|var| var.name.as_str()) + .collect() + } + + fn item_type(parent: &StanDensity<'model>, item: &str) -> nuts_rs::ItemType { + parent + .model + .variables + .iter() + .find(|var| var.name == item) + .map(|var| var.item_type.as_inner().clone()) + .expect("Item not found") + } + + fn dims<'a>(parent: &'a StanDensity<'model>, item: &str) -> Vec<&'a str> { + parent + .model + .variables + .iter() + .find(|var| var.name == item) + .map(|var| var.dims.as_slice().iter().map(|s| s.as_str()).collect()) + .expect("Item not found") + } + + fn get_all<'a>(&'a mut self, parent: &'a StanDensity<'model>) -> Vec<(&'a str, Option)> { + self.0 + .iter_mut() + .zip(parent.model.variables.iter()) + .map(|(val, var)| (var.name.as_str(), val.take())) + .collect() + } +} + +impl<'model> HasDims for StanDensity<'model> { + fn dim_sizes(&self) -> HashMap { + self.model.dim_sizes.clone() + } + + fn coords(&self) -> HashMap { + self.model.coords.clone() + } +} + impl<'model> CpuLogpFunc for StanDensity<'model> { type LogpError = StanLogpError; - type TransformParams = Py; + type FlowParameters = Py; + type ExpandedVector = ExpandedVector; fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result { let logp = self + .model .inner .log_density_gradient(position, true, true, grad)?; if !logp.is_finite() { @@ -355,7 +463,60 @@ impl<'model> CpuLogpFunc for StanDensity<'model> { } fn dim(&self) -> usize { - self.inner.param_unc_num() + self.model.inner.param_unc_num() + } + + fn vector_coord(&self) -> Option { + Some(self.model.unc_names.clone()) + } + + fn expand_vector( + &mut self, + _rng: &mut R, + array: &[f64], + ) -> Result + where + R: rand::Rng + ?Sized, + { + self.model + .inner + .param_constrain( + array, + true, + true, + &mut self.expanded_buffer, + Some(&mut self.rng), + ) + .context("Failed to constrain the parameters of the draw") + .map_err(|e| nuts_rs::CpuMathError::ExpandError(format!("{}", e)))?; + + let mut vars = Vec::new(); + + for var in self.model.variables.iter() { + let mut out = Vec::with_capacity(var.num_elements); + let start = var.start_idx.expect("Variable start index not set"); + let end = var.end_idx.expect("Variable end index not set"); + let slice = &self.expanded_buffer[start..end]; + assert!(slice.len() == var.num_elements); + + if var.num_elements == 0 { + vars.push(Some(Value::F64(out))); + continue; + } + + // The slice is in fortran order. This doesn't matter if it low dim + if var.shape.as_slice().len() < 2 { + out.extend_from_slice(slice); + vars.push(Some(Value::F64(out))); + continue; + } + + // We need to transpose + fortran_to_c_order(slice, var.shape.as_slice(), &mut out); + vars.push(Some(Value::F64(out))); + } + + Ok(ExpandedVector(vars)) } fn inv_transform_normalize( @@ -495,11 +656,11 @@ impl<'model> CpuLogpFunc for StanDensity<'model> { } } -fn fortran_to_c_order(data: &[f64], shape: &[usize], out: &mut Vec) { +fn fortran_to_c_order(data: &[f64], shape: &[u64], out: &mut Vec) { let rank = shape.len(); let strides = { - let mut strides: SmallVec<[usize; 8]> = SmallVec::with_capacity(rank); - let mut current: usize = 1; + let mut strides: SmallVec<[u64; 8]> = SmallVec::with_capacity(rank); + let mut current: u64 = 1; for &length in shape.iter() { strides.push(current); current = current @@ -510,33 +671,34 @@ fn fortran_to_c_order(data: &[f64], shape: &[usize], out: &mut Vec) { strides }; - let mut shape: SmallVec<[usize; 8]> = shape.to_smallvec(); + let mut shape: SmallVec<[u64; 8]> = shape.to_smallvec(); shape.reverse(); - let mut idx: SmallVec<[usize; 8]> = shape.iter().map(|_| 0usize).collect(); - let mut position: usize = 0; + let mut idx: SmallVec<[u64; 8]> = shape.iter().map(|_| 0u64).collect(); + let mut position: u64 = 0; 'iterate: loop { - out.push(data[position]); + out.push(data[position as usize]); - let mut axis: usize = 0; + let mut axis: u64 = 0; 'nextidx: loop { - idx[axis] += 1; - position += strides[axis]; + idx[axis as usize] += 1; + position += strides[axis as usize]; - if idx[axis] < shape[axis] { + if idx[axis as usize] < shape[axis as usize] { break 'nextidx; } - idx[axis] = 0; - position -= shape[axis] * strides[axis]; + idx[axis as usize] = 0; + position -= shape[axis as usize] * strides[axis as usize]; axis += 1; - if axis == rank { + if axis == rank as u64 { break 'iterate; } } } } +/* pub struct StanTrace<'model> { inner: &'model InnerModel, model: &'model StanModel, @@ -546,28 +708,6 @@ pub struct StanTrace<'model> { count: usize, } -impl<'model> Clone for StanTrace<'model> { - fn clone(&self) -> Self { - // TODO We should avoid this Clone implementation. - // We only need it for `StanTrace.inspect`, which - // doesn't need rng, so we could avoid this strange - // seed of zeros. - let rng = self - .model - .model - .new_rng(0) - .expect("Could not create stan rng"); - Self { - inner: self.inner, - model: self.model, - trace: self.trace.clone(), - expanded_buffer: self.expanded_buffer.clone(), - rng, - count: self.count, - } - } -} - impl<'model> DrawStorage for StanTrace<'model> { fn append_value(&mut self, point: &[f64]) -> anyhow::Result<()> { self.inner @@ -599,41 +739,13 @@ impl<'model> DrawStorage for StanTrace<'model> { self.count += 1; Ok(()) } - - fn finalize(self) -> anyhow::Result> { - let (fields, arrays): (Vec<_>, Vec<_>) = izip!(self.trace, &self.model.variables) - .map(|(data, variable)| { - let data = Float64Array::from(data); - let item_field = Arc::new(Field::new("item", DataType::Float64, false)); - let array = FixedSizeListArray::new( - item_field.clone(), - variable.size as _, - Arc::new(data), - None, - ); - let dtype = DataType::FixedSizeList(item_field, variable.size as i32); - let field = Arc::new(Field::new(variable.name.clone(), dtype.clone(), false)); - let list: Arc = Arc::new(array); - (field, list) - }) - .unzip(); - - Ok(Arc::new( - StructArray::try_new_with_length(fields.into(), arrays, None, self.count) - .context("Could not create arrow StructArray")?, - )) - } - - fn inspect(&self) -> anyhow::Result> { - self.clone().finalize() - } } +*/ impl Model for StanModel { type Math<'model> = CpuMath>; - type DrawStorage<'model, S: nuts_rs::Settings> = StanTrace<'model>; - + /* fn new_trace<'a, S: Settings, R: rand::Rng + ?Sized>( &'a self, rng: &mut R, @@ -658,11 +770,16 @@ impl Model for StanModel { count: 0, }) } + */ - fn math(&self) -> anyhow::Result> { + fn math(&self, rng: &mut R) -> anyhow::Result> { + let rng = self.inner.new_rng(rng.next_u32())?; + let num_expanded = self.inner.param_num(true, true); Ok(CpuMath::new(StanDensity { - inner: &self.model, + model: &self, + rng, transform_adapter: self.transform_adapter.clone(), + expanded_buffer: vec![0f64; num_expanded], })) } @@ -681,6 +798,8 @@ impl Model for StanModel { #[cfg(test)] mod tests { + use std::collections::HashMap; + use itertools::Itertools; use super::fortran_to_c_order; @@ -741,48 +860,51 @@ mod tests { #[test] fn parse_vars() { + let mut dims = HashMap::new(); + let mut dim_sizes = HashMap::new(); + let vars = ""; - let parsed = super::params(vars).unwrap(); + let parsed = super::params(vars, &mut dims, &mut dim_sizes).unwrap(); assert!(parsed.len() == 0); let vars = "x.1.1,x.2.1,x.3.1,x.1.2,x.2.2,x.3.2"; - let parsed = super::params(vars).unwrap(); + let parsed = super::params(vars, &mut dims, &mut dim_sizes).unwrap(); assert!(parsed.len() == 1); let parsed = parsed[0].clone(); assert!(parsed.name == "x"); - assert!(parsed.shape == vec![3, 2]); + assert!(parsed.shape.as_slice() == vec![3, 2]); // Incorrect order let vars = "x.1.2,x.1.1,x.2.1,x.2.2,x.3.1,x.3.2"; - assert!(super::params(vars).is_err()); + assert!(super::params(vars, &mut dims, &mut dim_sizes).is_err()); // Incorrect order let vars = "x.1.2.real,x.1.2.imag"; - assert!(super::params(vars).is_err()); + assert!(super::params(vars, &mut dims, &mut dim_sizes).is_err()); let vars = "x.1.1.real,x.1.1.imag,x.2.1.real,x.2.1.imag,x.3.1.real,x.3.1.imag"; - let parsed = super::params(vars).unwrap(); + let parsed = super::params(vars, &mut dims, &mut dim_sizes).unwrap(); assert!(parsed.len() == 2); let var = parsed[0].clone(); assert!(var.name == "x.real"); - assert!(var.shape == vec![3, 1]); + assert!(var.shape.as_slice() == vec![3, 1]); let var = parsed[1].clone(); assert!(var.name == "x.imag"); - assert!(var.shape == vec![3, 1]); + assert!(var.shape.as_slice() == vec![3, 1]); // Test single variable let vars = "alpha"; - let parsed = super::params(vars).unwrap(); + let parsed = super::params(vars, &mut dims, &mut dim_sizes).unwrap(); assert_eq!(parsed.len(), 1); let var = &parsed[0]; assert_eq!(var.name, "alpha"); - assert_eq!(var.shape, Vec::::new()); - assert_eq!(var.size, 1); + assert_eq!(var.shape.as_slice(), vec![0; 0]); + assert_eq!(var.num_elements, 1); // Test multiple scalar variables let vars = "alpha,beta,gamma"; - let parsed = super::params(vars).unwrap(); + let parsed = super::params(vars, &mut dims, &mut dim_sizes).unwrap(); assert_eq!(parsed.len(), 3); assert_eq!(parsed[0].name, "alpha"); assert_eq!(parsed[1].name, "beta"); @@ -790,21 +912,21 @@ mod tests { // Test 1D array let vars = "theta.1,theta.2,theta.3,theta.4"; - let parsed = super::params(vars).unwrap(); + let parsed = super::params(vars, &mut dims, &mut dim_sizes).unwrap(); assert_eq!(parsed.len(), 1); let var = &parsed[0]; assert_eq!(var.name, "theta"); - assert_eq!(var.shape, vec![4]); - assert_eq!(var.size, 4); + assert_eq!(var.shape.as_slice(), vec![4]); + assert_eq!(var.num_elements, 4); // Test variable name with colons and dots let vars = "x:1:2.4:1.1,x:1:2.4:1.2,x:1:2.4:1.3"; - let parsed = super::params(vars).unwrap(); + let parsed = super::params(vars, &mut dims, &mut dim_sizes).unwrap(); assert_eq!(parsed.len(), 1); let var = &parsed[0]; assert_eq!(var.name, "x:1:2.4:1"); - assert_eq!(var.shape, vec![3]); - assert_eq!(var.size, 3); + assert_eq!(var.shape.as_slice(), vec![3]); + assert_eq!(var.num_elements, 3); let vars = " a, @@ -1009,89 +1131,89 @@ mod tests { ultimate.2.3:2.3.5, ultimate.2.3:2.4.5 "; - let parsed = super::params(vars).unwrap(); + let parsed = super::params(vars, &mut dims, &mut dim_sizes).unwrap(); assert_eq!(parsed[0].name, "a"); - assert_eq!(parsed[0].shape, vec![0usize; 0]); + assert_eq!(parsed[0].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[1].name, "base"); - assert_eq!(parsed[1].shape, vec![0usize; 0]); + assert_eq!(parsed[1].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[2].name, "base_i"); - assert_eq!(parsed[2].shape, vec![0usize; 0]); + assert_eq!(parsed[2].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[3].name, "pair:1"); - assert_eq!(parsed[3].shape, vec![0usize; 0]); + assert_eq!(parsed[3].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[4].name, "pair:2"); - assert_eq!(parsed[4].shape, vec![0usize; 0]); + assert_eq!(parsed[4].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[5].name, "nested:1"); - assert_eq!(parsed[5].shape, vec![0usize; 0]); + assert_eq!(parsed[5].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[6].name, "nested:2:1"); - assert_eq!(parsed[6].shape, vec![0usize; 0]); + assert_eq!(parsed[6].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[7].name, "nested:2:2.real"); - assert_eq!(parsed[7].shape, vec![0usize; 0]); + assert_eq!(parsed[7].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[8].name, "nested:2:2.imag"); - assert_eq!(parsed[8].shape, vec![0usize; 0]); + assert_eq!(parsed[8].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[9].name, "arr_pair.1:1"); - assert_eq!(parsed[9].shape, vec![0usize; 0]); + assert_eq!(parsed[9].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[10].name, "arr_pair.1:2"); - assert_eq!(parsed[10].shape, vec![0usize; 0]); + assert_eq!(parsed[10].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[11].name, "arr_pair.2:1"); - assert_eq!(parsed[11].shape, vec![0usize; 0]); + assert_eq!(parsed[11].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[12].name, "arr_pair.2:2"); - assert_eq!(parsed[12].shape, vec![0usize; 0]); + assert_eq!(parsed[12].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[13].name, "arr_very_nested.1:1:1"); - assert_eq!(parsed[13].shape, vec![0usize; 0]); + assert_eq!(parsed[13].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[14].name, "arr_very_nested.1:1:2:1"); - assert_eq!(parsed[14].shape, vec![0usize; 0]); + assert_eq!(parsed[14].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[15].name, "arr_very_nested.1:1:2:2.real"); - assert_eq!(parsed[15].shape, vec![0usize; 0]); + assert_eq!(parsed[15].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[16].name, "arr_very_nested.1:1:2:2.imag"); - assert_eq!(parsed[16].shape, vec![0usize; 0]); + assert_eq!(parsed[16].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[17].name, "arr_very_nested.1:2"); - assert_eq!(parsed[17].shape, vec![0usize; 0]); + assert_eq!(parsed[17].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[18].name, "arr_very_nested.2:1:1"); - assert_eq!(parsed[18].shape, vec![0usize; 0]); + assert_eq!(parsed[18].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[19].name, "arr_very_nested.2:1:2:1"); - assert_eq!(parsed[19].shape, vec![0usize; 0]); + assert_eq!(parsed[19].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[20].name, "arr_very_nested.2:1:2:2.real"); - assert_eq!(parsed[20].shape, vec![0usize; 0]); + assert_eq!(parsed[20].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[21].name, "arr_very_nested.2:1:2:2.imag"); - assert_eq!(parsed[21].shape, vec![0usize; 0]); + assert_eq!(parsed[21].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[22].name, "arr_very_nested.2:2"); - assert_eq!(parsed[22].shape, vec![0usize; 0]); + assert_eq!(parsed[22].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[23].name, "arr_very_nested.3:1:1"); - assert_eq!(parsed[23].shape, vec![0usize; 0]); + assert_eq!(parsed[23].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[24].name, "arr_very_nested.3:1:2:1"); - assert_eq!(parsed[24].shape, vec![0usize; 0]); + assert_eq!(parsed[24].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[25].name, "arr_very_nested.3:1:2:2.real"); - assert_eq!(parsed[25].shape, vec![0usize; 0]); + assert_eq!(parsed[25].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[26].name, "arr_very_nested.3:1:2:2.imag"); - assert_eq!(parsed[26].shape, vec![0usize; 0]); + assert_eq!(parsed[26].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[27].name, "arr_very_nested.3:2"); - assert_eq!(parsed[27].shape, vec![0usize; 0]); + assert_eq!(parsed[27].shape.as_slice(), vec![0; 0]); } } diff --git a/src/wrapper.rs b/src/wrapper.rs index 725e99a..37322d6 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -6,27 +6,31 @@ use std::{ }; use crate::{ + common::PyVariable, progress::{IndicatifHandler, ProgressHandler}, - pyfunc::{ExpandDtype, PyModel, PyVariable, TensorShape}, + pyfunc::PyModel, pymc::{ExpandFunc, LogpFunc, PyMcModel}, stan::{StanLibrary, StanModel}, }; -use anyhow::{bail, Context, Result}; -use arrow::array::Array; +use anyhow::{anyhow, bail, Context, Result}; use numpy::{PyArray1, PyReadonlyArray1}; use nuts_rs::{ - ChainProgress, DiagGradNutsSettings, LowRankNutsSettings, ProgressCallback, Sampler, - SamplerWaitResult, StepSizeAdaptMethod, Trace, TransformedNutsSettings, + ArrowConfig, ArrowTrace, ChainProgress, DiagGradNutsSettings, LowRankNutsSettings, Model, + ProgressCallback, Sampler, SamplerWaitResult, StepSizeAdaptMethod, TransformedNutsSettings, + ZarrAsyncConfig, }; use pyo3::{ - exceptions::PyTimeoutError, - ffi::Py_uintptr_t, + exceptions::{PyTimeoutError, PyValueError}, intern, prelude::*, - types::{PyList, PyTuple}, + types::PyList, }; +use pyo3_arrow::PyRecordBatch; +use pyo3_object_store::AnyObjectStore; use rand::{rng, RngCore}; +use tokio::runtime::Runtime; +use zarrs_object_store::{object_store::limit::LimitStore, AsyncObjectStore}; #[pyclass] struct PyChainProgress(ChainProgress); @@ -449,9 +453,7 @@ impl PyNutsSettings { Settings::Diag(settings) => { Ok(settings.adapt_options.mass_matrix_options.store_mass_matrix) } - Settings::Transforming(_) => { - bail!("Option store_mass_matrix not availbale for transformation adaptation") - } + Settings::Transforming(_) => Ok(false), } } @@ -663,7 +665,7 @@ impl PyNutsSettings { #[setter(step_size_adapt_method)] fn set_step_size_adapt_method(&mut self, method: Py) -> Result<()> { - let method = Python::with_gil(|py| { + let method = Python::attach(|py| { if let Ok(method) = method.extract::(py) { match method.as_str() { "dual_average" => Ok(StepSizeAdaptMethod::DualAverage), @@ -785,8 +787,10 @@ impl PyNutsSettings { } pub(crate) enum SamplerState { - Running(Sampler), - Finished(Option), + RunningZarr(Sampler<()>), + RunningArrow(Sampler>), + FinishedZarr, + FinishedArrow(Vec), Empty, } @@ -856,57 +860,221 @@ impl ProgressType { } } +enum InnerPyStorage { + Zarr(Option), + Arrow, +} + #[pyclass] -struct PySampler(Mutex); +struct PyStorage(InnerPyStorage); #[pymethods] -impl PySampler { +impl PyStorage { #[staticmethod] - fn from_pymc( + fn zarr(object_store: AnyObjectStore) -> Self { + Self(InnerPyStorage::Zarr(Some(object_store))) + } + + #[staticmethod] + fn arrow() -> Self { + Self(InnerPyStorage::Arrow) + } +} + +#[pyclass] +struct PySampler(Mutex<(SamplerState, Runtime)>); + +impl PySampler { + fn new( settings: PyNutsSettings, cores: usize, - model: PyMcModel, + model: M, progress_type: ProgressType, - ) -> PyResult { + store: &mut PyStorage, + ) -> PyResult { let callback = progress_type.into_callback()?; - match settings.inner { - Settings::LowRank(settings) => { - let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler).into())) + let tokio_rt = Runtime::new().context("Failed to create Tokio runtime")?; + match &mut store.0 { + InnerPyStorage::Arrow => { + let storage_config = ArrowConfig::new(); + match settings.inner { + Settings::LowRank(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningArrow(sampler).into(), + tokio_rt, + )))) + } + Settings::Diag(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningArrow(sampler).into(), + tokio_rt, + )))) + } + Settings::Transforming(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningArrow(sampler).into(), + tokio_rt, + )))) + } + } } - Settings::Diag(settings) => { - let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler).into())) + InnerPyStorage::Zarr(store) => { + let object_store = store + .take() + .ok_or_else(|| anyhow!("Can not use storage configuration twice"))? + .into_dyn(); + let object_store = LimitStore::new(object_store, 50); + let store = AsyncObjectStore::new(object_store); + let store = Arc::new(store); + let storage_config = ZarrAsyncConfig::new(tokio_rt.handle().clone(), store); + match settings.inner { + Settings::LowRank(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningZarr(sampler).into(), + tokio_rt, + )))) + } + Settings::Diag(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningZarr(sampler).into(), + tokio_rt, + )))) + } + Settings::Transforming(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningZarr(sampler).into(), + tokio_rt, + )))) + } + } + } + } + } +} + +impl PySampler { + fn wait_inner_arrow( + &self, + mut control: Sampler>, + timeout: Option, + ) -> (PyResult<()>, SamplerState) { + let start_time = Instant::now(); + let step = Duration::from_millis(100); + + loop { + let time_so_far = Instant::now().saturating_duration_since(start_time); + let next_timeout = match timeout { + Some(timeout) => { + let Some(remaining) = timeout.checked_sub(time_so_far) else { + return ( + Err(PyTimeoutError::new_err( + "Timeout while waiting for sampler to finish", + )), + SamplerState::RunningArrow(control), + ); + }; + remaining.min(step) + } + None => step, + }; + + match control.wait_timeout(next_timeout) { + SamplerWaitResult::Trace(trace) => { + return (Ok(()), SamplerState::FinishedArrow(trace)) + } + SamplerWaitResult::Timeout(new_control) => { + control = new_control; + } + SamplerWaitResult::Err(err, trace) => { + return ( + Err(err.into()), + SamplerState::FinishedArrow(trace.unwrap_or_default()), + ) + } } - Settings::Transforming(settings) => { - let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler).into())) + + if let Err(err) = Python::attach(|py| py.check_signals()) { + return (Err(err), SamplerState::RunningArrow(control)); } } } + fn wait_inner_zarr( + &self, + mut control: Sampler<()>, + timeout: Option, + ) -> (PyResult<()>, SamplerState) { + let start_time = Instant::now(); + let step = Duration::from_millis(100); + + loop { + let time_so_far = Instant::now().saturating_duration_since(start_time); + let next_timeout = match timeout { + Some(timeout) => { + let Some(remaining) = timeout.checked_sub(time_so_far) else { + return ( + Err(PyTimeoutError::new_err( + "Timeout while waiting for sampler to finish", + )), + SamplerState::RunningZarr(control), + ); + }; + remaining.min(step) + } + None => step, + }; + + match control.wait_timeout(next_timeout) { + SamplerWaitResult::Trace(_trace) => return (Ok(()), SamplerState::FinishedZarr), + SamplerWaitResult::Timeout(new_control) => { + control = new_control; + } + SamplerWaitResult::Err(err, _trace) => { + return (Err(err.into()), SamplerState::FinishedZarr) + } + } + + if let Err(err) = Python::attach(|py| py.check_signals()) { + return (Err(err), SamplerState::RunningZarr(control)); + } + } + } +} + +#[pymethods] +impl PySampler { + #[staticmethod] + fn from_pymc( + settings: PyNutsSettings, + cores: usize, + model: PyMcModel, + progress_type: ProgressType, + store: &mut PyStorage, + ) -> PyResult { + PySampler::new(settings, cores, model, progress_type, store) + } + #[staticmethod] fn from_stan( settings: PyNutsSettings, cores: usize, model: StanModel, progress_type: ProgressType, + store: &mut PyStorage, ) -> PyResult { - let callback = progress_type.into_callback()?; - match settings.inner { - Settings::LowRank(settings) => { - let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler).into())) - } - Settings::Diag(settings) => { - let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler).into())) - } - Settings::Transforming(settings) => { - let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler).into())) - } - } + PySampler::new(settings, cores, model, progress_type, store) } #[staticmethod] @@ -915,78 +1083,61 @@ impl PySampler { cores: usize, model: PyModel, progress_type: ProgressType, + store: &mut PyStorage, ) -> PyResult { - let callback = progress_type.into_callback()?; - match settings.inner { - Settings::LowRank(settings) => { - let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler).into())) - } - Settings::Diag(settings) => { - let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler).into())) - } - Settings::Transforming(settings) => { - let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler).into())) - } - } + PySampler::new(settings, cores, model, progress_type, store) } fn is_finished(&mut self, py: Python<'_>) -> PyResult { + self.wait(py, Some(0.001))?; py.detach(|| { let guard = &mut self.0.lock().expect("Poisond sampler state mutex"); - let slot = guard.deref_mut(); - - let state = std::mem::replace(slot, SamplerState::Empty); - - let SamplerState::Running(sampler) = state else { - let _ = std::mem::replace(slot, state); - return Ok(true); - }; - - match sampler.wait_timeout(Duration::from_millis(1)) { - SamplerWaitResult::Trace(trace) => { - let _ = std::mem::replace(slot, SamplerState::Finished(Some(trace))); - Ok(true) - } - SamplerWaitResult::Timeout(sampler) => { - let _ = std::mem::replace(slot, SamplerState::Running(sampler)); - Ok(false) - } - SamplerWaitResult::Err(err, trace) => { - let _ = std::mem::replace(slot, SamplerState::Finished(trace)); - Err(err.into()) - } - } + Ok(matches!( + guard.deref_mut().0, + SamplerState::FinishedZarr | SamplerState::FinishedArrow(_) | SamplerState::Empty + )) }) } fn pause(&mut self, py: Python<'_>) -> PyResult<()> { py.detach(|| { - if let SamplerState::Running(ref mut control) = self + match self .0 .lock() - .expect("Poised sampler state mutex") + .expect("Poisond sampler state mutex") .deref_mut() { - control.pause()? + (SamplerState::RunningZarr(control), _) => { + control.pause()?; + return Ok(()); + } + (SamplerState::RunningArrow(control), _) => { + control.pause()?; + return Ok(()); + } + _ => return Ok(()), } - Ok(()) }) } fn resume(&mut self, py: Python<'_>) -> PyResult<()> { py.detach(|| { - if let SamplerState::Running(ref mut control) = self + match self .0 .lock() .expect("Poisond sampler state mutex") .deref_mut() { - control.resume()? + (SamplerState::RunningZarr(control), _) => { + control.resume()?; + return Ok(()); + } + (SamplerState::RunningArrow(control), _) => { + control.resume()?; + return Ok(()); + } + _ => return Ok(()), } - Ok(()) }) } @@ -995,6 +1146,7 @@ impl PySampler { py.detach(|| { let guard = &mut self.0.lock().expect("Poisond sampler state mutex"); let slot = guard.deref_mut(); + let slot = &mut slot.0; let timeout = match timeout_seconds { Some(val) => Some(Duration::try_from_secs_f64(val).context("Invalid timeout")?), @@ -1003,46 +1155,12 @@ impl PySampler { let state = std::mem::replace(slot, SamplerState::Empty); - let SamplerState::Running(mut control) = state else { - let _ = std::mem::replace(slot, state); - return Ok(()); - }; - - let start_time = Instant::now(); - let step = Duration::from_millis(100); - - let (final_state, retval) = loop { - let time_so_far = Instant::now().saturating_duration_since(start_time); - let next_timeout = match timeout { - Some(timeout) => { - let Some(remaining) = timeout.checked_sub(time_so_far) else { - break ( - SamplerState::Running(control), - Err(PyTimeoutError::new_err( - "Timeout while waiting for sampler to finish", - )), - ); - }; - remaining.min(step) - } - None => step, - }; - - match control.wait_timeout(next_timeout) { - SamplerWaitResult::Trace(trace) => { - break (SamplerState::Finished(Some(trace)), Ok(())) - } - SamplerWaitResult::Timeout(new_control) => { - control = new_control; - } - SamplerWaitResult::Err(err, trace) => { - break (SamplerState::Finished(trace), Err(err.into())) - } - } - - if let Err(err) = Python::attach(|py| py.check_signals()) { - break (SamplerState::Running(control), Err(err)); - } + let (retval, final_state) = match state { + SamplerState::FinishedZarr + | SamplerState::FinishedArrow(_) + | SamplerState::Empty => (Ok(()), state), + SamplerState::RunningZarr(control) => self.wait_inner_zarr(control, timeout), + SamplerState::RunningArrow(control) => self.wait_inner_arrow(control, timeout), }; let _ = std::mem::replace(slot, final_state); @@ -1054,101 +1172,160 @@ impl PySampler { py.detach(|| { let guard = &mut self.0.lock().expect("Poisond sampler state mutex"); let slot = guard.deref_mut(); + let slot = &mut slot.0; let state = std::mem::replace(slot, SamplerState::Empty); - let SamplerState::Running(control) = state else { - let _ = std::mem::replace(slot, state); - return Ok(()); - }; - - let (result, trace) = control.abort(); - let _ = std::mem::replace(slot, SamplerState::Finished(trace)); - result?; - Ok(()) + match state { + SamplerState::FinishedZarr + | SamplerState::FinishedArrow(_) + | SamplerState::Empty => { + let _ = std::mem::replace(slot, state); + return Ok(()); + } + SamplerState::RunningZarr(control) => { + let (result, _) = control.abort()?; + let _ = std::mem::replace(slot, SamplerState::FinishedZarr); + if let Some(err) = result { + Err(err)?; + } + Ok(()) + } + SamplerState::RunningArrow(control) => { + let (result, trace) = control.abort()?; + let _ = std::mem::replace(slot, SamplerState::FinishedArrow(trace)); + if let Some(err) = result { + Err(err)?; + } + Ok(()) + } + } }) } - fn extract_results<'py>(&mut self, py: Python<'py>) -> PyResult> { - let guard = &mut self.0.lock().expect("Poisond sampler state mutex"); - let slot = guard.deref_mut(); - - let state = std::mem::replace(slot, SamplerState::Empty); - - let SamplerState::Finished(trace) = state else { - let _ = std::mem::replace(slot, state); - return Err(anyhow::anyhow!("Sampler is not finished"))?; - }; - - let Some(trace) = trace else { - return Err(anyhow::anyhow!( - "Sampler failed and did not produce a trace" - ))?; - }; + fn is_empty(&self) -> bool { + matches!( + self.0.lock().expect("Poisoned sampler state lock").deref(), + (SamplerState::Empty, _) + ) + } - trace_to_list(trace, py) + fn flush<'py>(&mut self, py: Python<'py>) -> PyResult<()> { + match self + .0 + .lock() + .expect("Poisond sampler state mutex") + .deref_mut() + .0 + { + SamplerState::FinishedZarr => Ok(()), + SamplerState::FinishedArrow(_) => Ok(()), + SamplerState::Empty => Ok(()), + SamplerState::RunningZarr(ref mut control) => { + py.detach(|| control.flush())?; + Ok(()) + } + SamplerState::RunningArrow(ref mut control) => { + py.detach(|| control.flush())?; + Ok(()) + } + } } - fn is_empty(&self) -> bool { - match self.0.lock().expect("Poisoned sampler state lock").deref() { - SamplerState::Running(_) => false, - SamplerState::Finished(_) => false, - SamplerState::Empty => true, + fn inspect<'py>(&self, py: Python<'py>) -> PyResult> { + match &mut self + .0 + .lock() + .expect("Poisond sampler state mutex") + .deref_mut() + .0 + { + SamplerState::FinishedZarr => Ok(Some(PyTrace(InnerPyTrace::Zarr))), + SamplerState::FinishedArrow(trace) => { + Ok(Some(PyTrace(InnerPyTrace::Arrow(Some(trace.clone()))))) + } + SamplerState::Empty => Ok(None), + SamplerState::RunningZarr(control) => { + let (res, _) = py.detach(|| control.inspect())?; + if let Some(err) = res { + return Err(err.into()); + } + Ok(Some(PyTrace(InnerPyTrace::Zarr))) + } + SamplerState::RunningArrow(control) => { + let (res, trace) = py.detach(|| control.inspect())?; + if let Some(err) = res { + return Err(err.into()); + } + Ok(Some(PyTrace(InnerPyTrace::Arrow(Some(trace))))) + } } } - fn inspect<'py>(&mut self, py: Python<'py>) -> PyResult> { - let trace = py.detach(|| { - let mut guard = self.0.lock().unwrap(); - let SamplerState::Running(ref mut sampler) = guard.deref_mut() else { - return Err(anyhow::anyhow!("Sampler is not running"))?; - }; + fn take_results(&mut self) -> PyResult { + let state = &mut self.0.lock().expect("Poisond sampler state mutex"); - sampler.inspect_trace() - })?; - trace_to_list(trace, py) + match &state.0 { + SamplerState::FinishedZarr => { + let _ = std::mem::replace(&mut state.0, SamplerState::Empty); + Ok(PyTrace(InnerPyTrace::Zarr)) + } + SamplerState::FinishedArrow(_) => { + let state = std::mem::replace(&mut state.0, SamplerState::Empty); + let SamplerState::FinishedArrow(trace) = state else { + unreachable!(); + }; + Ok(PyTrace(InnerPyTrace::Arrow(Some(trace)))) + } + SamplerState::Empty => Err(PyErr::new::( + "Sampler has no results to take", + )), + SamplerState::RunningZarr(_) => Err(PyErr::new::( + "Sampler is still running, can only take results after it has finished", + )), + SamplerState::RunningArrow(_) => Err(PyErr::new::( + "Sampler is still running, can only take results after it has finished", + )), + } } } -fn trace_to_list(trace: Trace, py: Python<'_>) -> PyResult> { - let list = PyList::new( - py, - trace - .chains - .into_iter() - .map(|chain| { - Ok(PyTuple::new( - py, - [ - export_array(py, chain.draws)?, - export_array(py, chain.stats)?, - ] - .into_iter(), - )?) - }) - .collect::>>()?, - )?; - Ok(list) +enum InnerPyTrace { + Zarr, + Arrow(Option>), } -fn export_array(py: Python<'_>, data: Arc) -> PyResult> { - let pa = py.import("pyarrow")?; - let array = pa.getattr("Array")?; - - let data = data.into_data(); - - let (data, schema) = arrow::ffi::to_ffi(&data).context("Could not convert to arrow ffi")?; +#[pyclass] +pub struct PyTrace(InnerPyTrace); - let data = array - .call_method1( - "_import_from_c", - ( - (&data as *const _ as Py_uintptr_t).into_pyobject(py)?, - (&schema as *const _ as Py_uintptr_t).into_pyobject(py)?, - ), - ) - .context("Could not import arrow trace in python")?; - Ok(data.unbind()) +#[pymethods] +impl PyTrace { + fn is_zarr(&self) -> bool { + matches!(self.0, InnerPyTrace::Zarr) + } + + fn is_arrow(&self) -> bool { + matches!(self.0, InnerPyTrace::Arrow(_)) + } + + fn get_arrow_trace(&mut self) -> PyResult<(Vec, Vec)> { + match &mut self.0 { + InnerPyTrace::Zarr => Err(PyErr::new::( + "Trace is not stored in Arrow format", + )), + InnerPyTrace::Arrow(trace) => Ok(trace + .take() + .ok_or_else(|| PyValueError::new_err("The trace was already taken"))? + .into_iter() + .map(|array| { + ( + PyRecordBatch::new(array.posterior), + PyRecordBatch::new(array.sample_stats), + ) + }) + .collect()), + } + } } #[pyclass] @@ -1403,10 +1580,12 @@ pub fn _lib(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add("__version__", env!("CARGO_PKG_VERSION"))?; + pyo3_object_store::register_store_module(m.py(), m, "_lib", "store")?; + pyo3_object_store::register_exceptions_module(m.py(), m, "_lib", "exceptions")?; Ok(()) } diff --git a/tests/reference/test_deterministic_sampling_jax.txt b/tests/reference/test_deterministic_sampling_jax.txt index 114966e..0c6237a 100644 --- a/tests/reference/test_deterministic_sampling_jax.txt +++ b/tests/reference/test_deterministic_sampling_jax.txt @@ -1,200 +1,200 @@ -0.941959 -0.559649 -0.534203 -0.561444 -0.561444 -0.418685 -0.827896 -0.847014 -0.738508 -0.961291 -0.923931 -1.00584 -1.16386 -1.10065 -1.6348 -1.13139 -0.993458 -0.993458 -0.966241 -1.10922 -1.10922 -1.05723 -1.05723 -2.32492 -0.0700824 -0.0860656 -1.36431 -0.829624 -0.584658 -0.531506 -0.507961 -0.543701 -0.510104 -2.46898 -0.820341 -0.490474 -0.343958 -0.300549 -2.60267 -0.588131 -0.430013 -0.618032 -1.27527 -1.80449 -1.80449 -0.855217 -0.556106 -1.77619 -2.03761 -1.02106 -0.774811 -1.78438 -1.61398 -0.712683 -1.04966 -1.17936 -1.5425 -1.5425 -1.26262 -1.39659 -0.337024 -0.177694 -0.0424286 -0.180403 -0.140553 -0.367095 -0.348732 -0.341436 -1.82764 -0.692738 -0.629186 -0.245706 -0.732305 -0.56873 -0.498757 -0.204131 -0.417031 -0.184895 -0.208768 -0.238139 -1.95089 -1.95089 -0.593379 -0.593379 -0.750063 -0.69929 -0.490359 -0.478709 -0.361632 -0.346159 -0.728965 -1.58228 -0.985676 -1.58468 -0.709012 -0.700483 -0.805006 -1.70347 -1.26293 -1.24837 -0.23989 -0.881025 -1.39084 -1.37812 -0.969265 -0.969265 -0.938487 -0.846447 -1.61945 -0.108473 -0.173496 -0.897353 -0.455899 -0.571886 -0.891672 -0.891672 -0.864419 -0.739099 -1.49009 -1.49009 -0.385499 -0.228701 -1.83156 -1.83156 -0.947635 -0.805623 -0.714762 -0.853477 -1.45906 -0.908818 -0.540951 -1.40995 -1.22564 -0.26496 -0.159994 -0.423836 -0.350158 -0.388884 -1.39507 -0.727701 -1.80674 -0.466389 -1.61574 -1.61574 -0.42774 -0.217983 -0.14579 -1.01321 -1.01321 -1.19713 -0.390791 -0.223687 -0.149019 -0.103866 -0.153768 -0.12942 -0.346371 -0.814553 -2.41042 -0.42739 -0.322291 -0.248911 -0.854404 -1.35372 -1.35372 -2.00546 -0.0457881 -0.0415644 -0.0797551 -0.0913076 -0.070948 -0.00993872 -0.421448 -0.550377 -0.609387 -0.490487 -2.6607 -0.32804 -0.385999 -0.497294 -1.67109 -1.14328 -1.14328 -0.903063 -0.903063 -0.903063 -0.691269 -2.00151 -0.587672 -0.79679 -1.35563 -0.598471 -0.681826 -0.818296 -1.14265 -0.113094 -0.250861 -0.284491 -0.00420445 -0.00566936 +0.00293185 +0.00896809 +0.00768812 +0.00190588 +0.00320792 +0.00740496 +0.0958083 +0.109615 +0.0105327 +0.0192266 +0.0214682 +0.0218331 +0.0585783 +0.0251717 +0.0280682 +0.0729859 +0.425133 +0.457443 +2.52983 +1.15484 +1.15484 +1.18416 +0.104273 +1.20247 +1.1064 +1.67005 +1.05586 +2.55089 +1.83339 +0.971751 +0.470398 +0.284519 +0.253759 +2.29193 +1.29672 +1.29672 +0.432495 +0.411462 +1.10822 +1.10822 +0.698466 +1.01384 +0.422528 +0.471828 +0.354965 +0.370006 +0.932942 +0.924415 +0.821473 +2.34528 +1.8362 +0.329965 +0.427145 +0.995745 +1.17653 +0.937676 +0.937676 +0.71568 +0.916428 +1.05491 +0.479239 +0.488732 +1.07755 +1.05904 +0.269731 +0.197423 +0.303258 +0.0738098 +0.0535444 +0.0704248 +0.083286 +0.158385 +0.149845 +0.416708 +0.349628 +0.31117 +0.304837 +0.0724371 +1.5569 +1.20564 +2.12525 +0.303531 +0.712031 +0.844468 +0.434198 +0.277141 +0.593882 +0.648409 +1.02533 +0.692478 +0.367875 +0.316403 +0.351662 +0.117319 +1.85435 +0.413934 +0.409025 +0.661536 +0.650092 +0.766712 +0.594595 +0.501872 +0.515377 +0.236945 +0.689338 +2.99054 +0.172018 +0.0528735 +0.0579658 +0.0581689 +0.0497977 +0.063146 +0.311101 +0.347411 +0.763051 +0.734721 +1.17926 +1.02504 +1.02504 +0.645771 +0.970169 +1.20163 +1.1179 +0.385697 +0.410691 +0.471671 +0.540587 +0.250604 +0.254267 +0.220907 +0.673968 +0.265055 +0.766607 +1.50436 +1.58131 +0.719291 +0.958127 +0.546963 +1.60432 +1.60432 +1.45897 +0.717682 +0.668208 +0.71339 +0.276479 +0.255967 +0.799242 +1.32658 +0.724295 +0.36085 +0.217894 +0.254816 +0.125993 +1.31909 +1.56969 +0.750499 +1.11993 +1.87465 +1.472 +0.950422 +0.754906 +0.270587 +0.231469 +1.19634 +1.19634 +1.19634 +1.51182 +1.34804 +1.42657 +0.544703 +1.66443 +1.66443 +1.14928 +1.10046 +1.16557 +1.5537 +0.629914 +0.880496 +0.525169 +0.312335 +0.797038 +0.733363 +1.6496 +0.0602699 +0.0840557 +0.107319 +0.0324205 +0.0929894 +0.226149 +0.202803 +0.217807 +0.366175 +0.158146 +0.160235 +0.175013 +0.148804 +0.526506 +0.785313 +1.23336 +0.733001 diff --git a/tests/reference/test_deterministic_sampling_numba.txt b/tests/reference/test_deterministic_sampling_numba.txt index 6426e8c..5bea297 100644 --- a/tests/reference/test_deterministic_sampling_numba.txt +++ b/tests/reference/test_deterministic_sampling_numba.txt @@ -1,200 +1,200 @@ -0.862203 -0.743827 -0.985284 -0.864159 -1.11537 -1.46228 -1.46228 -0.731645 -0.618394 -0.70658 -1.58816 -1.58816 -1.58816 -1.58816 -1.02597 -1.02597 -2.38965 -0.0442154 -0.0556998 -1.20147 -0.878239 -0.595919 -0.542086 -0.520452 -0.56279 -0.539904 -0.129453 -0.136407 -0.408806 -0.34263 -0.929525 -0.947864 -0.947864 -1.94444 -0.911973 -0.429576 -0.776378 -0.452981 -0.985476 -1.74745 -1.74095 -1.74095 -0.9855 -0.886535 -0.617313 -0.86405 -2.00577 -0.839407 -0.745118 -1.49611 -1.74491 -1.40854 -0.631877 -1.95302 -1.01379 -1.1063 -0.930275 -0.315935 -0.225544 -0.136821 -0.180021 -0.498635 -0.462448 -0.445633 -0.0878991 -0.105731 -0.355683 -0.750934 -0.750934 -0.874486 -1.15119 -0.657067 -0.500027 -1.28332 -1.28332 -0.919994 -1.09658 -1.73803 -1.13439 -1.21956 -0.643106 -0.329788 -0.456239 -0.596018 -0.180103 -0.388767 -1.03772 -1.03192 -1.03192 -1.04759 -1.04759 -1.13558 -0.673716 -0.871073 -0.50739 -0.625146 -0.999657 -1.00779 -2.06182 -0.707917 -0.107437 -0.0772623 -0.10719 -0.36616 -0.14863 -0.0333724 -0.0295763 -0.0205304 -0.127619 -0.164319 -0.241143 -0.376838 -0.87369 -1.64165 -0.106128 -0.170459 -0.916833 -0.458599 -0.575215 -0.894488 -0.894488 -0.865427 -0.739365 -0.681649 -0.72888 -1.38352 -1.38352 -2.28238 -2.28238 -2.28238 -0.567775 -0.41864 -1.41709 -1.41709 -1.41709 -1.41709 -0.600311 -0.598689 -0.627731 -0.460137 -1.86219 -1.81783 -1.78092 -1.78092 -1.78092 -0.492732 -1.37953 -1.16762 -0.597573 -0.627465 -0.617661 -0.649115 -0.608255 -0.685365 -0.685365 -0.685365 -0.685365 -0.685365 -2.2227 -0.971606 -0.4219 -0.879055 -0.74434 -2.08679 -1.34952 -1.34952 -1.34952 -1.34952 -0.513284 -0.16734 -0.174037 -0.626756 -0.913504 -0.271423 -0.200176 -0.132462 -0.465497 -0.406755 -0.493296 -0.0175891 -0.0234891 -0.0220327 -0.132404 -0.0788943 -0.0949265 -0.103031 -0.0760492 -0.377155 -1.90599 -1.58063 -1.58063 -1.17038 -0.556726 -0.55085 -0.24632 -0.375951 -0.339243 -0.747524 -1.82921 -0.794344 +0.00293185 +0.00896808 +0.00768811 +0.00190587 +0.00320792 +0.00740495 +0.0958081 +0.109615 +0.0105327 +0.0192265 +0.0214681 +0.021833 +0.0585782 +0.0251717 +0.0280681 +0.0729857 +0.425132 +0.457442 +2.52983 +1.15483 +1.15483 +1.18416 +0.104274 +1.20247 +1.1064 +1.67005 +1.05586 +2.55089 +1.8334 +0.971751 +0.470398 +0.284519 +0.25376 +2.29193 +1.29672 +1.29672 +0.432495 +0.411463 +1.10822 +1.10822 +0.698467 +1.01384 +0.422528 +0.471828 +0.354965 +0.370006 +0.932942 +0.924415 +0.821473 +2.34528 +1.8362 +0.329965 +0.427145 +0.995744 +1.17653 +0.937677 +0.937677 +0.71568 +0.916428 +1.05491 +0.479239 +0.488732 +1.07755 +1.05904 +0.269731 +0.197423 +0.303257 +0.0738098 +0.0535443 +0.0704248 +0.083286 +0.158385 +0.149844 +0.416707 +0.349628 +0.31117 +0.304836 +0.072437 +1.5569 +1.20564 +2.12525 +0.303531 +0.712031 +0.844469 +0.434198 +0.277141 +0.593882 +0.648409 +1.02533 +0.692478 +0.367875 +0.316403 +0.351662 +0.117319 +1.85435 +0.413932 +0.409023 +0.661534 +0.650092 +0.766712 +0.594595 +0.501872 +0.515377 +0.236945 +0.689338 +2.99054 +0.172018 +0.0528735 +0.0579658 +0.0581689 +0.0497977 +0.063146 +0.311101 +0.347411 +0.763051 +0.734721 +1.17926 +1.02504 +1.02504 +0.645771 +0.970169 +1.20163 +1.1179 +0.385697 +0.410691 +0.471671 +0.540587 +0.250604 +0.254267 +0.220907 +0.673968 +0.265055 +0.766607 +1.50436 +1.58131 +0.719291 +0.958127 +0.546963 +1.60432 +1.60432 +1.45897 +0.717682 +0.668208 +0.71339 +0.276479 +0.255967 +0.799242 +1.32658 +0.724295 +0.36085 +0.217894 +0.254816 +0.125993 +1.31909 +1.56969 +0.750499 +1.11993 +1.87465 +1.472 +0.950422 +0.754906 +0.270587 +0.231469 +1.19634 +1.19634 +1.19634 +1.51182 +1.34804 +1.42657 +0.544703 +1.66443 +1.66443 +1.14928 +1.10046 +1.16557 +1.5537 +0.629914 +0.880496 +0.525169 +0.312335 +0.797038 +0.733363 +1.6496 +0.0602699 +0.0840557 +0.107319 +0.0324205 +0.0929894 +0.226149 +0.202803 +0.217807 +0.366175 +0.158146 +0.160235 +0.175013 +0.148804 +0.526506 +0.785313 +1.23336 +0.733001 diff --git a/tests/reference/test_deterministic_sampling_stan.txt b/tests/reference/test_deterministic_sampling_stan.txt index 3bed2a2..dd85d53 100644 --- a/tests/reference/test_deterministic_sampling_stan.txt +++ b/tests/reference/test_deterministic_sampling_stan.txt @@ -1,2 +1,2 @@ -1.21572 1.03376 1.60518 1.60518 1.59553 1.35023 0.761056 1.41688 1.41688 1.41688 -0.252389 0.999663 0.999663 0.999663 0.740026 0.387763 0.944247 0.289785 1.52909 0.683129 +0.754944 0.746804 0.687211 1.56984 2.15413 2.15413 0.186138 1.19976 1.19976 0.818806 +0.185979 1.20179 0.236474 0.240597 0.416886 0.529295 0.574728 0.59912 1.02193 0.902788 diff --git a/tests/test_pymc.py b/tests/test_pymc.py index e710aae..9014124 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -88,7 +88,12 @@ def test_low_rank(backend, gradient_backend): model, backend=backend, gradient_backend=gradient_backend ) trace = nutpie.sample(compiled, chains=1, low_rank_modified_mass_matrix=True) - trace.posterior.a # noqa: B018 + + assert "mass_matrix_eigvals" not in trace.sample_stats + trace = nutpie.sample( + compiled, chains=1, low_rank_modified_mass_matrix=True, store_mass_matrix=True + ) + assert "mass_matrix_eigvals" in trace.sample_stats @pytest.mark.pymc @@ -421,7 +426,7 @@ def test_missing(backend, gradient_backend): @pytest.mark.pymc -@pytest.mark.array_compare +@pytest.mark.array_compare(atol=1e-4, rtol=1e-4) def test_deterministic_sampling_numba(): with pm.Model() as model: pm.HalfNormal("a") @@ -432,7 +437,7 @@ def test_deterministic_sampling_numba(): @pytest.mark.pymc -@pytest.mark.array_compare +@pytest.mark.array_compare(atol=1e-4, rtol=1e-4) def test_deterministic_sampling_jax(): with pm.Model() as model: pm.HalfNormal("a") diff --git a/tests/test_stan.py b/tests/test_stan.py index 53b6b40..66cc4d5 100644 --- a/tests/test_stan.py +++ b/tests/test_stan.py @@ -278,7 +278,7 @@ def test_stan_flow(): # TODO: There are small numerical differences between linux and windows. # We should figure out if they originate in stan or in nutpie. -@pytest.mark.array_compare(atol=1e-4) +@pytest.mark.array_compare(atol=1e-4, rtol=1e-4) @pytest.mark.stan def test_deterministic_sampling_stan(): model = """ @@ -296,6 +296,6 @@ def test_deterministic_sampling_stan(): compiled_model = nutpie.compile_stan_model(code=model) trace = nutpie.sample(compiled_model, chains=2, seed=123, draws=100, tune=100) trace2 = nutpie.sample(compiled_model, chains=2, seed=123, draws=100, tune=100) - np.testing.assert_allclose(trace.posterior.a.values, trace2.posterior.a.values) - np.testing.assert_allclose(trace.posterior.b.values, trace2.posterior.b.values) + np.testing.assert_array_max_ulp(trace.posterior.a.values, trace2.posterior.a.values) + np.testing.assert_array_max_ulp(trace.posterior.b.values, trace2.posterior.b.values) return trace.posterior.a.isel(draw=slice(None, 10)).values From 121c81c2f1b1d61f583a7590e1cf9cbb7fc37b6b Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 8 Oct 2025 19:02:32 +0200 Subject: [PATCH 09/12] feat: add zarr_store argument to write trace while sampling --- Cargo.lock | 8 ++++- Cargo.toml | 2 +- docs/sampling-options.qmd | 56 ++++++++++++++++++++++++++++++++++- pyproject.toml | 2 ++ python/nutpie/__init__.py | 9 +++++- python/nutpie/compile_stan.py | 1 - python/nutpie/sample.py | 8 +++-- tests/test_pymc.py | 25 ++++++++++++++++ 8 files changed, 103 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2f93395..0752e67 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2240,6 +2240,8 @@ dependencies = [ [[package]] name = "nuts-derive" version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64eac5046c75ced9bdaede15ebc30c4ce982a13e75032ae8d5c1312d1e05d82e" dependencies = [ "nuts-storable", "proc-macro2", @@ -2249,7 +2251,9 @@ dependencies = [ [[package]] name = "nuts-rs" -version = "0.16.1" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cbacdcc02ea7e33cf6c3389b1c4944d028d5ca038e503e656856a9407c5b2b6" dependencies = [ "anyhow", "arrow", @@ -2273,6 +2277,8 @@ dependencies = [ [[package]] name = "nuts-storable" version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb6cf9fc84ca313648ddb112f8728eb2f9531f2e4533959dd01127eb34290b5b" [[package]] name = "object" diff --git a/Cargo.toml b/Cargo.toml index a09f519..fd06e25 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ name = "_lib" crate-type = ["cdylib"] [dependencies] -nuts-rs = { version = "0.16.1", features = ["zarr", "arrow"] } +nuts-rs = { version = "0.17.0", features = ["zarr", "arrow"] } numpy = "0.26.0" rand = "0.9.0" thiserror = "2.0.3" diff --git a/docs/sampling-options.qmd b/docs/sampling-options.qmd index ef6bfe4..458c6df 100644 --- a/docs/sampling-options.qmd +++ b/docs/sampling-options.qmd @@ -25,7 +25,7 @@ trace = nutpie.sample( tune=500, # Number of warmup draws for adaptation chains=6, # Number of independent chains cores=None, # Number chains that are allowed to run simultainiously - seed=12345 # Random seed for reproducibility + seed=12345 # Random seed for reproducibility ) ``` @@ -143,6 +143,60 @@ trace = nutpie.sample( ) ``` +## Zarr Storage (Experimental) + +Nutpie includes experimental support for writing traces directly to zarr storage, which can be useful for large traces that don't fit in memory or for distributed storage scenarios. The zarr format provides efficient, chunked, compressed storage for multi-dimensional arrays. + +### Basic Usage + +You can write traces directly to zarr storage by providing a `zarr_store` parameter: + +```python +import nutpie +import pymc as pm + +with pm.Model() as model: + pm.HalfNormal("a") + +compiled = nutpie.compile_pymc_model(model, backend="numba") + +# Create a local zarr store +path = "trace.zarr" +store = nutpie.zarr_store.LocalStore(path) + +trace = nutpie.sample( + compiled, + chains=2, + seed=123, + draws=100, + tune=100, + zarr_store=store +) +``` + +### Memory Considerations + +When using zarr storage, the trace object supports lazy loading: + +```python +# The trace is not loaded into memory by default +posterior_data = trace.posterior.a # Lazy access + +# Explicitly load the entire trace into memory (optional) +loaded_trace = trace.load() +posterior_data = loaded_trace.posterior.a # In-memory access +``` + +### Available Store Types + +Nutpie supports several zarr store backends: + +- `nutpie.zarr_store.LocalStore(path)` - Local filesystem storage +- `nutpie.zarr_store.S3Store(...)` - Amazon S3 storage +- `nutpie.zarr_store.GCSStore(...)` - Google Cloud Storage +- `nutpie.zarr_store.AzureStore(...)` - Azure Blob Storage +- `nutpie.zarr_store.HTTPStore(...)` - HTTP-based storage + ## Progress Monitoring Customize the sampling progress display: diff --git a/pyproject.toml b/pyproject.toml index 822d2e4..84e5e33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,8 @@ dependencies = [ "pandas >= 2.0", "xarray >= 2025.01.2", "arviz >= 0.20.0", + "obstore >= 0.8.0", + "zarr >= 3.1.0", ] dynamic = ["version"] diff --git a/python/nutpie/__init__.py b/python/nutpie/__init__.py index 443f099..4084b57 100644 --- a/python/nutpie/__init__.py +++ b/python/nutpie/__init__.py @@ -2,6 +2,13 @@ from nutpie.compile_pymc import compile_pymc_model from nutpie.compile_stan import compile_stan_model from nutpie.sample import sample +from nutpie._lib import store as zarr_store __version__: str = _lib.__version__ -__all__ = ["__version__", "compile_pymc_model", "compile_stan_model", "sample"] +__all__ = [ + "__version__", + "compile_pymc_model", + "compile_stan_model", + "sample", + "zarr_store", +] diff --git a/python/nutpie/compile_stan.py b/python/nutpie/compile_stan.py index 35dfa18..82de950 100644 --- a/python/nutpie/compile_stan.py +++ b/python/nutpie/compile_stan.py @@ -4,7 +4,6 @@ from pathlib import Path from typing import Any, Optional -import pandas as pd from numpy.typing import NDArray from nutpie import _lib diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index 5b25404..edc80f0 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -778,10 +778,12 @@ def sample( transform_adapt: bool, default=False Use the experimental transform adaptation algorithm during tuning. - zarr_store: nutpie.store.Store - A store created using nutpie.store to store the samples + zarr_store: nutpie.zarr_store.* + A store created using nutpie.zarr_store to store the samples in. If None (default), the samples will be stored in - memory using an arrow table. + memory using an arrow table. This can be used to write + the trace directly into a zarr store, for instance + on disk or to S3 or GCS. **kwargs Pass additional arguments to nutpie._lib.PySamplerArgs diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 9014124..8d697b0 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -445,3 +445,28 @@ def test_deterministic_sampling_jax(): compiled = nutpie.compile_pymc_model(model, backend="jax", gradient_backend="jax") trace = nutpie.sample(compiled, chains=2, seed=123, draws=100, tune=100) return trace.posterior.a.values.ravel() + + +@pytest.mark.pymc +def test_zarr_store(tmp_path): + with pm.Model() as model: + pm.HalfNormal("a") + + compiled = nutpie.compile_pymc_model(model, backend="numba") + + path = tmp_path / "trace.zarr" + path.mkdir() + store = nutpie.zarr_store.LocalStore(str(path)) + trace = nutpie.sample( + compiled, chains=2, seed=123, draws=100, tune=100, zarr_store=store + ) + trace.load().posterior.a # noqa: B018 + + +@pytest.fixture +def tmp_path(): + import tempfile + from pathlib import Path + + with tempfile.TemporaryDirectory() as tmpdirname: + yield Path(tmpdirname) From 0738b93ab1a040ab9d564f43e7c4bef764d49ad0 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 8 Oct 2025 19:47:50 +0200 Subject: [PATCH 10/12] chore(release): Bump version --- Cargo.lock | 2 +- Cargo.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0752e67..11948d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2210,7 +2210,7 @@ dependencies = [ [[package]] name = "nutpie" -version = "0.15.2" +version = "0.16.0" dependencies = [ "anyhow", "arrow", diff --git a/Cargo.toml b/Cargo.toml index fd06e25..b8410b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "nutpie" -version = "0.15.2" +version = "0.16.0" authors = [ "Adrian Seyboldt ", "PyMC Developers ", @@ -10,7 +10,7 @@ license = "MIT" repository = "https://github.com/pymc-devs/nutpie" keywords = ["statistics", "bayes"] description = "Python wrapper for nuts-rs -- a NUTS sampler written in Rust." -rust-version = "1.90" +rust-version = "1.89" [features] extension-module = ["pyo3/extension-module"] From 025445d947e6201cbeb69de3af53eae595f5580f Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 8 Oct 2025 19:47:50 +0200 Subject: [PATCH 11/12] doc: fix typo in pymc usage docs --- docs/pymc-usage.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/pymc-usage.qmd b/docs/pymc-usage.qmd index 428143c..a045490 100644 --- a/docs/pymc-usage.qmd +++ b/docs/pymc-usage.qmd @@ -129,7 +129,7 @@ pixi add jax We can select the backend by passing the `backend` argument to the `compile_pymc_model`: ```python -compiled_jax = nutpie.compiled_pymc_model(model, backend="jax") +compiled_jax = nutpie.compile_pymc_model(model, backend="jax") trace = nutpie.sample(compiled_jax) ``` From c009f22bd3769a0570074f0318edac01e48bb903 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 8 Oct 2025 19:47:50 +0200 Subject: [PATCH 12/12] chore(release): update changelog --- CHANGELOG.md | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 356c9c2..c76ed4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,46 @@ All notable changes to this project will be documented in this file. +## [0.16.0] - 2025-10-08 + +### Bug Fixes + +- Keep python thread handle alive (Adrian Seyboldt) + +- No errors for unused parameters (Adrian Seyboldt) + + +### Documentation + +- Fix typo in pymc usage docs (Adrian Seyboldt) + + +### Features + +- Support step size adaptation method (Adrian Seyboldt) + +- Add argument for mindepth (Adrian Seyboldt) + +- Support free-threaded python build (Adrian Seyboldt) + +- Use new nuts-rs storage interface (Adrian Seyboldt) + +- Add zarr_store argument to write trace while sampling (Adrian Seyboldt) + + +### Miscellaneous Tasks + +- Bump actions/checkout from 4 to 5 (dependabot[bot]) + +- Bump actions/download-artifact from 4 to 5 (dependabot[bot]) + +- Bump actions/setup-python from 5 to 6 (#240) (dependabot[bot]) + +- Bump actions/attest-build-provenance from 2 to 3 (#239) (dependabot[bot]) + +- Update pyo3 (Adrian Seyboldt) + + ## [0.15.2] - 2025-07-16 ### Features