Skip to content

Commit 69e238e

Browse files
author
Yuanqi Du
committed
Merge branch 'main' of github.com:plainerman/TPS-Flow-JAX-Demo into main
2 parents 94854a7 + 7862d89 commit 69e238e

File tree

10 files changed

+1220
-275
lines changed

10 files changed

+1220
-275
lines changed

eval/path_metrics.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from tqdm import tqdm
44

55

6-
def plot_path_energy(paths, U, reduce=jnp.max, already_ln=False):
7-
reduced = jnp.array([reduce(U(path)) for path in tqdm(paths, 'Computing path metric')])
6+
def plot_path_energy(paths, U, reduce=jnp.max, add=0, already_ln=False, **kwargs):
7+
reduced = jnp.array([reduce(U(path)) for path in paths]) + add
88

99
if already_ln:
1010
# Convert reduced to log10
1111
reduced = reduced / jnp.log(10)
12-
plt.plot(jnp.arange(0, len(reduced), 1), reduced)
12+
plt.plot(jnp.arange(0, len(reduced), 1), reduced, **kwargs)
1313
else:
14-
plt.semilogy(jnp.arange(0, len(reduced), 1), reduced)
14+
plt.semilogy(jnp.arange(0, len(reduced), 1), reduced, **kwargs)

evaluate_mueller.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import numpy as np
2+
import jax.numpy as jnp
3+
import jax
4+
from eval.path_metrics import plot_path_energy
5+
from tps_baseline_mueller import U, dUdx_fn, minima_points
6+
from scipy.optimize import minimize
7+
import matplotlib.pyplot as plt
8+
import os
9+
10+
num_paths = 1000
11+
xi = 5
12+
kbT = xi ** 2 / 2
13+
dt = 1e-4
14+
T = 275e-4
15+
N = int(T / dt)
16+
17+
18+
def load(path):
19+
return jnp.array(np.load(path, allow_pickle=True).astype(np.float32)).squeeze()
20+
21+
22+
@jax.jit
23+
def log_prob_path(path):
24+
rand = path[1:] - path[:-1] + dt * dUdx_fn(path[:-1])
25+
return U(path[0]) / kbT + jax.scipy.stats.norm.logpdf(rand, scale=jnp.sqrt(dt) * xi).sum()
26+
27+
28+
if __name__ == '__main__':
29+
savedir = './out/evaluation/mueller/'
30+
os.makedirs(savedir, exist_ok=True)
31+
32+
all_paths = [
33+
('one-way-shooting', './out/baselines/mueller/paths-one-way-shooting.npy'),
34+
('two-way-shooting', './out/baselines/mueller/paths-two-way-shooting.npy'),
35+
('var-doobs', './out/var_doobs/mueller/paths.npy'),
36+
]
37+
38+
global_minimum_energy = U(minima_points[0])
39+
for point in minima_points:
40+
global_minimum_energy = min(global_minimum_energy, minimize(U, point).fun)
41+
print("Global minimum energy", global_minimum_energy)
42+
43+
all_paths = [(name, load(path)) for name, path in all_paths]
44+
[print(name, path.shape) for name, path in all_paths]
45+
46+
for name, paths in all_paths:
47+
plot_path_energy(paths, U, add=-global_minimum_energy, label=name)
48+
49+
plt.legend()
50+
plt.ylabel('Maximum energy')
51+
plt.savefig(f'{savedir}/mueller-max-energy.pdf', bbox_inches='tight')
52+
plt.show()
53+
54+
for name, paths in all_paths:
55+
plot_path_energy(paths, U, add=-global_minimum_energy, reduce=jnp.median, label=name)
56+
57+
plt.legend()
58+
plt.ylabel('Median energy')
59+
plt.savefig(f'{savedir}/mueller-median-energy.pdf', bbox_inches='tight')
60+
plt.show()
61+
62+
for name, paths in all_paths:
63+
plot_path_energy(paths, log_prob_path, reduce=lambda x: x, label=name)
64+
print('Median log-likelihood of:', name, jnp.median(jnp.array([log_prob_path(path) for path in paths])))
65+
66+
plt.legend()
67+
plt.ylabel('log path likelihood')
68+
plt.savefig(f'{savedir}/mueller-log-path-likelihood.pdf', bbox_inches='tight')
69+
plt.show()

gaussian_mixture.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
from functools import partial
2+
import utils.toy_plot_helpers as toy
3+
from flax import linen as nn
4+
from flax.training import train_state
5+
import optax
6+
import jax
7+
import jax.numpy as jnp
8+
from tqdm import trange
9+
import matplotlib.pyplot as plt
10+
import os
11+
12+
13+
@jax.jit
14+
def U(xs, beta=1.0):
15+
x, y = xs[:, 0], xs[:, 1]
16+
borders = x ** 6 + y ** 6
17+
e1 = +2.0 * jnp.exp(-(12.0 * (x - 0.00) ** 2 + 12.0 * (y - 0.00) ** 2))
18+
e2 = -1.0 * jnp.exp(-(12.0 * (x + 0.50) ** 2 + 12.0 * (y + 0.00) ** 2))
19+
e3 = -1.0 * jnp.exp(-(12.0 * (x - 0.50) ** 2 + 12.0 * (y + 0.00) ** 2))
20+
return beta * (borders + e1 + e2 + e3)
21+
22+
23+
dUdx_fn = jax.jit(jax.grad(lambda _x: U(_x).sum()))
24+
25+
plot_energy_surface = partial(toy.plot_energy_surface, U, [], jnp.array((-1, 1)), jnp.array((-1, 1)), levels=20)
26+
27+
28+
def create_mlp_q(A, B, T, num_mixtures):
29+
class MLPq(nn.Module):
30+
@nn.compact
31+
def __call__(self, t):
32+
"""
33+
in_shape: (batch, t)
34+
out_shape: (batch, num_mixtures, data)
35+
"""
36+
t = t / T
37+
h = nn.Dense(128)(t - 0.5)
38+
h = nn.swish(h)
39+
h = nn.Dense(128)(h)
40+
h = nn.swish(h)
41+
h = nn.Dense(128)(h)
42+
h = nn.swish(h)
43+
h = nn.Dense(2 * num_mixtures + num_mixtures)(h)
44+
45+
mu = (((1 - t) * A)[:, None, :] + (t * B)[:, None, :] +
46+
((1 - t) * t * h[:, :2 * num_mixtures]).reshape(-1, num_mixtures, A.shape[-1]))
47+
sigma = (1 - t) * 1e-2 * 2.5 + t * 1e-2 * 2.5 + (1 - t) * t * jnp.exp(h[:, 2 * num_mixtures:])
48+
return mu, sigma[:, :, None]
49+
50+
return MLPq(), jnp.zeros(num_mixtures)
51+
52+
53+
def train(q, w_logits, epochs):
54+
BS = 512
55+
key = jax.random.PRNGKey(1)
56+
key, *init_key = jax.random.split(key, 3)
57+
params_q = q.init(init_key[0], jnp.ones([BS, 1]))
58+
59+
optimizer_q = optax.adam(learning_rate=1e-4)
60+
state_q = train_state.TrainState.create(apply_fn=q.apply,
61+
params=params_q,
62+
tx=optimizer_q)
63+
64+
def loss_fn(params_q, key):
65+
key = jax.random.split(key)
66+
67+
t = T * jax.random.uniform(key[0], [BS, 1])
68+
eps = jax.random.normal(key[1], [BS, 1, A.shape[-1]])
69+
i = jax.random.categorical(key[2], w_logits, shape=(BS,))
70+
71+
mu_t = lambda _t: state_q.apply_fn(params_q, _t)[0]
72+
sigma_t = lambda _t: state_q.apply_fn(params_q, _t)[1]
73+
74+
def dmudt(_t):
75+
_dmudt = jax.jacrev(lambda _t: mu_t(_t).sum(0).T)
76+
return _dmudt(_t).squeeze(axis=-1).T
77+
78+
def dsigmadt(_t):
79+
_dsigmadt = jax.jacrev(lambda _t: sigma_t(_t).sum(0).T)
80+
return _dsigmadt(_t).squeeze(axis=-1).T
81+
82+
dUdx_fn = jax.grad(lambda _x: U(_x).sum())
83+
84+
def v_t(_eps, _t, _i, _w_logits):
85+
"""This function is equal to v_t * xi ** 2."""
86+
_mu_t = mu_t(_t)
87+
_sigma_t = sigma_t(_t)
88+
_x = _mu_t[jnp.arange(BS), _i, None] + _sigma_t[jnp.arange(BS), _i, None] * eps
89+
90+
log_q_i = jax.scipy.stats.norm.logpdf(_x, _mu_t, _sigma_t).sum(-1)
91+
relative_mixture_weights = jax.nn.softmax(_w_logits + log_q_i)[:, :, None]
92+
93+
log_q_t = -(relative_mixture_weights / (_sigma_t ** 2) * (_x - _mu_t)).sum(axis=1)
94+
u_t = (relative_mixture_weights * (1 / _sigma_t * dsigmadt(_t) * (_x - _mu_t) + dmudt(_t))).sum(axis=1)
95+
b_t = -dUdx_fn(_x.reshape(BS, A.shape[-1]))
96+
97+
return u_t - b_t + 0.5 * (xi ** 2) * log_q_t
98+
99+
loss = 0.5 * ((v_t(eps, t, i, w_logits) / xi) ** 2).sum(-1, keepdims=True)
100+
print(loss.shape, 'loss.shape', flush=True)
101+
return loss.mean()
102+
103+
@jax.jit
104+
def train_step(state_q, key):
105+
grad_fn = jax.value_and_grad(loss_fn, argnums=0)
106+
loss, grads = grad_fn(state_q.params, key)
107+
state_q = state_q.apply_gradients(grads=grads)
108+
return state_q, loss
109+
110+
key, loc_key = jax.random.split(key)
111+
state_q, loss = train_step(state_q, loc_key)
112+
113+
loss_plot = []
114+
for _ in trange(epochs):
115+
key, loc_key = jax.random.split(key)
116+
state_q, loss = train_step(state_q, loc_key)
117+
loss_plot.append(loss)
118+
119+
return state_q, loss_plot
120+
121+
122+
def draw_samples_qt(state_q, w_logits, num_samples, T, A, key):
123+
num_mixtures = len(w_logits)
124+
w = jax.nn.softmax(w_logits)[None, :, None]
125+
t = T * jnp.linspace(0, 1, num_samples).reshape((-1, 1))
126+
eps = jax.random.normal(key, [num_samples, num_mixtures, A.shape[-1]])
127+
mu_t, sigma_t = state_q.apply_fn(state_q.params, t)
128+
return (w * (mu_t + sigma_t * eps)).sum(axis=1)
129+
130+
131+
def sample_stochastically(state_q, w_logits, num_samples, T, A, xi, dt, key):
132+
mu_t = lambda _t: state_q.apply_fn(state_q.params, _t)[0]
133+
sigma_t = lambda _t: state_q.apply_fn(state_q.params, _t)[1]
134+
135+
def dmudt(_t):
136+
_dmudt = jax.jacrev(lambda _t: mu_t(_t).sum(0).T, argnums=0)
137+
return _dmudt(_t).squeeze(axis=-1).T
138+
139+
def dsigmadt(_t):
140+
_dsigmadt = jax.jacrev(lambda _t: sigma_t(_t).sum(0).T)
141+
return _dsigmadt(_t).squeeze(axis=-1).T
142+
143+
@jax.jit
144+
def u_t(_t, _x):
145+
_mu_t = mu_t(_t)
146+
_sigma_t = sigma_t(_t)
147+
_x = _x[:, None, :]
148+
149+
log_q_i = jax.scipy.stats.norm.logpdf(_x, _mu_t, _sigma_t).sum(-1)
150+
relative_mixture_weights = jax.nn.softmax(w_logits + log_q_i)[:, :, None]
151+
152+
log_q_t = -(relative_mixture_weights / (_sigma_t ** 2) * (_x - _mu_t)).sum(axis=1)
153+
_u_t = (relative_mixture_weights * (1 / _sigma_t * dsigmadt(_t) * (_x - _mu_t) + dmudt(_t))).sum(axis=1)
154+
155+
return _u_t + 0.5 * (xi ** 2) * log_q_t
156+
157+
N = int(T / dt)
158+
159+
key, loc_key = jax.random.split(key)
160+
161+
x_t = jnp.ones((num_samples, N + 1, 2)) * A
162+
eps = jax.random.normal(key, shape=(num_samples, A.shape[-1]))
163+
x_t = x_t.at[:, 0, :].set(x_t[:, 0, :] + sigma_t(jnp.zeros((num_samples, 1)))[:, 0, :] * eps)
164+
165+
t = jnp.zeros((num_samples, 1))
166+
for i in trange(N):
167+
key, loc_key = jax.random.split(key)
168+
eps = jax.random.normal(key, shape=(num_samples, A.shape[-1]))
169+
170+
dx = dt * u_t(t, x_t[:, i, :]) + jnp.sqrt(dt) * xi * eps
171+
x_t = x_t.at[:, i + 1, :].set(x_t[:, i, :] + dx)
172+
t += dt
173+
174+
return x_t
175+
176+
177+
if __name__ == '__main__':
178+
savedir = './out/var_doobs/mixture'
179+
os.makedirs(savedir, exist_ok=True)
180+
181+
A = jnp.array([[-0.5, 0]])
182+
B = jnp.array([[0.5, 0]])
183+
dt = 5e-4
184+
T = 1.0
185+
xi = 0.1
186+
epochs = 20_000
187+
188+
q_single, w_logits_single = create_mlp_q(A, B, T, 1)
189+
q_mixture, w_logits_mixture = create_mlp_q(A, B, T, 2)
190+
191+
state_q_single, loss_plot_single = train(q_single, w_logits_single, epochs=epochs)
192+
state_q_mixture, loss_plot_mixture = train(q_mixture, w_logits_mixture, epochs=epochs)
193+
194+
plt.plot(loss_plot_single, label='single')
195+
plt.plot(loss_plot_mixture, label='mixture')
196+
plt.legend()
197+
plt.show()
198+
199+
samples_qt_single = draw_samples_qt(state_q_single, w_logits_single, num_samples=1000, T=T, A=A,
200+
key=jax.random.PRNGKey(0))
201+
samples_qt_mixture = draw_samples_qt(state_q_mixture, w_logits_mixture, num_samples=1000, T=T, A=A,
202+
key=jax.random.PRNGKey(0))
203+
204+
plot_energy_surface()
205+
plt.scatter(samples_qt_single[:, 0], samples_qt_single[:, 1], label='single')
206+
plt.scatter(samples_qt_mixture[:, 0], samples_qt_mixture[:, 1], label='mixture')
207+
plt.legend()
208+
plt.show()
209+
210+
samples_single = sample_stochastically(state_q_single, w_logits_single, num_samples=1000, T=T, A=A, xi=xi, dt=dt,
211+
key=jax.random.PRNGKey(0))
212+
samples_mixture = sample_stochastically(state_q_mixture, w_logits_mixture, num_samples=1000, T=T, A=A, xi=xi, dt=dt,
213+
key=jax.random.PRNGKey(0))
214+
215+
plot_energy_surface(trajectories=samples_single)
216+
plt.savefig(f'{savedir}/toy-gaussian-single.pdf', bbox_inches='tight')
217+
plt.show()
218+
219+
plot_energy_surface(trajectories=samples_mixture)
220+
plt.savefig(f'{savedir}/toy-gaussian-mixture.pdf', bbox_inches='tight')
221+
plt.show()
222+
223+
t = T * jnp.linspace(0, 1, 10 * int(T / dt)).reshape((-1, 1))
224+
_mu_t_single, _sigma_t_single = state_q_single.apply_fn(state_q_single.params, t)
225+
_mu_t_mixture, _sigma_t_mixture = state_q_mixture.apply_fn(state_q_mixture.params, t)
226+
227+
vmin = min(_sigma_t_single.min(), _sigma_t_mixture.min())
228+
vmax = max(_sigma_t_single.max(), _sigma_t_mixture.max())
229+
230+
plot_energy_surface()
231+
plt.scatter(_mu_t_single[:, :, 0], _mu_t_single[:, :, 1], c=_sigma_t_single, vmin=vmin, vmax=vmax, rasterized=True)
232+
plt.colorbar(label=r'$\sigma$')
233+
plt.savefig(f'{savedir}/toy-gaussian-single-mu.pdf', bbox_inches='tight')
234+
plt.show()
235+
236+
plot_energy_surface()
237+
plt.scatter(_mu_t_mixture[:, :, 0], _mu_t_mixture[:, :, 1], c=_sigma_t_mixture, vmin=vmin, vmax=vmax, rasterized=True)
238+
plt.colorbar(label=r'$\sigma$')
239+
plt.savefig(f'{savedir}/toy-gaussian-mixture-mu.pdf', bbox_inches='tight')
240+
plt.show()

0 commit comments

Comments
 (0)