-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathlib.rs
More file actions
151 lines (144 loc) · 4.99 KB
/
lib.rs
File metadata and controls
151 lines (144 loc) · 4.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
//! Sample from posterior distributions using the No U-turn Sampler (NUTS).
//! For details see the original [NUTS paper](https://arxiv.org/abs/1111.4246)
//! and the more recent [introduction](https://arxiv.org/abs/1701.02434).
//!
//! This crate was developed as a faster replacement of the sampler in PyMC,
//! to be used with the new numba backend of PyTensor. The python wrapper
//! for this sampler is [nutpie](https://github.com/pymc-devs/nutpie).
//!
//! ## Usage
//!
//! ```
//! use nuts_rs::{CpuLogpFunc, CpuMath, LogpError, DiagGradNutsSettings, Chain, Progress,
//! Settings, HasDims};
//! use thiserror::Error;
//! use rand::rng;
//! use std::collections::HashMap;
//!
//! // Define a function that computes the unnormalized posterior density
//! // and its gradient.
//! #[derive(Debug)]
//! struct PosteriorDensity {}
//!
//! // The density might fail in a recoverable or non-recoverable manner...
//! #[derive(Debug, Error)]
//! enum PosteriorLogpError {}
//! impl LogpError for PosteriorLogpError {
//! fn is_recoverable(&self) -> bool { false }
//! }
//!
//! impl HasDims for PosteriorDensity {
//! fn dim_sizes(&self) -> HashMap<String, u64> {
//! vec![("unconstrained_parameter".to_string(), self.dim() as u64)].into_iter().collect()
//! }
//! }
//!
//! impl CpuLogpFunc for PosteriorDensity {
//! type LogpError = PosteriorLogpError;
//! type ExpandedVector = Vec<f64>;
//!
//! // Only used for transforming adaptation.
//! type FlowParameters = ();
//!
//! // We define a 10 dimensional normal distribution
//! fn dim(&self) -> usize { 10 }
//!
//! // The normal likelihood with mean 3 and its gradient.
//! fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
//! let mu = 3f64;
//! let logp = position
//! .iter()
//! .copied()
//! .zip(grad.iter_mut())
//! .map(|(x, grad)| {
//! let diff = x - mu;
//! *grad = -diff;
//! -diff * diff / 2f64
//! })
//! .sum();
//! return Ok(logp)
//! }
//!
//! fn expand_vector<R: rand::Rng + ?Sized>(&mut self, rng: &mut R, position: &[f64]) -> Result<Vec<f64>, nuts_rs::CpuMathError> {
//! Ok(position.to_vec())
//! }
//! }
//!
//! // We get the default sampler arguments
//! let mut settings = DiagGradNutsSettings::default();
//!
//! // and modify as we like
//! settings.num_tune = 1000;
//! settings.maxdepth = 3; // small value just for testing...
//!
//! // We instanciate our posterior density function
//! let logp_func = PosteriorDensity {};
//! let math = CpuMath::new(logp_func);
//!
//! let chain = 0;
//! let mut rng = rng();
//! let mut sampler = settings.new_chain(0, math, &mut rng);
//!
//! // Set to some initial position and start drawing samples.
//! sampler.set_position(&vec![0f64; 10]).expect("Unrecoverable error during init");
//! let mut trace = vec![]; // Collection of all draws
//! for _ in 0..2000 {
//! let (draw, info) = sampler.draw().expect("Unrecoverable error during sampling");
//! trace.push(draw);
//! }
//! ```
//!
//! Users can also implement the `Model` trait for more control and parallel sampling.
//!
//! See the examples directory in the repository for more examples.
//!
//! ## Implementation details
//!
//! This crate mostly follows the implementation of NUTS in [Stan](https://mc-stan.org) and
//! [PyMC](https://docs.pymc.io/en/v3/), only tuning of mass matrix and step size differs
//! somewhat.
mod adapt_strategy;
mod chain;
mod cpu_math;
mod euclidean_hamiltonian;
mod hamiltonian;
mod mass_matrix;
mod math;
mod math_base;
mod model;
mod nuts;
mod sampler;
mod sampler_stats;
mod state;
mod stepsize;
mod storage;
mod transform_adapt_strategy;
mod transformed_hamiltonian;
pub use nuts_derive::Storable;
pub use nuts_storable::{DateTimeUnit, HasDims, ItemType, Storable, Value};
pub use rand;
pub use adapt_strategy::EuclideanAdaptOptions;
pub use chain::Chain;
pub use cpu_math::{CpuLogpFunc, CpuMath, CpuMathError};
pub use hamiltonian::DivergenceInfo;
pub use math_base::{LogpError, Math};
pub use model::Model;
pub use nuts::NutsError;
pub use sampler::{
ChainProgress, DiagGradNutsSettings, LowRankNutsSettings, NutsSettings, Progress,
ProgressCallback, SampleData, Sampler, SamplerWaitResult, Settings, TransformedNutsSettings,
sample_sequentially,
};
pub use sampler_stats::SamplerStats;
pub use mass_matrix::DiagAdaptExpSettings;
pub use mass_matrix::LowRankSettings;
pub use stepsize::{AdamOptions, StepSizeAdaptMethod, StepSizeAdaptOptions, StepSizeSettings};
pub use transform_adapt_strategy::TransformedSettings;
#[cfg(feature = "zarr")]
pub use storage::{ZarrAsyncConfig, ZarrAsyncTraceStorage, ZarrConfig, ZarrTraceStorage};
pub use storage::{CsvConfig, CsvTraceStorage};
pub use storage::{HashMapConfig, HashMapValue};
#[cfg(feature = "ndarray")]
pub use storage::{NdarrayConfig, NdarrayTrace, NdarrayValue};
#[cfg(feature = "arrow")]
pub use storage::{ArrowConfig, ArrowTrace, ArrowTraceStorage};