@@ -158,7 +158,7 @@ Although the proposal distribution is always normal, the mean and diagonal of th
158158the particles outcome of the $i-th$ step, in order to mutate them in the step $i+1$
159159
160160``` {code-cell} ipython3
161- from blackjax.smc.inner_kernel_tuning import inner_kernel_tuning
161+ from blackjax.smc.inner_kernel_tuning import as_top_level_api as inner_kernel_tuning
162162from blackjax.smc.tuning.from_particles import (
163163 particles_covariance_matrix,
164164 particles_stds,
@@ -212,10 +212,10 @@ def tuned_irmh_experiment(dimensions, target_ess, num_mcmc_steps):
212212 mcmc_init_fn=irmh.init,
213213 resampling_fn=resampling.systematic,
214214 smc_algorithm=adaptive_tempered_smc,
215- mcmc_parameter_update_fn=lambda state, info: extend_params(n_particles,
216- {"means":particles_means(state.particles),
217- "stds":particles_stds(state.particles)}),
218- initial_parameter_value=extend_params(n_particles, {"means":jnp.zeros(dimensions), "stds":jnp.ones(dimensions) * 2}),
215+ mcmc_parameter_update_fn=lambda _, state, info: extend_params(
216+ {"means":particles_means(state.particles), "stds":particles_stds(state.particles)} ),
217+ initial_parameter_value=extend_params(
218+ {"means":jnp.zeros(dimensions), "stds":jnp.ones(dimensions) * 2}),
219219 target_ess=target_ess,
220220 num_mcmc_steps=num_mcmc_steps,
221221 )
@@ -246,9 +246,9 @@ def irmh_full_cov_experiment(dimensions, target_ess, num_mcmc_steps):
246246 return kernel(key, state, logdensity, proposal_distribution, proposal_logdensity_fn)
247247
248248
249- def mcmc_parameter_update_fn(state, info):
249+ def mcmc_parameter_update_fn(_, state, info):
250250 covariance = jnp.atleast_2d(particles_covariance_matrix(state.particles))
251- return extend_params(n_particles, {"means":particles_means(state.particles), "cov":covariance})
251+ return extend_params({"means":particles_means(state.particles), "cov":covariance})
252252
253253 kernel_tuned_proposal = inner_kernel_tuning(
254254 logprior_fn=prior_log_prob,
@@ -258,7 +258,7 @@ def irmh_full_cov_experiment(dimensions, target_ess, num_mcmc_steps):
258258 resampling_fn=resampling.systematic,
259259 smc_algorithm=adaptive_tempered_smc,
260260 mcmc_parameter_update_fn=mcmc_parameter_update_fn,
261- initial_parameter_value=extend_params(n_particles, {"means":jnp.zeros(dimensions), "cov":jnp.eye(dimensions) * 2}),
261+ initial_parameter_value=extend_params({"means":jnp.zeros(dimensions), "cov":jnp.eye(dimensions) * 2}),
262262 target_ess=target_ess,
263263 num_mcmc_steps=num_mcmc_steps,
264264 )
0 commit comments