Skip to content

Commit 7ada9be

Browse files
committed
feat: implement step size adaptation with adam
1 parent e5cf1f9 commit 7ada9be

File tree

9 files changed

+422
-48
lines changed

9 files changed

+422
-48
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ rand = { version = "0.9.0", features = ["small_rng"] }
2222
rand_distr = "0.5.0"
2323
itertools = "0.14.0"
2424
thiserror = "2.0.3"
25-
arrow = { version = "55.1.0", default-features = false, features = ["ffi"] }
25+
arrow = { version = "56.1.0", default-features = false, features = ["ffi"] }
2626
rand_chacha = "0.9.0"
2727
anyhow = "1.0.72"
2828
faer = { version = "0.22.6", default-features = false, features = ["linalg"] }
@@ -32,7 +32,7 @@ rayon = "1.10.0"
3232
[dev-dependencies]
3333
proptest = "1.6.0"
3434
pretty_assertions = "1.4.0"
35-
criterion = "0.6.0"
35+
criterion = "0.7.0"
3636
nix = { version = "0.30.0", features = ["sched"] }
3737
approx = "0.5.1"
3838
ndarray = "0.16.1"

examples/adam_adaptation.rs

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
//! Example demonstrating the Adam optimizer for step size adaptation.
2+
//!
3+
//! This example shows how to use the Adam optimizer instead of dual averaging
4+
//! for adapting the step size in NUTS.
5+
6+
use nuts_rs::{
7+
AdamOptions, Chain, CpuLogpFunc, CpuMath, DiagGradNutsSettings, LogpError, Settings,
8+
StepSizeAdaptMethod,
9+
};
10+
use thiserror::Error;
11+
12+
// Define a function that computes the unnormalized posterior density
13+
// and its gradient.
14+
#[derive(Debug)]
15+
struct PosteriorDensity {}
16+
17+
// The density might fail in a recoverable or non-recoverable manner...
18+
#[derive(Debug, Error)]
19+
enum PosteriorLogpError {}
20+
impl LogpError for PosteriorLogpError {
21+
fn is_recoverable(&self) -> bool {
22+
false
23+
}
24+
}
25+
26+
impl CpuLogpFunc for PosteriorDensity {
27+
type LogpError = PosteriorLogpError;
28+
29+
// Only used for transforming adaptation.
30+
type TransformParams = ();
31+
32+
// We define a 10 dimensional normal distribution
33+
fn dim(&self) -> usize {
34+
10
35+
}
36+
37+
// The normal likelihood with mean 3 and its gradient.
38+
fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
39+
let mu = 3f64;
40+
let logp = position
41+
.iter()
42+
.copied()
43+
.zip(grad.iter_mut())
44+
.map(|(x, grad)| {
45+
let diff = x - mu;
46+
*grad = -diff;
47+
-diff * diff / 2f64
48+
})
49+
.sum();
50+
return Ok(logp);
51+
}
52+
}
53+
54+
fn main() {
55+
println!("Running NUTS with Adam step size adaptation...");
56+
57+
// Create sampler settings with Adam optimizer
58+
let mut settings = DiagGradNutsSettings::default();
59+
60+
// Configure for Adam adaptation
61+
settings
62+
.adapt_options
63+
.step_size_settings
64+
.adapt_options
65+
.method = StepSizeAdaptMethod::Adam;
66+
67+
// Set Adam options
68+
let adam_options = AdamOptions {
69+
beta1: 0.9,
70+
beta2: 0.999,
71+
epsilon: 1e-8,
72+
learning_rate: 0.05,
73+
};
74+
75+
settings.adapt_options.step_size_settings.adapt_options.adam = adam_options;
76+
77+
// Standard MCMC settings
78+
settings.num_tune = 1000;
79+
settings.num_draws = 1000;
80+
settings.maxdepth = 10;
81+
82+
// Create the posterior density function
83+
let logp_func = PosteriorDensity {};
84+
let math = CpuMath::new(logp_func);
85+
86+
// Initialize the sampler
87+
let chain = 0;
88+
let mut rng = rand::rng();
89+
let mut sampler = settings.new_chain(chain, math, &mut rng);
90+
91+
// Set initial position
92+
let initial_position = vec![0f64; 10];
93+
sampler
94+
.set_position(&initial_position)
95+
.expect("Unrecoverable error during init");
96+
97+
// Collect samples
98+
let mut trace = vec![];
99+
let mut stats = vec![];
100+
101+
// Sampling with progress reporting
102+
println!("Warmup phase:");
103+
for i in 0..settings.num_tune {
104+
if i % 100 == 0 {
105+
println!("\rWarmup: {}/{}", i, settings.num_tune);
106+
}
107+
108+
let (draw, info) = sampler.draw().expect("Unrecoverable error during sampling");
109+
println!("{:?}", info.step_size);
110+
trace.push(draw);
111+
stats.push(info);
112+
}
113+
println!("\rWarmup: {}/{}", settings.num_tune, settings.num_tune);
114+
115+
println!("\nSampling phase:");
116+
for i in 0..settings.num_draws {
117+
if i % 100 == 0 {
118+
print!("\rSampling: {}/{}", i, settings.num_draws);
119+
}
120+
121+
let (draw, info) = sampler.draw().expect("Unrecoverable error during sampling");
122+
trace.push(draw);
123+
stats.push(info);
124+
}
125+
println!("\rSampling: {}/{}", settings.num_draws, settings.num_draws);
126+
127+
// Calculate mean of samples (post-warmup)
128+
let warmup_samples = settings.num_tune as usize;
129+
let mut means = vec![0.0; 10];
130+
131+
for i in warmup_samples..trace.len() {
132+
for (j, mean) in means.iter_mut().enumerate() {
133+
*mean += trace[i][j];
134+
}
135+
}
136+
137+
for mean in &mut means {
138+
*mean /= settings.num_draws as f64;
139+
}
140+
141+
// Print results
142+
println!("\nResults after {} samples:", settings.num_draws);
143+
println!("Target mean: 3.0 for all dimensions");
144+
println!("Estimated means:");
145+
for (i, mean) in means.iter().enumerate() {
146+
println!("Dimension {}: {:.4}", i, mean);
147+
}
148+
149+
// Print adaptation statistics
150+
let last_stats = &stats[stats.len() - 1];
151+
println!("\nFinal adaptation statistics:");
152+
println!("Step size: {:.6}", last_stats.step_size);
153+
// Note: the full acceptance stats are in the Progress struct, but we don't have direct access to mean_tree_accept
154+
println!("Number of steps: {}", last_stats.num_steps);
155+
156+
println!("\nSampling completed successfully!");
157+
}

src/adapt_strategy.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ use crate::{
1414
sampler::Settings,
1515
sampler_stats::{SamplerStats, StatTraceBuilder},
1616
state::State,
17-
stepsize::AcceptanceRateCollector,
1817
stepsize_adapt::{
19-
DualAverageSettings, StatsBuilder as StepSizeStatsBuilder, Strategy as StepSizeStrategy,
18+
StatsBuilder as StepSizeStatsBuilder, StepSizeSettings, Strategy as StepSizeStrategy,
2019
},
20+
stepsize_dual_avg::AcceptanceRateCollector,
2121
NutsError,
2222
};
2323

@@ -38,7 +38,7 @@ pub struct GlobalStrategy<M: Math, A: MassMatrixAdaptStrategy<M>> {
3838

3939
#[derive(Debug, Clone, Copy)]
4040
pub struct EuclideanAdaptOptions<S: Debug + Default> {
41-
pub dual_average_options: DualAverageSettings,
41+
pub step_size_settings: StepSizeSettings,
4242
pub mass_matrix_options: S,
4343
pub early_window: f64,
4444
pub step_size_window: f64,
@@ -50,7 +50,7 @@ pub struct EuclideanAdaptOptions<S: Debug + Default> {
5050
impl<S: Debug + Default> Default for EuclideanAdaptOptions<S> {
5151
fn default() -> Self {
5252
Self {
53-
dual_average_options: DualAverageSettings::default(),
53+
step_size_settings: StepSizeSettings::default(),
5454
mass_matrix_options: S::default(),
5555
early_window: 0.3,
5656
step_size_window: 0.15,
@@ -97,7 +97,7 @@ impl<M: Math, A: MassMatrixAdaptStrategy<M>> AdaptStrategy<M> for GlobalStrategy
9797
assert!(early_end < num_tune);
9898

9999
Self {
100-
step_size: StepSizeStrategy::new(options.dual_average_options),
100+
step_size: StepSizeStrategy::new(options.step_size_settings),
101101
mass_matrix: A::new(math, options.mass_matrix_options, num_tune, chain),
102102
options,
103103
num_tune,

src/lib.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,9 @@ mod nuts;
9898
mod sampler;
9999
mod sampler_stats;
100100
mod state;
101-
mod stepsize;
101+
mod stepsize_adam;
102102
mod stepsize_adapt;
103+
mod stepsize_dual_avg;
103104
mod transform_adapt_strategy;
104105
mod transformed_hamiltonian;
105106

@@ -117,5 +118,6 @@ pub use sampler::{
117118

118119
pub use low_rank_mass_matrix::LowRankSettings;
119120
pub use mass_matrix_adapt::DiagAdaptExpSettings;
120-
pub use stepsize_adapt::DualAverageSettings;
121+
pub use stepsize_adam::AdamOptions;
122+
pub use stepsize_adapt::{StepSizeAdaptMethod, StepSizeAdaptOptions, StepSizeSettings};
121123
pub use transform_adapt_strategy::TransformedSettings;

src/mass_matrix.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,6 @@ pub trait MassMatrix<M: Math>: SamplerStats<M> {
2424
);
2525
}
2626

27-
pub struct NullCollector {}
28-
29-
impl<M: Math, P: Point<M>> Collector<M, P> for NullCollector {}
30-
3127
#[derive(Debug)]
3228
pub struct DiagMassMatrix<M: Math> {
3329
inv_stds: M::Vector,

src/stepsize_adam.rs

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
//! Adam optimizer for step size adaptation.
2+
//!
3+
//! This implements a single-parameter version of the Adam optimizer
4+
//! for adapting the step size in the NUTS algorithm. Unlike dual averaging,
5+
//! Adam maintains both first and second moment estimates of gradients,
6+
//! which can potentially lead to better adaptation in some scenarios.
7+
8+
use std::f64;
9+
10+
/// Settings for Adam step size adaptation
11+
#[derive(Debug, Clone, Copy)]
12+
pub struct AdamOptions {
13+
/// First moment decay rate (default: 0.9)
14+
pub beta1: f64,
15+
/// Second moment decay rate (default: 0.999)
16+
pub beta2: f64,
17+
/// Small constant for numerical stability (default: 1e-8)
18+
pub epsilon: f64,
19+
/// Learning rate (default: 0.001)
20+
pub learning_rate: f64,
21+
}
22+
23+
impl Default for AdamOptions {
24+
fn default() -> Self {
25+
Self {
26+
beta1: 0.9,
27+
beta2: 0.999,
28+
epsilon: 1e-8,
29+
learning_rate: 0.05,
30+
}
31+
}
32+
}
33+
34+
/// Adam optimizer for step size adaptation.
35+
///
36+
/// This implements the Adam optimizer for a single parameter (the step size).
37+
/// The adaptation takes the acceptance probability statistic and adjusts
38+
/// the step size to reach the target acceptance rate.
39+
#[derive(Clone)]
40+
pub struct Adam {
41+
/// Current log step size
42+
log_step: f64,
43+
/// First moment estimate
44+
m: f64,
45+
/// Second moment estimate
46+
v: f64,
47+
/// Iteration counter
48+
t: u64,
49+
/// Adam settings
50+
settings: AdamOptions,
51+
}
52+
53+
impl Adam {
54+
/// Create a new Adam optimizer with given settings and initial step size
55+
pub fn new(settings: AdamOptions, initial_step: f64) -> Self {
56+
Self {
57+
log_step: initial_step.ln(),
58+
m: 0.0,
59+
v: 0.0,
60+
t: 0,
61+
settings,
62+
}
63+
}
64+
65+
/// Advance the optimizer by one step using the current acceptance statistic
66+
///
67+
/// This updates the step size to move towards the target acceptance rate.
68+
/// The error signal is the difference between the target and current acceptance rates.
69+
pub fn advance(&mut self, accept_stat: f64, target: f64) {
70+
// Compute the error/gradient - we want to minimize (target - accept_stat)²
71+
// So gradient is -2 * (target - accept_stat)
72+
// We simplify and just use (accept_stat - target) as our gradient
73+
let gradient = accept_stat - target;
74+
75+
// Increment timestep
76+
self.t += 1;
77+
78+
// Update biased first moment estimate
79+
self.m = self.settings.beta1 * self.m + (1.0 - self.settings.beta1) * gradient;
80+
81+
// Update biased second moment estimate
82+
self.v = self.settings.beta2 * self.v + (1.0 - self.settings.beta2) * gradient * gradient;
83+
84+
// Compute bias-corrected first moment estimate
85+
let m_hat = self.m / (1.0 - self.settings.beta1.powi(self.t as i32));
86+
87+
// Compute bias-corrected second moment estimate
88+
let v_hat = self.v / (1.0 - self.settings.beta2.powi(self.t as i32));
89+
90+
// Update log step size
91+
// Note: if gradient is positive (accept_stat > target), we should decrease step size
92+
// if gradient is negative (accept_stat < target), we should increase step size
93+
self.log_step +=
94+
self.settings.learning_rate * m_hat / (v_hat.sqrt() + self.settings.epsilon);
95+
}
96+
97+
/// Get the current step size (not adapted)
98+
pub fn current_step_size(&self) -> f64 {
99+
self.log_step.exp()
100+
}
101+
102+
/// Reset the optimizer with a new initial step size and bias factor
103+
#[allow(dead_code)]
104+
pub fn reset(&mut self, initial_step: f64, _bias_factor: f64) {
105+
self.log_step = initial_step.ln();
106+
self.m = 0.0;
107+
self.v = 0.0;
108+
self.t = 0;
109+
}
110+
}

0 commit comments

Comments
 (0)