Skip to content

Commit 36a0302

Browse files
committed
feat: generalize sample and stats storage
1 parent 7ada9be commit 36a0302

34 files changed

+4382
-1751
lines changed

Cargo.toml

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ authors = [
55
"Adrian Seyboldt <[email protected]>",
66
"PyMC Developers <[email protected]>",
77
]
8-
edition = "2021"
8+
edition = "2024"
99
license = "MIT"
1010
repository = "https://github.com/pymc-devs/nuts-rs"
1111
keywords = ["statistics", "bayes"]
@@ -28,16 +28,34 @@ anyhow = "1.0.72"
2828
faer = { version = "0.22.6", default-features = false, features = ["linalg"] }
2929
pulp = "0.21.4"
3030
rayon = "1.10.0"
31+
zarrs = { version = "0.21.0", features = [
32+
"filesystem",
33+
"gzip",
34+
"sharding",
35+
], optional = true }
36+
ndarray = { version = "0.16.1", optional = true }
37+
nuts-derive = { path = "./nuts-derive" }
38+
nuts-storable = { path = "./nuts-storable" }
39+
serde = { version = "1.0.219", features = ["derive"] }
40+
serde_json = "1.0"
3141

3242
[dev-dependencies]
3343
proptest = "1.6.0"
3444
pretty_assertions = "1.4.0"
3545
criterion = "0.7.0"
3646
nix = { version = "0.30.0", features = ["sched"] }
3747
approx = "0.5.1"
38-
ndarray = "0.16.1"
3948
equator = "0.4.2"
49+
serde_json = "1.0"
50+
ndarray = "0.16.1"
51+
tempfile = "3.0"
52+
53+
[features]
54+
zarr = ["dep:zarrs"]
55+
ndarray = ["dep:ndarray"]
4056

4157
[[bench]]
4258
name = "sample"
4359
harness = false
60+
61+
[workspace]

benches/sample.rs

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
use std::hint::black_box;
22

3-
use criterion::{criterion_group, criterion_main, Criterion};
4-
use nix::sched::{sched_setaffinity, CpuSet};
3+
use criterion::{Criterion, criterion_group, criterion_main};
4+
use nix::sched::{CpuSet, sched_setaffinity};
55
use nix::unistd::Pid;
66
use nuts_rs::{Chain, CpuLogpFunc, CpuMath, LogpError, Math, Settings};
7+
use nuts_storable::HasDims;
78
use rand::SeedableRng;
89
use rayon::ThreadPoolBuilder;
910
use thiserror::Error;
@@ -22,11 +23,20 @@ impl LogpError for PosteriorLogpError {
2223
}
2324
}
2425

26+
impl HasDims for PosteriorDensity {
27+
fn dim_sizes(&self) -> std::collections::HashMap<String, u64> {
28+
vec![("unconstrained_parameter".to_string(), self.dim() as u64)]
29+
.into_iter()
30+
.collect()
31+
}
32+
}
33+
2534
impl CpuLogpFunc for PosteriorDensity {
2635
type LogpError = PosteriorLogpError;
36+
type ExpandedVector = Vec<f64>;
2737

2838
// Only used for transforming adaptation.
29-
type TransformParams = ();
39+
type FlowParameters = ();
3040

3141
// We define a 10 dimensional normal distribution
3242
fn dim(&self) -> usize {
@@ -48,6 +58,17 @@ impl CpuLogpFunc for PosteriorDensity {
4858
.sum();
4959
return Ok(logp);
5060
}
61+
62+
fn expand_vector<R>(
63+
&mut self,
64+
_rng: &mut R,
65+
array: &[f64],
66+
) -> Result<Self::ExpandedVector, nuts_rs::CpuMathError>
67+
where
68+
R: rand::Rng + ?Sized,
69+
{
70+
Ok(array.to_vec())
71+
}
5172
}
5273

5374
fn make_sampler(dim: usize) -> impl Chain<CpuMath<PosteriorDensity>> {

examples/adam_adaptation.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use nuts_rs::{
77
AdamOptions, Chain, CpuLogpFunc, CpuMath, DiagGradNutsSettings, LogpError, Settings,
88
StepSizeAdaptMethod,
99
};
10+
use nuts_storable::HasDims;
1011
use thiserror::Error;
1112

1213
// Define a function that computes the unnormalized posterior density
@@ -23,11 +24,20 @@ impl LogpError for PosteriorLogpError {
2324
}
2425
}
2526

27+
impl HasDims for PosteriorDensity {
28+
fn dim_sizes(&self) -> std::collections::HashMap<String, u64> {
29+
vec![("unconstrained_parameter".to_string(), self.dim() as u64)]
30+
.into_iter()
31+
.collect()
32+
}
33+
}
34+
2635
impl CpuLogpFunc for PosteriorDensity {
2736
type LogpError = PosteriorLogpError;
37+
type ExpandedVector = Vec<f64>;
2838

2939
// Only used for transforming adaptation.
30-
type TransformParams = ();
40+
type FlowParameters = ();
3141

3242
// We define a 10 dimensional normal distribution
3343
fn dim(&self) -> usize {
@@ -49,6 +59,17 @@ impl CpuLogpFunc for PosteriorDensity {
4959
.sum();
5060
return Ok(logp);
5161
}
62+
63+
fn expand_vector<R>(
64+
&mut self,
65+
_rng: &mut R,
66+
array: &[f64],
67+
) -> Result<Self::ExpandedVector, nuts_rs::CpuMathError>
68+
where
69+
R: rand::Rng + ?Sized,
70+
{
71+
Ok(array.to_vec())
72+
}
5273
}
5374

5475
fn main() {

0 commit comments

Comments
 (0)