Skip to content

Commit 40a04a3

Browse files
committed
feat: generalize sample and stats storage
1 parent 7ada9be commit 40a04a3

34 files changed

+4390
-1752
lines changed

Cargo.toml

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ authors = [
55
"Adrian Seyboldt <[email protected]>",
66
"PyMC Developers <[email protected]>",
77
]
8-
edition = "2021"
8+
edition = "2024"
99
license = "MIT"
1010
repository = "https://github.com/pymc-devs/nuts-rs"
1111
keywords = ["statistics", "bayes"]
@@ -22,22 +22,39 @@ rand = { version = "0.9.0", features = ["small_rng"] }
2222
rand_distr = "0.5.0"
2323
itertools = "0.14.0"
2424
thiserror = "2.0.3"
25-
arrow = { version = "56.1.0", default-features = false, features = ["ffi"] }
2625
rand_chacha = "0.9.0"
2726
anyhow = "1.0.72"
2827
faer = { version = "0.22.6", default-features = false, features = ["linalg"] }
2928
pulp = "0.21.4"
3029
rayon = "1.10.0"
30+
zarrs = { version = "0.21.0", features = [
31+
"filesystem",
32+
"gzip",
33+
"sharding",
34+
], optional = true }
35+
ndarray = { version = "0.16.1", optional = true }
36+
nuts-derive = { path = "./nuts-derive" }
37+
nuts-storable = { path = "./nuts-storable" }
38+
serde = { version = "1.0.219", features = ["derive"] }
39+
serde_json = "1.0"
3140

3241
[dev-dependencies]
3342
proptest = "1.6.0"
3443
pretty_assertions = "1.4.0"
3544
criterion = "0.7.0"
3645
nix = { version = "0.30.0", features = ["sched"] }
3746
approx = "0.5.1"
38-
ndarray = "0.16.1"
3947
equator = "0.4.2"
48+
serde_json = "1.0"
49+
ndarray = "0.16.1"
50+
tempfile = "3.0"
51+
52+
[features]
53+
zarr = ["dep:zarrs"]
54+
ndarray = ["dep:ndarray"]
4055

4156
[[bench]]
4257
name = "sample"
4358
harness = false
59+
60+
[workspace]

benches/sample.rs

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
use std::hint::black_box;
22

3-
use criterion::{criterion_group, criterion_main, Criterion};
4-
use nix::sched::{sched_setaffinity, CpuSet};
3+
use criterion::{Criterion, criterion_group, criterion_main};
4+
use nix::sched::{CpuSet, sched_setaffinity};
55
use nix::unistd::Pid;
66
use nuts_rs::{Chain, CpuLogpFunc, CpuMath, LogpError, Math, Settings};
7+
use nuts_storable::HasDims;
78
use rand::SeedableRng;
89
use rayon::ThreadPoolBuilder;
910
use thiserror::Error;
@@ -22,11 +23,20 @@ impl LogpError for PosteriorLogpError {
2223
}
2324
}
2425

26+
impl HasDims for PosteriorDensity {
27+
fn dim_sizes(&self) -> std::collections::HashMap<String, u64> {
28+
vec![("unconstrained_parameter".to_string(), self.dim() as u64)]
29+
.into_iter()
30+
.collect()
31+
}
32+
}
33+
2534
impl CpuLogpFunc for PosteriorDensity {
2635
type LogpError = PosteriorLogpError;
36+
type ExpandedVector = Vec<f64>;
2737

2838
// Only used for transforming adaptation.
29-
type TransformParams = ();
39+
type FlowParameters = ();
3040

3141
// We define a 10 dimensional normal distribution
3242
fn dim(&self) -> usize {
@@ -48,6 +58,17 @@ impl CpuLogpFunc for PosteriorDensity {
4858
.sum();
4959
return Ok(logp);
5060
}
61+
62+
fn expand_vector<R>(
63+
&mut self,
64+
_rng: &mut R,
65+
array: &[f64],
66+
) -> Result<Self::ExpandedVector, nuts_rs::CpuMathError>
67+
where
68+
R: rand::Rng + ?Sized,
69+
{
70+
Ok(array.to_vec())
71+
}
5172
}
5273

5374
fn make_sampler(dim: usize) -> impl Chain<CpuMath<PosteriorDensity>> {

examples/adam_adaptation.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use nuts_rs::{
77
AdamOptions, Chain, CpuLogpFunc, CpuMath, DiagGradNutsSettings, LogpError, Settings,
88
StepSizeAdaptMethod,
99
};
10+
use nuts_storable::HasDims;
1011
use thiserror::Error;
1112

1213
// Define a function that computes the unnormalized posterior density
@@ -23,11 +24,20 @@ impl LogpError for PosteriorLogpError {
2324
}
2425
}
2526

27+
impl HasDims for PosteriorDensity {
28+
fn dim_sizes(&self) -> std::collections::HashMap<String, u64> {
29+
vec![("unconstrained_parameter".to_string(), self.dim() as u64)]
30+
.into_iter()
31+
.collect()
32+
}
33+
}
34+
2635
impl CpuLogpFunc for PosteriorDensity {
2736
type LogpError = PosteriorLogpError;
37+
type ExpandedVector = Vec<f64>;
2838

2939
// Only used for transforming adaptation.
30-
type TransformParams = ();
40+
type FlowParameters = ();
3141

3242
// We define a 10 dimensional normal distribution
3343
fn dim(&self) -> usize {
@@ -49,6 +59,17 @@ impl CpuLogpFunc for PosteriorDensity {
4959
.sum();
5060
return Ok(logp);
5161
}
62+
63+
fn expand_vector<R>(
64+
&mut self,
65+
_rng: &mut R,
66+
array: &[f64],
67+
) -> Result<Self::ExpandedVector, nuts_rs::CpuMathError>
68+
where
69+
R: rand::Rng + ?Sized,
70+
{
71+
Ok(array.to_vec())
72+
}
5273
}
5374

5475
fn main() {

0 commit comments

Comments
 (0)