Skip to content

Commit b5dbaef

Browse files
committed
feat: add mindepth option for nuts
1 parent 9a4ccb8 commit b5dbaef

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

src/adapt_strategy.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ mod test {
501501
EuclideanHamiltonian::new(&mut math, mass_matrix, max_energy_error, step_size);
502502
let options = NutsOptions {
503503
maxdepth: 10u64,
504+
mindepth: 0,
504505
store_gradient: true,
505506
store_unconstrained: true,
506507
check_turning: true,

src/nuts.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
250250

251251
pub struct NutsOptions {
252252
pub maxdepth: u64,
253+
pub mindepth: u64,
253254
pub store_gradient: bool,
254255
pub store_unconstrained: bool,
255256
pub check_turning: bool,
@@ -286,9 +287,13 @@ where
286287
tree = match tree.extend(math, rng, hamiltonian, direction, collector, options) {
287288
ExtendResult::Ok(tree) => tree,
288289
ExtendResult::Turning(tree) => {
289-
let info = tree.info(false, None);
290-
collector.register_draw(math, &tree.draw, &info);
291-
return Ok((tree.draw, info));
290+
if tree.depth < options.mindepth {
291+
tree
292+
} else {
293+
let info = tree.info(false, None);
294+
collector.register_draw(math, &tree.draw, &info);
295+
return Ok((tree.draw, info));
296+
}
292297
}
293298
ExtendResult::Diverging(tree, info) => {
294299
let info = tree.info(false, Some(info));

src/sampler.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ pub struct NutsSettings<A: Debug + Copy + Default> {
8787
/// The maximum tree depth during sampling. The number of leapfrog steps
8888
/// is smaller than 2 ^ maxdepth.
8989
pub maxdepth: u64,
90+
/// The minimum tree depth during sampling. The number of leapfrog steps
91+
/// is larger than 2 ^ mindepth.
92+
pub mindepth: u64,
9093
/// Store the gradient in the SampleStats
9194
pub store_gradient: bool,
9295
/// Store each unconstrained parameter vector in the sampler stats
@@ -114,6 +117,7 @@ impl Default for DiagGradNutsSettings {
114117
num_tune: 400,
115118
num_draws: 1000,
116119
maxdepth: 10,
120+
mindepth: 0,
117121
max_energy_error: 1000f64,
118122
store_gradient: false,
119123
store_unconstrained: false,
@@ -132,6 +136,7 @@ impl Default for LowRankNutsSettings {
132136
num_tune: 800,
133137
num_draws: 1000,
134138
maxdepth: 10,
139+
mindepth: 0,
135140
max_energy_error: 1000f64,
136141
store_gradient: false,
137142
store_unconstrained: false,
@@ -152,6 +157,7 @@ impl Default for TransformedNutsSettings {
152157
num_tune: 1500,
153158
num_draws: 1000,
154159
maxdepth: 10,
160+
mindepth: 0,
155161
max_energy_error: 20f64,
156162
store_gradient: false,
157163
store_unconstrained: false,
@@ -187,6 +193,7 @@ impl Settings for LowRankNutsSettings {
187193

188194
let options = NutsOptions {
189195
maxdepth: self.maxdepth,
196+
mindepth: self.mindepth,
190197
store_gradient: self.store_gradient,
191198
store_divergences: self.store_divergences,
192199
store_unconstrained: self.store_unconstrained,
@@ -246,6 +253,7 @@ impl Settings for DiagGradNutsSettings {
246253

247254
let options = NutsOptions {
248255
maxdepth: self.maxdepth,
256+
mindepth: self.mindepth,
249257
store_gradient: self.store_gradient,
250258
store_divergences: self.store_divergences,
251259
store_unconstrained: self.store_unconstrained,
@@ -302,6 +310,7 @@ impl Settings for TransformedNutsSettings {
302310

303311
let options = NutsOptions {
304312
maxdepth: self.maxdepth,
313+
mindepth: self.mindepth,
305314
store_gradient: self.store_gradient,
306315
store_divergences: self.store_divergences,
307316
store_unconstrained: self.store_unconstrained,

0 commit comments

Comments
 (0)