Skip to content

Commit be2c293

Browse files
committed
fix: mindepth when check_turning=True was misbehaving
1 parent b768151 commit be2c293

File tree

1 file changed

+34
-8
lines changed

1 file changed

+34
-8
lines changed

src/nuts.rs

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,19 @@ pub struct NutsOptions {
257257
pub store_divergences: bool,
258258
}
259259

260+
impl Default for NutsOptions {
261+
fn default() -> Self {
262+
NutsOptions {
263+
maxdepth: 10,
264+
mindepth: 0,
265+
store_gradient: false,
266+
store_unconstrained: false,
267+
check_turning: true,
268+
store_divergences: false,
269+
}
270+
}
271+
}
272+
260273
pub(crate) fn draw<M, H, R, C>(
261274
math: &mut M,
262275
init: &mut State<M, H::Point>,
@@ -282,18 +295,31 @@ where
282295
return Ok((init.clone(), info));
283296
}
284297

298+
let options_no_check = NutsOptions {
299+
check_turning: false,
300+
..*options
301+
};
302+
285303
while tree.depth < options.maxdepth {
286304
let direction: Direction = rng.random();
287-
tree = match tree.extend(math, rng, hamiltonian, direction, collector, options) {
305+
let current_options = if tree.depth < options.mindepth {
306+
&options_no_check
307+
} else {
308+
options
309+
};
310+
tree = match tree.extend(
311+
math,
312+
rng,
313+
hamiltonian,
314+
direction,
315+
collector,
316+
current_options,
317+
) {
288318
ExtendResult::Ok(tree) => tree,
289319
ExtendResult::Turning(tree) => {
290-
if tree.depth < options.mindepth {
291-
tree
292-
} else {
293-
let info = tree.info(false, None);
294-
collector.register_draw(math, &tree.draw, &info);
295-
return Ok((tree.draw, info));
296-
}
320+
let info = tree.info(false, None);
321+
collector.register_draw(math, &tree.draw, &info);
322+
return Ok((tree.draw, info));
297323
}
298324
ExtendResult::Diverging(tree, info) => {
299325
let info = tree.info(false, Some(info));

0 commit comments

Comments
 (0)