Skip to content

Commit 08be5e9

Browse files
committed
adjusted
1 parent 994085a commit 08be5e9

File tree

1 file changed

+136
-15
lines changed

1 file changed

+136
-15
lines changed

book/algorithms/mclmc.md

Lines changed: 136 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ MCLMC in Blackjax comes with a tuning algorithm which attempts to find optimal v
4444

4545
An example is given below, of tuning and running a chain for a 1000 dimensional Gaussian target (of which a 2 dimensional marginal is plotted):
4646

47-
```{code-cell} ipython3
47+
```{code-cell}
4848
:tags: [hide-cell]
4949
5050
import matplotlib.pyplot as plt
@@ -66,7 +66,7 @@ from numpyro.infer.util import initialize_model
6666
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
6767
```
6868

69-
```{code-cell} ipython3
69+
```{code-cell}
7070
def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform, desired_energy_variance= 5e-4):
7171
init_key, tune_key, run_key = jax.random.split(key, 3)
7272
@@ -115,7 +115,7 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform, desire
115115
return samples, blackjax_state_after_tuning, blackjax_mclmc_sampler_params, run_key
116116
```
117117

118-
```{code-cell} ipython3
118+
```{code-cell}
119119
# run the algorithm on a high dimensional gaussian, and show two of the dimensions
120120
121121
logdensity_fn = lambda x: -0.5 * jnp.sum(jnp.square(x))
@@ -134,13 +134,13 @@ samples, initial_state, params, chain_key = run_mclmc(
134134
samples.mean()
135135
```
136136

137-
```{code-cell} ipython3
137+
```{code-cell}
138138
plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1)
139139
plt.axis("equal")
140140
plt.title("Scatter Plot of Samples")
141141
```
142142

143-
```{code-cell} ipython3
143+
```{code-cell}
144144
def visualize_results_gauss(samples, label, color):
145145
x1 = samples[:, 0]
146146
plt.hist(x1, bins= 30, density= True, histtype= 'step', lw= 4, color= color, label= label)
@@ -165,12 +165,12 @@ ground_truth_gauss()
165165

166166
A natural sanity check is to see if reducing $\epsilon$ changes the inferred distribution to an extent you care about. For example, we can inspect the 1D marginal with a stepsize $\epsilon$ as above, and compare it to a stepsize $\epsilon/2$ (and double the number of steps). We show this comparison below:
167167

168-
```{code-cell} ipython3
168+
```{code-cell}
169169
new_params = params._replace(step_size= params.step_size / 2)
170170
new_num_steps = num_steps * 2
171171
```
172172

173-
```{code-cell} ipython3
173+
```{code-cell}
174174
sampling_alg = blackjax.mclmc(
175175
logdensity_fn,
176176
L=new_params.L,
@@ -211,7 +211,7 @@ Our task is to find the posterior of the parameters $\{R_n\}_{n =1}^N$, $\sigma$
211211

212212
First, we get the data, define a model using NumPyro, and draw samples:
213213

214-
```{code-cell} ipython3
214+
```{code-cell}
215215
import matplotlib.dates as mdates
216216
from numpyro.examples.datasets import SP500, load_dataset
217217
from numpyro.distributions import StudentT
@@ -243,7 +243,7 @@ def setup():
243243
setup()
244244
```
245245

246-
```{code-cell} ipython3
246+
```{code-cell}
247247
def from_numpyro(model, rng_key, model_args):
248248
init_params, potential_fn_gen, *_ = initialize_model(
249249
rng_key,
@@ -272,13 +272,13 @@ rng_key = jax.random.key(42)
272272
logp_sv, x_init = from_numpyro(stochastic_volatility, rng_key, model_args)
273273
```
274274

275-
```{code-cell} ipython3
275+
```{code-cell}
276276
num_steps = 20000
277277
278278
samples, initial_state, params, chain_key = run_mclmc(logdensity_fn= logp_sv, num_steps= num_steps, initial_position= x_init, key= sample_key, transform=lambda state, info: state.position)
279279
```
280280

281-
```{code-cell} ipython3
281+
```{code-cell}
282282
def visualize_results_sv(samples, color, label):
283283
284284
R = np.exp(np.array(samples['s'])) # take an exponent to get R
@@ -297,7 +297,7 @@ plt.legend()
297297
plt.show()
298298
```
299299

300-
```{code-cell} ipython3
300+
```{code-cell}
301301
new_params = params._replace(step_size = params.step_size/2)
302302
new_num_steps = num_steps * 2
303303
@@ -318,10 +318,9 @@ _, new_samples = blackjax.util.run_inference_algorithm(
318318
transform=lambda state, info : state.position,
319319
progress_bar=True,
320320
)
321-
322321
```
323322

324-
```{code-cell} ipython3
323+
```{code-cell}
325324
setup()
326325
visualize_results_sv(new_samples,'red', 'MCLMC', )
327326
visualize_results_sv(samples,'teal', 'MCLMC (stepsize/2)', )
@@ -332,7 +331,7 @@ plt.show()
332331

333332
Here, we have again inspected the effect of halving $\epsilon$. This looks OK, but suppose we are interested in the hierarchial parameters in particular, which tend to be harder to infer. We now inspect the marginal of a hierarchical parameter:
334333

335-
```{code-cell} ipython3
334+
```{code-cell}
336335
def visualize_results_sv_marginal(samples, color, label):
337336
# plt.subplot(1, 2, 1)
338337
# plt.hist(samples['nu'], bins = 20, histtype= 'step', lw= 4, density= True, color= color, label= label)
@@ -354,9 +353,131 @@ If we care about this parameter in particular, we should reduce step size furthe
354353

355354
+++
356355

356+
## Adjusted MCLMC
357+
358+
Blackjax also provides an adjusted version of the algorithm. This also has two hyperparameters, `step_size` and `L`. `L` is related to the `L` parameter of the unadjusted version, but not identical. The tuning algorithm is also similar, but uses a dual averaging scheme to tune the step size. We find in practice that a target MH acceptance rate of 0.9 is a good choice.
359+
360+
```{code-cell}
361+
from blackjax.mcmc.adjusted_mclmc import rescale
362+
from blackjax.util import run_inference_algorithm
363+
364+
def run_adjusted_mclmc(
365+
logdensity_fn,
366+
num_steps,
367+
initial_position,
368+
key,
369+
transform=lambda state, _ : state.position,
370+
diagonal_preconditioning=False,
371+
random_trajectory_length=True,
372+
L_proposal_factor=jnp.inf
373+
):
374+
375+
init_key, tune_key, run_key = jax.random.split(key, 3)
376+
377+
initial_state = blackjax.mcmc.adjusted_mclmc.init(
378+
position=initial_position,
379+
logdensity_fn=logdensity_fn,
380+
random_generator_arg=init_key,
381+
)
382+
383+
if random_trajectory_length:
384+
integration_steps_fn = lambda avg_num_integration_steps: lambda k: jnp.ceil(
385+
jax.random.uniform(k) * rescale(avg_num_integration_steps))
386+
else:
387+
integration_steps_fn = lambda avg_num_integration_steps: lambda _: jnp.ceil(avg_num_integration_steps)
388+
389+
kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel(
390+
integration_steps_fn=integration_steps_fn(avg_num_integration_steps),
391+
sqrt_diag_cov=sqrt_diag_cov,
392+
)(
393+
rng_key=rng_key,
394+
state=state,
395+
step_size=step_size,
396+
logdensity_fn=logdensity_fn,
397+
L_proposal_factor=L_proposal_factor,
398+
)
399+
400+
target_acc_rate = 0.9 # our recommendation
401+
402+
(
403+
blackjax_state_after_tuning,
404+
blackjax_mclmc_sampler_params,
405+
) = blackjax.adjusted_mclmc_find_L_and_step_size(
406+
mclmc_kernel=kernel,
407+
num_steps=num_steps,
408+
state=initial_state,
409+
rng_key=tune_key,
410+
target=target_acc_rate,
411+
frac_tune1=0.1,
412+
frac_tune2=0.1,
413+
frac_tune3=0.0, # our recommendation
414+
diagonal_preconditioning=diagonal_preconditioning,
415+
)
416+
417+
step_size = blackjax_mclmc_sampler_params.step_size
418+
L = blackjax_mclmc_sampler_params.L
419+
420+
alg = blackjax.adjusted_mclmc(
421+
logdensity_fn=logdensity_fn,
422+
step_size=step_size,
423+
integration_steps_fn=lambda key: jnp.ceil(
424+
jax.random.uniform(key) * rescale(L / step_size)
425+
),
426+
sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov,
427+
L_proposal_factor=L_proposal_factor,
428+
)
429+
430+
_, out = run_inference_algorithm(
431+
rng_key=run_key,
432+
initial_state=blackjax_state_after_tuning,
433+
inference_algorithm=alg,
434+
num_steps=num_steps,
435+
transform=transform,
436+
progress_bar=False,
437+
)
438+
439+
return out
440+
```
441+
442+
```{code-cell}
443+
# run the algorithm on a high dimensional gaussian, and show two of the dimensions
444+
445+
sample_key, rng_key = jax.random.split(rng_key)
446+
samples = run_adjusted_mclmc(
447+
logdensity_fn=lambda x: -0.5 * jnp.sum(jnp.square(x)),
448+
num_steps=1000,
449+
initial_position=jnp.ones((1000,)),
450+
key=sample_key,
451+
)
452+
plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1)
453+
plt.axis("equal")
454+
plt.title("Scatter Plot of Samples")
455+
```
456+
457+
```{code-cell}
458+
# run the algorithm on a high dimensional gaussian, and show two of the dimensions
459+
460+
sample_key, rng_key = jax.random.split(rng_key)
461+
samples = run_adjusted_mclmc(
462+
logdensity_fn=lambda x: -0.5 * jnp.sum(jnp.square(x)),
463+
num_steps=1000,
464+
initial_position=jnp.ones((1000,)),
465+
key=sample_key,
466+
random_trajectory_length=False,
467+
L_proposal_factor=1.25,
468+
)
469+
plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1)
470+
plt.axis("equal")
471+
plt.title("Scatter Plot of Samples")
472+
```
473+
357474
```{bibliography}
358475
:filter: docname in docnames
359476
```
360477

361478

479+
```
480+
481+
```{code-cell}
482+
362483
```

0 commit comments

Comments
 (0)