diff --git a/book/algorithms/TemperedSMC.md b/book/algorithms/TemperedSMC.md index d769f09b..474862cd 100644 --- a/book/algorithms/TemperedSMC.md +++ b/book/algorithms/TemperedSMC.md @@ -4,7 +4,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.16.2 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -158,7 +158,7 @@ We now use a NUTS kernel. ```{code-cell} ipython3 %%time -nuts_parameters = dict(step_size=1e-4, inverse_mass_matrix=inv_mass_matrix) +nuts_parameters = dict(step_size=1e-4, inverse_mass_matrix=inv_mass_matrix, max_num_doublings=6) nuts = blackjax.nuts(full_logdensity, **nuts_parameters) nuts_state = nuts.init(jnp.ones((1,))) @@ -219,7 +219,7 @@ tempered = blackjax.adaptive_tempered_smc( loglikelihood, blackjax.hmc.build_kernel(), blackjax.hmc.init, - extend_params(n_samples, hmc_parameters), + extend_params(hmc_parameters), resampling.systematic, 0.5, num_mcmc_steps=1, @@ -367,7 +367,7 @@ tempered = blackjax.adaptive_tempered_smc( loglikelihood, blackjax.hmc.build_kernel(), blackjax.hmc.init, - extend_params(n_samples, hmc_parameters), + extend_params(hmc_parameters), resampling.systematic, 0.75, num_mcmc_steps=1, diff --git a/book/algorithms/TemperedSMCWithOptimizedInnerKernel.md b/book/algorithms/TemperedSMCWithOptimizedInnerKernel.md index 769af90b..b6750d77 100644 --- a/book/algorithms/TemperedSMCWithOptimizedInnerKernel.md +++ b/book/algorithms/TemperedSMCWithOptimizedInnerKernel.md @@ -4,7 +4,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.16.2 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -158,7 +158,7 @@ Although the proposal distribution is always normal, the mean and diagonal of th the particles outcome of the $i-th$ step, in order to mutate them in the step $i+1$ ```{code-cell} ipython3 -from blackjax.smc.inner_kernel_tuning import inner_kernel_tuning +from blackjax import inner_kernel_tuning from blackjax.smc.tuning.from_particles import ( particles_covariance_matrix, particles_stds, @@ -212,10 +212,9 @@ def tuned_irmh_experiment(dimensions, target_ess, num_mcmc_steps): mcmc_init_fn=irmh.init, resampling_fn=resampling.systematic, smc_algorithm=adaptive_tempered_smc, - mcmc_parameter_update_fn=lambda state, info: extend_params(n_particles, - {"means":particles_means(state.particles), - "stds":particles_stds(state.particles)}), - initial_parameter_value=extend_params(n_particles, {"means":jnp.zeros(dimensions), "stds":jnp.ones(dimensions) * 2}), + mcmc_parameter_update_fn=lambda state, info: extend_params({"means":particles_means(state.particles), + "stds":particles_stds(state.particles)}), + initial_parameter_value=extend_params({"means":jnp.zeros(dimensions), "stds":jnp.ones(dimensions) * 2}), target_ess=target_ess, num_mcmc_steps=num_mcmc_steps, ) @@ -248,7 +247,7 @@ def irmh_full_cov_experiment(dimensions, target_ess, num_mcmc_steps): def mcmc_parameter_update_fn(state, info): covariance = jnp.atleast_2d(particles_covariance_matrix(state.particles)) - return extend_params(n_particles, {"means":particles_means(state.particles), "cov":covariance}) + return extend_params({"means":particles_means(state.particles), "cov":covariance}) kernel_tuned_proposal = inner_kernel_tuning( logprior_fn=prior_log_prob, @@ -258,7 +257,7 @@ def irmh_full_cov_experiment(dimensions, target_ess, num_mcmc_steps): resampling_fn=resampling.systematic, smc_algorithm=adaptive_tempered_smc, mcmc_parameter_update_fn=mcmc_parameter_update_fn, - initial_parameter_value=extend_params(n_particles, {"means":jnp.zeros(dimensions), "cov":jnp.eye(dimensions) * 2}), + initial_parameter_value=extend_params({"means":jnp.zeros(dimensions), "cov":jnp.eye(dimensions) * 2}), target_ess=target_ess, num_mcmc_steps=num_mcmc_steps, )