|
| 1 | +--- |
| 2 | +jupytext: |
| 3 | + text_representation: |
| 4 | + extension: .md |
| 5 | + format_name: myst |
| 6 | + format_version: 0.13 |
| 7 | + jupytext_version: 1.16.0 |
| 8 | +kernelspec: |
| 9 | + display_name: mclmc |
| 10 | + language: python |
| 11 | + name: python3 |
| 12 | +--- |
| 13 | + |
| 14 | +# Ensemble Microcanonical Adjusted-Unadjusted Sampler (EMAUS) |
| 15 | + |
| 16 | +MCMC algorithms can be run in parallel (ensemble), in the sense of running multiple chains at once. During the phase where all chain have converged to the typical set, this parallelism improves wallclock time by a factor of the number of chains. This is because each chain draws samples just as well as any other, so we get more samples in the same time. |
| 17 | + |
| 18 | +Reaching the typical set, on the other hand, is not as easily parallelizable, and for ensemble methods, this is the bottlenech. EMAUS is one algorithm, based on (microcanonical)[https://blackjax-devs.github.io/sampling-book/algorithms/mclmc.html] dynamics, designed to target this problem. |
| 19 | + |
| 20 | +The idea is to run a batch (or ensemble) of chains of microcanonical dynamics without MH adjustment first, and based on convergence diagnostics, to switch all the chains to be adjusted. Without adjustment, microcanonical dynamics converge fast to the target, and with adjustment, the chains are guaranteed to be asymptotically unbiased. |
| 21 | + |
| 22 | +This code is designed to be run on GPU, and even across multiple nodes. |
| 23 | + |
| 24 | +```{code-cell} ipython3 |
| 25 | +
|
| 26 | +import jax |
| 27 | +import jax.numpy as jnp |
| 28 | +jax.config.update("jax_enable_x64", True) |
| 29 | +from blackjax.adaptation.ensemble_mclmc import emaus |
| 30 | +
|
| 31 | +
|
| 32 | +
|
| 33 | +mesh = jax.sharding.Mesh(jax.devices(), 'chains') |
| 34 | +
|
| 35 | +sample_init = lambda key: jax.random.normal(key, shape=(2,)) * jnp.array([10.0, 5.0]) * 2 |
| 36 | +
|
| 37 | +def logdensity_fn(x): |
| 38 | + mu2 = 0.03 * (x[0] ** 2 - 100) |
| 39 | + return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) |
| 40 | +
|
| 41 | + |
| 42 | +
|
| 43 | +def run_emaus( |
| 44 | + chains= 4096, |
| 45 | + alpha = 1.9, C= 0.1, |
| 46 | + early_stop=1, |
| 47 | + r_end= 1e-2, # switch parameters |
| 48 | + diagonal_preconditioning= 1, |
| 49 | + steps_per_sample= 15, |
| 50 | + acc_prob= None # adjusted parameters |
| 51 | + ): |
| 52 | + |
| 53 | + key = jax.random.split(jax.random.key(42), 100)[2] |
| 54 | + |
| 55 | + info, grads_per_step, _acc_prob, final_state = emaus( |
| 56 | + |
| 57 | + logdensity_fn=logdensity_fn, |
| 58 | + sample_init=sample_init, |
| 59 | + ndims=2, |
| 60 | + num_steps1=100, |
| 61 | + num_steps2=300, |
| 62 | + num_chains=chains, |
| 63 | + mesh=mesh, |
| 64 | + rng_key=key, |
| 65 | + alpha= alpha, |
| 66 | + C= C, |
| 67 | + early_stop= early_stop, |
| 68 | + r_end= r_end, |
| 69 | + diagonal_preconditioning= diagonal_preconditioning, |
| 70 | + integrator_coefficients= None, |
| 71 | + steps_per_sample= steps_per_sample, |
| 72 | + acc_prob= acc_prob, |
| 73 | + ensemble_observables= lambda x: x |
| 74 | + ) |
| 75 | + |
| 76 | + return final_state.position |
| 77 | +
|
| 78 | +samples = run_emaus() |
| 79 | +``` |
| 80 | + |
| 81 | +The above code runs EMAUS with 4096 chains, on a banana shaped density function, and returns only the final step of each chain. These can be plotted: |
| 82 | + |
| 83 | +```{code-cell} ipython3 |
| 84 | +
|
| 85 | +import seaborn as sns |
| 86 | +sns.scatterplot(x= samples[:, 0], y= samples[:, 1], alpha= 0.1) |
| 87 | +``` |
0 commit comments