Skip to content

Commit fa175ca

Browse files
committed
Add gaussian mixture illustration
1 parent d6e3a1e commit fa175ca

File tree

3 files changed

+279
-167
lines changed

3 files changed

+279
-167
lines changed

gaussian_mixture.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
from functools import partial
2+
import utils.toy_plot_helpers as toy
3+
from flax import linen as nn
4+
from flax.training import train_state
5+
import optax
6+
import jax
7+
import jax.numpy as jnp
8+
from tqdm import trange
9+
import matplotlib.pyplot as plt
10+
import os
11+
12+
13+
@jax.jit
14+
def U(xs, beta=1.0):
15+
x, y = xs[:, 0], xs[:, 1]
16+
borders = x ** 6 + y ** 6
17+
e1 = +2.0 * jnp.exp(-(12.0 * (x - 0.00) ** 2 + 12.0 * (y - 0.00) ** 2))
18+
e2 = -1.0 * jnp.exp(-(12.0 * (x + 0.50) ** 2 + 12.0 * (y + 0.00) ** 2))
19+
e3 = -1.0 * jnp.exp(-(12.0 * (x - 0.50) ** 2 + 12.0 * (y + 0.00) ** 2))
20+
return beta * (borders + e1 + e2 + e3)
21+
22+
23+
dUdx_fn = jax.jit(jax.grad(lambda _x: U(_x).sum()))
24+
25+
plot_energy_surface = partial(toy.plot_energy_surface, U, [], jnp.array((-1, 1)), jnp.array((-1, 1)), levels=20)
26+
27+
28+
def create_mlp_q(A, B, T, num_mixtures):
29+
class MLPq(nn.Module):
30+
@nn.compact
31+
def __call__(self, t):
32+
"""
33+
in_shape: (batch, t)
34+
out_shape: (batch, num_mixtures, data)
35+
"""
36+
t = t / T
37+
h = nn.Dense(128)(t - 0.5)
38+
h = nn.swish(h)
39+
h = nn.Dense(128)(h)
40+
h = nn.swish(h)
41+
h = nn.Dense(128)(h)
42+
h = nn.swish(h)
43+
h = nn.Dense(2 * num_mixtures + num_mixtures)(h)
44+
45+
mu = (((1 - t) * A)[:, None, :] + (t * B)[:, None, :] +
46+
((1 - t) * t * h[:, :2 * num_mixtures]).reshape(-1, num_mixtures, A.shape[-1]))
47+
sigma = (1 - t) * 1e-2 * 2.5 + t * 1e-2 * 2.5 + (1 - t) * t * jnp.exp(h[:, 2 * num_mixtures:])
48+
return mu, sigma[:, :, None]
49+
50+
return MLPq(), jnp.zeros(num_mixtures)
51+
52+
53+
def train(q, w_logits, epochs):
54+
BS = 512
55+
key = jax.random.PRNGKey(1)
56+
key, *init_key = jax.random.split(key, 3)
57+
params_q = q.init(init_key[0], jnp.ones([BS, 1]))
58+
59+
optimizer_q = optax.adam(learning_rate=1e-4)
60+
state_q = train_state.TrainState.create(apply_fn=q.apply,
61+
params=params_q,
62+
tx=optimizer_q)
63+
64+
def loss_fn(params_q, key):
65+
key = jax.random.split(key)
66+
67+
t = T * jax.random.uniform(key[0], [BS, 1])
68+
eps = jax.random.normal(key[1], [BS, 1, A.shape[-1]])
69+
i = jax.random.categorical(key[2], w_logits, shape=(BS,))
70+
71+
mu_t = lambda _t: state_q.apply_fn(params_q, _t)[0]
72+
sigma_t = lambda _t: state_q.apply_fn(params_q, _t)[1]
73+
74+
def dmudt(_t):
75+
_dmudt = jax.jacrev(lambda _t: mu_t(_t).sum(0).T)
76+
return _dmudt(_t).squeeze(axis=-1).T
77+
78+
def dsigmadt(_t):
79+
_dsigmadt = jax.jacrev(lambda _t: sigma_t(_t).sum(0).T)
80+
return _dsigmadt(_t).squeeze(axis=-1).T
81+
82+
dUdx_fn = jax.grad(lambda _x: U(_x).sum())
83+
84+
def v_t(_eps, _t, _i, _w_logits):
85+
"""This function is equal to v_t * xi ** 2."""
86+
_mu_t = mu_t(_t)
87+
_sigma_t = sigma_t(_t)
88+
_x = _mu_t[jnp.arange(BS), _i, None] + _sigma_t[jnp.arange(BS), _i, None] * eps
89+
90+
log_q_i = jax.scipy.stats.norm.logpdf(_x, _mu_t, _sigma_t).sum(-1)
91+
relative_mixture_weights = jax.nn.softmax(_w_logits + log_q_i)[:, :, None]
92+
93+
log_q_t = -(relative_mixture_weights / (_sigma_t ** 2) * (_x - _mu_t)).sum(axis=1)
94+
u_t = (relative_mixture_weights * (1 / _sigma_t * dsigmadt(_t) * (_x - _mu_t) + dmudt(_t))).sum(axis=1)
95+
b_t = -dUdx_fn(_x.reshape(BS, A.shape[-1]))
96+
97+
return u_t - b_t + 0.5 * (xi ** 2) * log_q_t
98+
99+
loss = 0.5 * ((v_t(eps, t, i, w_logits) / xi) ** 2).sum(-1, keepdims=True)
100+
print(loss.shape, 'loss.shape', flush=True)
101+
return loss.mean()
102+
103+
@jax.jit
104+
def train_step(state_q, key):
105+
grad_fn = jax.value_and_grad(loss_fn, argnums=0)
106+
loss, grads = grad_fn(state_q.params, key)
107+
state_q = state_q.apply_gradients(grads=grads)
108+
return state_q, loss
109+
110+
key, loc_key = jax.random.split(key)
111+
state_q, loss = train_step(state_q, loc_key)
112+
113+
loss_plot = []
114+
for _ in trange(epochs):
115+
key, loc_key = jax.random.split(key)
116+
state_q, loss = train_step(state_q, loc_key)
117+
loss_plot.append(loss)
118+
119+
return state_q, loss_plot
120+
121+
122+
def draw_samples_qt(state_q, w_logits, num_samples, T, A, key):
123+
num_mixtures = len(w_logits)
124+
w = jax.nn.softmax(w_logits)[None, :, None]
125+
t = T * jnp.linspace(0, 1, num_samples).reshape((-1, 1))
126+
eps = jax.random.normal(key, [num_samples, num_mixtures, A.shape[-1]])
127+
mu_t, sigma_t = state_q.apply_fn(state_q.params, t)
128+
return (w * (mu_t + sigma_t * eps)).sum(axis=1)
129+
130+
131+
def sample_stochastically(state_q, w_logits, num_samples, T, A, xi, dt, key):
132+
mu_t = lambda _t: state_q.apply_fn(state_q.params, _t)[0]
133+
sigma_t = lambda _t: state_q.apply_fn(state_q.params, _t)[1]
134+
135+
def dmudt(_t):
136+
_dmudt = jax.jacrev(lambda _t: mu_t(_t).sum(0).T, argnums=0)
137+
return _dmudt(_t).squeeze(axis=-1).T
138+
139+
def dsigmadt(_t):
140+
_dsigmadt = jax.jacrev(lambda _t: sigma_t(_t).sum(0).T)
141+
return _dsigmadt(_t).squeeze(axis=-1).T
142+
143+
@jax.jit
144+
def u_t(_t, _x):
145+
_mu_t = mu_t(_t)
146+
_sigma_t = sigma_t(_t)
147+
_x = _x[:, None, :]
148+
149+
log_q_i = jax.scipy.stats.norm.logpdf(_x, _mu_t, _sigma_t).sum(-1)
150+
relative_mixture_weights = jax.nn.softmax(w_logits + log_q_i)[:, :, None]
151+
152+
log_q_t = -(relative_mixture_weights / (_sigma_t ** 2) * (_x - _mu_t)).sum(axis=1)
153+
_u_t = (relative_mixture_weights * (1 / _sigma_t * dsigmadt(_t) * (_x - _mu_t) + dmudt(_t))).sum(axis=1)
154+
155+
return _u_t + 0.5 * (xi ** 2) * log_q_t
156+
157+
N = int(T / dt)
158+
159+
key, loc_key = jax.random.split(key)
160+
161+
x_t = jnp.ones((num_samples, N + 1, 2)) * A
162+
eps = jax.random.normal(key, shape=(num_samples, A.shape[-1]))
163+
x_t = x_t.at[:, 0, :].set(x_t[:, 0, :] + sigma_t(jnp.zeros((num_samples, 1)))[:, 0, :] * eps)
164+
165+
t = jnp.zeros((num_samples, 1))
166+
for i in trange(N):
167+
key, loc_key = jax.random.split(key)
168+
eps = jax.random.normal(key, shape=(num_samples, A.shape[-1]))
169+
170+
dx = dt * u_t(t, x_t[:, i, :]) + jnp.sqrt(dt) * xi * eps
171+
x_t = x_t.at[:, i + 1, :].set(x_t[:, i, :] + dx)
172+
t += dt
173+
174+
return x_t
175+
176+
177+
if __name__ == '__main__':
178+
savedir = './out/var_doobs/mixture'
179+
os.makedirs(savedir, exist_ok=True)
180+
181+
A = jnp.array([[-0.5, 0]])
182+
B = jnp.array([[0.5, 0]])
183+
dt = 5e-4
184+
T = 1.0
185+
xi = 0.1
186+
epochs = 20_000
187+
188+
q_single, w_logits_single = create_mlp_q(A, B, T, 1)
189+
q_mixture, w_logits_mixture = create_mlp_q(A, B, T, 2)
190+
191+
state_q_single, loss_plot_single = train(q_single, w_logits_single, epochs=epochs)
192+
state_q_mixture, loss_plot_mixture = train(q_mixture, w_logits_mixture, epochs=epochs)
193+
194+
plt.plot(loss_plot_single, label='single')
195+
plt.plot(loss_plot_mixture, label='mixture')
196+
plt.legend()
197+
plt.show()
198+
199+
samples_qt_single = draw_samples_qt(state_q_single, w_logits_single, num_samples=1000, T=T, A=A,
200+
key=jax.random.PRNGKey(0))
201+
samples_qt_mixture = draw_samples_qt(state_q_mixture, w_logits_mixture, num_samples=1000, T=T, A=A,
202+
key=jax.random.PRNGKey(0))
203+
204+
plot_energy_surface()
205+
plt.scatter(samples_qt_single[:, 0], samples_qt_single[:, 1], label='single')
206+
plt.scatter(samples_qt_mixture[:, 0], samples_qt_mixture[:, 1], label='mixture')
207+
plt.legend()
208+
plt.show()
209+
210+
samples_single = sample_stochastically(state_q_single, w_logits_single, num_samples=1000, T=T, A=A, xi=xi, dt=dt,
211+
key=jax.random.PRNGKey(0))
212+
samples_mixture = sample_stochastically(state_q_mixture, w_logits_mixture, num_samples=1000, T=T, A=A, xi=xi, dt=dt,
213+
key=jax.random.PRNGKey(0))
214+
215+
plot_energy_surface(trajectories=samples_single)
216+
plt.savefig(f'{savedir}/toy-gaussian-single.pdf', bbox_inches='tight')
217+
plt.show()
218+
219+
plot_energy_surface(trajectories=samples_mixture)
220+
plt.savefig(f'{savedir}/toy-gaussian-mixture.pdf', bbox_inches='tight')
221+
plt.show()
222+
223+
t = T * jnp.linspace(0, 1, 10 * int(T / dt)).reshape((-1, 1))
224+
_mu_t_single, _sigma_t_single = state_q_single.apply_fn(state_q_single.params, t)
225+
_mu_t_mixture, _sigma_t_mixture = state_q_mixture.apply_fn(state_q_mixture.params, t)
226+
227+
vmin = min(_sigma_t_single.min(), _sigma_t_mixture.min())
228+
vmax = max(_sigma_t_single.max(), _sigma_t_mixture.max())
229+
230+
plot_energy_surface()
231+
plt.scatter(_mu_t_single[:, :, 0], _mu_t_single[:, :, 1], c=_sigma_t_single, vmin=vmin, vmax=vmax, rasterized=True)
232+
plt.colorbar(label=r'$\sigma$')
233+
plt.savefig(f'{savedir}/toy-gaussian-single-mu.pdf', bbox_inches='tight')
234+
plt.show()
235+
236+
plot_energy_surface()
237+
plt.scatter(_mu_t_mixture[:, :, 0], _mu_t_mixture[:, :, 1], c=_sigma_t_mixture, vmin=vmin, vmax=vmax, rasterized=True)
238+
plt.colorbar(label=r'$\sigma$')
239+
plt.savefig(f'{savedir}/toy-gaussian-mixture-mu.pdf', bbox_inches='tight')
240+
plt.show()

tps_baseline_mueller.py

Lines changed: 5 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import json
2+
from functools import partial
23

34
import jax
45
import jax.numpy as jnp
56
import os
67
import matplotlib.pyplot as plt
78
from tps import first_order as tps1
89
import numpy as np
10+
import utils.toy_plot_helpers as toy
911

1012
minima_points = jnp.array([[-0.55828035, 1.44169], [-0.05004308, 0.46666032], [0.62361133, 0.02804632]])
1113
A, B = minima_points[None, 0], minima_points[None, 2]
@@ -44,51 +46,8 @@ def interpolate_two_points(start, stop, steps):
4446
return interpolation
4547

4648

47-
def plot_energy_surface(points=[], trajectories=[], bins=150, alpha=0.7):
48-
xlim, ylim = jnp.array((-1.5, 0.9)), jnp.array((-0.5, 1.7))
49-
50-
x, y = jnp.linspace(xlim[0], xlim[1], bins), jnp.linspace(ylim[0], ylim[1], bins)
51-
x, y = jnp.meshgrid(x, y, indexing='ij')
52-
z = U(jnp.stack([x, y], -1).reshape(-1, 2)).reshape([bins, bins])
53-
54-
# black and white contour plot
55-
plt.contour(x, y, z, levels=30, cmap='gray')
56-
57-
plt.xlim(xlim[0], xlim[1])
58-
plt.ylim(ylim[0], ylim[1])
59-
60-
if len(trajectories) > 0:
61-
from openpathsampling.analysis import PathHistogram
62-
from openpathsampling.numerics import HistogramPlotter2D
63-
64-
hist = PathHistogram(
65-
left_bin_edges=(xlim[0], ylim[0]),
66-
bin_widths=(jnp.diff(xlim)[0] / bins, jnp.diff(ylim)[0] / bins),
67-
interpolate=True, per_traj=True
68-
)
69-
70-
[hist.add_trajectory(t) for t in trajectories]
71-
72-
plotter = HistogramPlotter2D(hist, xlim=xlim, ylim=ylim)
73-
df = hist().df_2d(x_range=plotter.xrange_, y_range=plotter.yrange_)
74-
plt.pcolormesh(
75-
jnp.linspace(xlim[0], xlim[1], df.shape[0]),
76-
jnp.linspace(ylim[0], ylim[1], df.shape[1]),
77-
df.values.T.astype(dtype=float),
78-
vmin=0, vmax=3, cmap='Blues',
79-
rasterized=True
80-
)
81-
82-
plt.colorbar()
83-
84-
for p in points:
85-
plt.scatter(p[0], p[1], marker='*')
86-
87-
for name, pos in zip(['A', 'B', 'C'], minima_points):
88-
c = plt.Circle(pos, radius=0.1, edgecolor='gray', alpha=alpha, facecolor='white', ls='--', lw=0.7)
89-
plt.gca().add_patch(c)
90-
plt.gca().annotate(name, xy=pos, ha="center", va="center")
91-
49+
plot_energy_surface = partial(toy.plot_energy_surface, U=U, states=zip(['A', 'B', 'C'], minima_points),
50+
xlim=jnp.array((-1.5, 0.9)), ylim=jnp.array((-0.5, 1.7)))
9251

9352
if __name__ == '__main__':
9453
savedir = f"out/baselines/mueller"
@@ -101,6 +60,7 @@ def plot_energy_surface(points=[], trajectories=[], bins=150, alpha=0.7):
10160
N = int(T / dt)
10261
initial_trajectory = [t.reshape(1, 2) for t in interpolate(minima_points, N)]
10362

63+
10464
@jax.jit
10565
def step(_x, _key):
10666
"""Perform one step of forward euler"""

0 commit comments

Comments
 (0)