Skip to content

Commit 4a2ed34

Browse files
committed
add emaus docs
1 parent a3cbaec commit 4a2ed34

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed

book/algorithms/emaus.md

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
---
2+
jupytext:
3+
text_representation:
4+
extension: .md
5+
format_name: myst
6+
format_version: 0.13
7+
jupytext_version: 1.16.0
8+
kernelspec:
9+
display_name: mclmc
10+
language: python
11+
name: python3
12+
---
13+
14+
# Ensemble Microcanonical Adjusted-Unadjusted Sampler (EMAUS)
15+
16+
MCMC algorithms can be run in parallel (ensemble), in the sense of running multiple chains at once. During the phase where all chain have converged to the typical set, this parallelism improves wallclock time by a factor of the number of chains. This is because each chain draws samples just as well as any other, so we get more samples in the same time.
17+
18+
Reaching the typical set, on the other hand, is not as easily parallelizable, and for ensemble methods, this is the bottlenech. EMAUS is one algorithm, based on (microcanonical)[https://blackjax-devs.github.io/sampling-book/algorithms/mclmc.html] dynamics, designed to target this problem.
19+
20+
The idea is to run a batch (or ensemble) of chains of microcanonical dynamics without MH adjustment first, and based on convergence diagnostics, to switch all the chains to be adjusted. Without adjustment, microcanonical dynamics converge fast to the target, and with adjustment, the chains are guaranteed to be asymptotically unbiased.
21+
22+
This code is designed to be run on GPU, and even across multiple nodes.
23+
24+
```{code-cell} ipython3
25+
26+
import jax
27+
import jax.numpy as jnp
28+
jax.config.update("jax_enable_x64", True)
29+
from blackjax.adaptation.ensemble_mclmc import emaus
30+
31+
32+
33+
mesh = jax.sharding.Mesh(jax.devices(), 'chains')
34+
35+
sample_init = lambda key: jax.random.normal(key, shape=(2,)) * jnp.array([10.0, 5.0]) * 2
36+
37+
def logdensity_fn(x):
38+
mu2 = 0.03 * (x[0] ** 2 - 100)
39+
return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2))
40+
41+
42+
43+
def run_emaus(
44+
chains= 4096,
45+
alpha = 1.9, C= 0.1,
46+
early_stop=1,
47+
r_end= 1e-2, # switch parameters
48+
diagonal_preconditioning= 1,
49+
steps_per_sample= 15,
50+
acc_prob= None # adjusted parameters
51+
):
52+
53+
key = jax.random.split(jax.random.key(42), 100)[2]
54+
55+
info, grads_per_step, _acc_prob, final_state = emaus(
56+
57+
logdensity_fn=logdensity_fn,
58+
sample_init=sample_init,
59+
ndims=2,
60+
num_steps1=100,
61+
num_steps2=300,
62+
num_chains=chains,
63+
mesh=mesh,
64+
rng_key=key,
65+
alpha= alpha,
66+
C= C,
67+
early_stop= early_stop,
68+
r_end= r_end,
69+
diagonal_preconditioning= diagonal_preconditioning,
70+
integrator_coefficients= None,
71+
steps_per_sample= steps_per_sample,
72+
acc_prob= acc_prob,
73+
ensemble_observables= lambda x: x
74+
)
75+
76+
return final_state.position
77+
78+
samples = run_emaus()
79+
```
80+
81+
The above code runs EMAUS with 4096 chains, on a banana shaped density function, and returns only the final step of each chain. These can be plotted:
82+
83+
```{code-cell} ipython3
84+
85+
import seaborn as sns
86+
sns.scatterplot(x= samples[:, 0], y= samples[:, 1], alpha= 0.1)
87+
```

0 commit comments

Comments
 (0)