Skip to content

Commit 5cd96b5

Browse files
committed
Update SMC notebook
1 parent f70c58f commit 5cd96b5

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

book/algorithms/TemperedSMCWithOptimizedInnerKernel.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ Although the proposal distribution is always normal, the mean and diagonal of th
158158
the particles outcome of the $i-th$ step, in order to mutate them in the step $i+1$
159159

160160
```{code-cell} ipython3
161-
from blackjax.smc.inner_kernel_tuning import inner_kernel_tuning
161+
from blackjax.smc.inner_kernel_tuning import as_top_level_api as inner_kernel_tuning
162162
from blackjax.smc.tuning.from_particles import (
163163
particles_covariance_matrix,
164164
particles_stds,
@@ -212,10 +212,10 @@ def tuned_irmh_experiment(dimensions, target_ess, num_mcmc_steps):
212212
mcmc_init_fn=irmh.init,
213213
resampling_fn=resampling.systematic,
214214
smc_algorithm=adaptive_tempered_smc,
215-
mcmc_parameter_update_fn=lambda state, info: extend_params(n_particles,
216-
{"means":particles_means(state.particles),
217-
"stds":particles_stds(state.particles)}),
218-
initial_parameter_value=extend_params(n_particles, {"means":jnp.zeros(dimensions), "stds":jnp.ones(dimensions) * 2}),
215+
mcmc_parameter_update_fn=lambda _, state, info: extend_params(
216+
{"means":particles_means(state.particles), "stds":particles_stds(state.particles)}),
217+
initial_parameter_value=extend_params(
218+
{"means":jnp.zeros(dimensions), "stds":jnp.ones(dimensions) * 2}),
219219
target_ess=target_ess,
220220
num_mcmc_steps=num_mcmc_steps,
221221
)
@@ -246,9 +246,9 @@ def irmh_full_cov_experiment(dimensions, target_ess, num_mcmc_steps):
246246
return kernel(key, state, logdensity, proposal_distribution, proposal_logdensity_fn)
247247
248248
249-
def mcmc_parameter_update_fn(state, info):
249+
def mcmc_parameter_update_fn(_, state, info):
250250
covariance = jnp.atleast_2d(particles_covariance_matrix(state.particles))
251-
return extend_params(n_particles, {"means":particles_means(state.particles), "cov":covariance})
251+
return extend_params({"means":particles_means(state.particles), "cov":covariance})
252252
253253
kernel_tuned_proposal = inner_kernel_tuning(
254254
logprior_fn=prior_log_prob,
@@ -258,7 +258,7 @@ def irmh_full_cov_experiment(dimensions, target_ess, num_mcmc_steps):
258258
resampling_fn=resampling.systematic,
259259
smc_algorithm=adaptive_tempered_smc,
260260
mcmc_parameter_update_fn=mcmc_parameter_update_fn,
261-
initial_parameter_value=extend_params(n_particles, {"means":jnp.zeros(dimensions), "cov":jnp.eye(dimensions) * 2}),
261+
initial_parameter_value=extend_params({"means":jnp.zeros(dimensions), "cov":jnp.eye(dimensions) * 2}),
262262
target_ess=target_ess,
263263
num_mcmc_steps=num_mcmc_steps,
264264
)

0 commit comments

Comments
 (0)