@@ -44,7 +44,7 @@ MCLMC in Blackjax comes with a tuning algorithm which attempts to find optimal v
4444
4545An example is given below, of tuning and running a chain for a 1000 dimensional Gaussian target (of which a 2 dimensional marginal is plotted):
4646
47- ``` {code-cell} ipython3
47+ ``` {code-cell}
4848:tags: [hide-cell]
4949
5050import matplotlib.pyplot as plt
@@ -66,7 +66,7 @@ from numpyro.infer.util import initialize_model
6666rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
6767```
6868
69- ``` {code-cell} ipython3
69+ ``` {code-cell}
7070def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform, desired_energy_variance= 5e-4):
7171 init_key, tune_key, run_key = jax.random.split(key, 3)
7272
@@ -115,7 +115,7 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform, desire
115115 return samples, blackjax_state_after_tuning, blackjax_mclmc_sampler_params, run_key
116116```
117117
118- ``` {code-cell} ipython3
118+ ``` {code-cell}
119119# run the algorithm on a high dimensional gaussian, and show two of the dimensions
120120
121121logdensity_fn = lambda x: -0.5 * jnp.sum(jnp.square(x))
@@ -134,13 +134,13 @@ samples, initial_state, params, chain_key = run_mclmc(
134134samples.mean()
135135```
136136
137- ``` {code-cell} ipython3
137+ ``` {code-cell}
138138plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1)
139139plt.axis("equal")
140140plt.title("Scatter Plot of Samples")
141141```
142142
143- ``` {code-cell} ipython3
143+ ``` {code-cell}
144144def visualize_results_gauss(samples, label, color):
145145 x1 = samples[:, 0]
146146 plt.hist(x1, bins= 30, density= True, histtype= 'step', lw= 4, color= color, label= label)
@@ -165,12 +165,12 @@ ground_truth_gauss()
165165
166166A natural sanity check is to see if reducing $\epsilon$ changes the inferred distribution to an extent you care about. For example, we can inspect the 1D marginal with a stepsize $\epsilon$ as above, and compare it to a stepsize $\epsilon/2$ (and double the number of steps). We show this comparison below:
167167
168- ``` {code-cell} ipython3
168+ ``` {code-cell}
169169new_params = params._replace(step_size= params.step_size / 2)
170170new_num_steps = num_steps * 2
171171```
172172
173- ``` {code-cell} ipython3
173+ ``` {code-cell}
174174sampling_alg = blackjax.mclmc(
175175 logdensity_fn,
176176 L=new_params.L,
@@ -211,7 +211,7 @@ Our task is to find the posterior of the parameters $\{R_n\}_{n =1}^N$, $\sigma$
211211
212212First, we get the data, define a model using NumPyro, and draw samples:
213213
214- ``` {code-cell} ipython3
214+ ``` {code-cell}
215215import matplotlib.dates as mdates
216216from numpyro.examples.datasets import SP500, load_dataset
217217from numpyro.distributions import StudentT
@@ -243,7 +243,7 @@ def setup():
243243setup()
244244```
245245
246- ``` {code-cell} ipython3
246+ ``` {code-cell}
247247def from_numpyro(model, rng_key, model_args):
248248 init_params, potential_fn_gen, *_ = initialize_model(
249249 rng_key,
@@ -272,13 +272,13 @@ rng_key = jax.random.key(42)
272272logp_sv, x_init = from_numpyro(stochastic_volatility, rng_key, model_args)
273273```
274274
275- ``` {code-cell} ipython3
275+ ``` {code-cell}
276276num_steps = 20000
277277
278278samples, initial_state, params, chain_key = run_mclmc(logdensity_fn= logp_sv, num_steps= num_steps, initial_position= x_init, key= sample_key, transform=lambda state, info: state.position)
279279```
280280
281- ``` {code-cell} ipython3
281+ ``` {code-cell}
282282def visualize_results_sv(samples, color, label):
283283
284284 R = np.exp(np.array(samples['s'])) # take an exponent to get R
@@ -297,7 +297,7 @@ plt.legend()
297297plt.show()
298298```
299299
300- ``` {code-cell} ipython3
300+ ``` {code-cell}
301301new_params = params._replace(step_size = params.step_size/2)
302302new_num_steps = num_steps * 2
303303
@@ -318,10 +318,9 @@ _, new_samples = blackjax.util.run_inference_algorithm(
318318 transform=lambda state, info : state.position,
319319 progress_bar=True,
320320)
321-
322321```
323322
324- ``` {code-cell} ipython3
323+ ``` {code-cell}
325324setup()
326325visualize_results_sv(new_samples,'red', 'MCLMC', )
327326visualize_results_sv(samples,'teal', 'MCLMC (stepsize/2)', )
@@ -332,7 +331,7 @@ plt.show()
332331
333332Here, we have again inspected the effect of halving $\epsilon$. This looks OK, but suppose we are interested in the hierarchial parameters in particular, which tend to be harder to infer. We now inspect the marginal of a hierarchical parameter:
334333
335- ``` {code-cell} ipython3
334+ ``` {code-cell}
336335def visualize_results_sv_marginal(samples, color, label):
337336 # plt.subplot(1, 2, 1)
338337 # plt.hist(samples['nu'], bins = 20, histtype= 'step', lw= 4, density= True, color= color, label= label)
@@ -354,9 +353,131 @@ If we care about this parameter in particular, we should reduce step size furthe
354353
355354+++
356355
356+ ## Adjusted MCLMC
357+
358+ Blackjax also provides an adjusted version of the algorithm. This also has two hyperparameters, ` step_size ` and ` L ` . ` L ` is related to the ` L ` parameter of the unadjusted version, but not identical. The tuning algorithm is also similar, but uses a dual averaging scheme to tune the step size. We find in practice that a target MH acceptance rate of 0.9 is a good choice.
359+
360+ ``` {code-cell}
361+ from blackjax.mcmc.adjusted_mclmc import rescale
362+ from blackjax.util import run_inference_algorithm
363+
364+ def run_adjusted_mclmc(
365+ logdensity_fn,
366+ num_steps,
367+ initial_position,
368+ key,
369+ transform=lambda state, _ : state.position,
370+ diagonal_preconditioning=False,
371+ random_trajectory_length=True,
372+ L_proposal_factor=jnp.inf
373+ ):
374+
375+ init_key, tune_key, run_key = jax.random.split(key, 3)
376+
377+ initial_state = blackjax.mcmc.adjusted_mclmc.init(
378+ position=initial_position,
379+ logdensity_fn=logdensity_fn,
380+ random_generator_arg=init_key,
381+ )
382+
383+ if random_trajectory_length:
384+ integration_steps_fn = lambda avg_num_integration_steps: lambda k: jnp.ceil(
385+ jax.random.uniform(k) * rescale(avg_num_integration_steps))
386+ else:
387+ integration_steps_fn = lambda avg_num_integration_steps: lambda _: jnp.ceil(avg_num_integration_steps)
388+
389+ kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel(
390+ integration_steps_fn=integration_steps_fn(avg_num_integration_steps),
391+ sqrt_diag_cov=sqrt_diag_cov,
392+ )(
393+ rng_key=rng_key,
394+ state=state,
395+ step_size=step_size,
396+ logdensity_fn=logdensity_fn,
397+ L_proposal_factor=L_proposal_factor,
398+ )
399+
400+ target_acc_rate = 0.9 # our recommendation
401+
402+ (
403+ blackjax_state_after_tuning,
404+ blackjax_mclmc_sampler_params,
405+ ) = blackjax.adjusted_mclmc_find_L_and_step_size(
406+ mclmc_kernel=kernel,
407+ num_steps=num_steps,
408+ state=initial_state,
409+ rng_key=tune_key,
410+ target=target_acc_rate,
411+ frac_tune1=0.1,
412+ frac_tune2=0.1,
413+ frac_tune3=0.0, # our recommendation
414+ diagonal_preconditioning=diagonal_preconditioning,
415+ )
416+
417+ step_size = blackjax_mclmc_sampler_params.step_size
418+ L = blackjax_mclmc_sampler_params.L
419+
420+ alg = blackjax.adjusted_mclmc(
421+ logdensity_fn=logdensity_fn,
422+ step_size=step_size,
423+ integration_steps_fn=lambda key: jnp.ceil(
424+ jax.random.uniform(key) * rescale(L / step_size)
425+ ),
426+ sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov,
427+ L_proposal_factor=L_proposal_factor,
428+ )
429+
430+ _, out = run_inference_algorithm(
431+ rng_key=run_key,
432+ initial_state=blackjax_state_after_tuning,
433+ inference_algorithm=alg,
434+ num_steps=num_steps,
435+ transform=transform,
436+ progress_bar=False,
437+ )
438+
439+ return out
440+ ```
441+
442+ ``` {code-cell}
443+ # run the algorithm on a high dimensional gaussian, and show two of the dimensions
444+
445+ sample_key, rng_key = jax.random.split(rng_key)
446+ samples = run_adjusted_mclmc(
447+ logdensity_fn=lambda x: -0.5 * jnp.sum(jnp.square(x)),
448+ num_steps=1000,
449+ initial_position=jnp.ones((1000,)),
450+ key=sample_key,
451+ )
452+ plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1)
453+ plt.axis("equal")
454+ plt.title("Scatter Plot of Samples")
455+ ```
456+
457+ ``` {code-cell}
458+ # run the algorithm on a high dimensional gaussian, and show two of the dimensions
459+
460+ sample_key, rng_key = jax.random.split(rng_key)
461+ samples = run_adjusted_mclmc(
462+ logdensity_fn=lambda x: -0.5 * jnp.sum(jnp.square(x)),
463+ num_steps=1000,
464+ initial_position=jnp.ones((1000,)),
465+ key=sample_key,
466+ random_trajectory_length=False,
467+ L_proposal_factor=1.25,
468+ )
469+ plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1)
470+ plt.axis("equal")
471+ plt.title("Scatter Plot of Samples")
472+ ```
473+
357474``` {bibliography}
358475:filter: docname in docnames
359476```
360477
361478
479+ ```
480+
481+ ```{code-cell}
482+
362483```
0 commit comments