Skip to content

Conversation

@Cole-Monnahan-NOAA
Copy link
Contributor

This PR fixes a small bug in the check_hmc_diagnostics function where the number of iterations exceeding the treedepth was calculated as

max_treedepths <- sum(draws_df$treedepth__ > max_treedepth)

but should be

max_treedepths <- sum(draws_df$treedepth__ >= max_treedepth)

The sampler will never exceed the max_treedepth so perhaps == would be more sensible. Furthermore it adds a new print=TRUE argument to toggle console printing, and instead of returning nothing it returns a data.frame containing the percent divergences, number of treedepth exceedences, and EBMFI info. These two are useful for users running a lot of models and wanting an easy way to track this diagnostic information. I would suggest exporting this function as well, as a user may want to rerun it (e.g., in a saved session).

Reprex:

fit <- stan_sample(loglik_fun, inits, additional_args = list(y),
                   lower = c(-Inf, 0), # Enforce a positivity constraint for SD
                   num_chains = 1, seed = 1234,
                   max_treedepth = 2)

results in

857 of 1000 (85.7%) transitions hit the maximum treedepth limit of 2, or 2^2 leapfrog steps.
Trajectories that are prematurely terminated due to this limit will result in slow exploration.
For optimal performance, increase this limit.

whereas previously it printed no warnings.

This PR is backwards compatible and passes tests locally.

@andrjohns andrjohns merged commit 75d76c4 into andrjohns:main Jan 4, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants