|
| 1 | +from functools import partial |
| 2 | +import utils.toy_plot_helpers as toy |
| 3 | +from flax import linen as nn |
| 4 | +from flax.training import train_state |
| 5 | +import optax |
| 6 | +import jax |
| 7 | +import jax.numpy as jnp |
| 8 | +from tqdm import trange |
| 9 | +import matplotlib.pyplot as plt |
| 10 | +import os |
| 11 | + |
| 12 | + |
| 13 | +@jax.jit |
| 14 | +def U(xs, beta=1.0): |
| 15 | + x, y = xs[:, 0], xs[:, 1] |
| 16 | + borders = x ** 6 + y ** 6 |
| 17 | + e1 = +2.0 * jnp.exp(-(12.0 * (x - 0.00) ** 2 + 12.0 * (y - 0.00) ** 2)) |
| 18 | + e2 = -1.0 * jnp.exp(-(12.0 * (x + 0.50) ** 2 + 12.0 * (y + 0.00) ** 2)) |
| 19 | + e3 = -1.0 * jnp.exp(-(12.0 * (x - 0.50) ** 2 + 12.0 * (y + 0.00) ** 2)) |
| 20 | + return beta * (borders + e1 + e2 + e3) |
| 21 | + |
| 22 | + |
| 23 | +dUdx_fn = jax.jit(jax.grad(lambda _x: U(_x).sum())) |
| 24 | + |
| 25 | +plot_energy_surface = partial(toy.plot_energy_surface, U, [], jnp.array((-1, 1)), jnp.array((-1, 1)), levels=20) |
| 26 | + |
| 27 | + |
| 28 | +def create_mlp_q(A, B, T, num_mixtures): |
| 29 | + class MLPq(nn.Module): |
| 30 | + @nn.compact |
| 31 | + def __call__(self, t): |
| 32 | + """ |
| 33 | + in_shape: (batch, t) |
| 34 | + out_shape: (batch, num_mixtures, data) |
| 35 | + """ |
| 36 | + t = t / T |
| 37 | + h = nn.Dense(128)(t - 0.5) |
| 38 | + h = nn.swish(h) |
| 39 | + h = nn.Dense(128)(h) |
| 40 | + h = nn.swish(h) |
| 41 | + h = nn.Dense(128)(h) |
| 42 | + h = nn.swish(h) |
| 43 | + h = nn.Dense(2 * num_mixtures + num_mixtures)(h) |
| 44 | + |
| 45 | + mu = (((1 - t) * A)[:, None, :] + (t * B)[:, None, :] + |
| 46 | + ((1 - t) * t * h[:, :2 * num_mixtures]).reshape(-1, num_mixtures, A.shape[-1])) |
| 47 | + sigma = (1 - t) * 1e-2 * 2.5 + t * 1e-2 * 2.5 + (1 - t) * t * jnp.exp(h[:, 2 * num_mixtures:]) |
| 48 | + return mu, sigma[:, :, None] |
| 49 | + |
| 50 | + return MLPq(), jnp.zeros(num_mixtures) |
| 51 | + |
| 52 | + |
| 53 | +def train(q, w_logits, epochs): |
| 54 | + BS = 512 |
| 55 | + key = jax.random.PRNGKey(1) |
| 56 | + key, *init_key = jax.random.split(key, 3) |
| 57 | + params_q = q.init(init_key[0], jnp.ones([BS, 1])) |
| 58 | + |
| 59 | + optimizer_q = optax.adam(learning_rate=1e-4) |
| 60 | + state_q = train_state.TrainState.create(apply_fn=q.apply, |
| 61 | + params=params_q, |
| 62 | + tx=optimizer_q) |
| 63 | + |
| 64 | + def loss_fn(params_q, key): |
| 65 | + key = jax.random.split(key) |
| 66 | + |
| 67 | + t = T * jax.random.uniform(key[0], [BS, 1]) |
| 68 | + eps = jax.random.normal(key[1], [BS, 1, A.shape[-1]]) |
| 69 | + i = jax.random.categorical(key[2], w_logits, shape=(BS,)) |
| 70 | + |
| 71 | + mu_t = lambda _t: state_q.apply_fn(params_q, _t)[0] |
| 72 | + sigma_t = lambda _t: state_q.apply_fn(params_q, _t)[1] |
| 73 | + |
| 74 | + def dmudt(_t): |
| 75 | + _dmudt = jax.jacrev(lambda _t: mu_t(_t).sum(0).T) |
| 76 | + return _dmudt(_t).squeeze(axis=-1).T |
| 77 | + |
| 78 | + def dsigmadt(_t): |
| 79 | + _dsigmadt = jax.jacrev(lambda _t: sigma_t(_t).sum(0).T) |
| 80 | + return _dsigmadt(_t).squeeze(axis=-1).T |
| 81 | + |
| 82 | + dUdx_fn = jax.grad(lambda _x: U(_x).sum()) |
| 83 | + |
| 84 | + def v_t(_eps, _t, _i, _w_logits): |
| 85 | + """This function is equal to v_t * xi ** 2.""" |
| 86 | + _mu_t = mu_t(_t) |
| 87 | + _sigma_t = sigma_t(_t) |
| 88 | + _x = _mu_t[jnp.arange(BS), _i, None] + _sigma_t[jnp.arange(BS), _i, None] * eps |
| 89 | + |
| 90 | + log_q_i = jax.scipy.stats.norm.logpdf(_x, _mu_t, _sigma_t).sum(-1) |
| 91 | + relative_mixture_weights = jax.nn.softmax(_w_logits + log_q_i)[:, :, None] |
| 92 | + |
| 93 | + log_q_t = -(relative_mixture_weights / (_sigma_t ** 2) * (_x - _mu_t)).sum(axis=1) |
| 94 | + u_t = (relative_mixture_weights * (1 / _sigma_t * dsigmadt(_t) * (_x - _mu_t) + dmudt(_t))).sum(axis=1) |
| 95 | + b_t = -dUdx_fn(_x.reshape(BS, A.shape[-1])) |
| 96 | + |
| 97 | + return u_t - b_t + 0.5 * (xi ** 2) * log_q_t |
| 98 | + |
| 99 | + loss = 0.5 * ((v_t(eps, t, i, w_logits) / xi) ** 2).sum(-1, keepdims=True) |
| 100 | + print(loss.shape, 'loss.shape', flush=True) |
| 101 | + return loss.mean() |
| 102 | + |
| 103 | + @jax.jit |
| 104 | + def train_step(state_q, key): |
| 105 | + grad_fn = jax.value_and_grad(loss_fn, argnums=0) |
| 106 | + loss, grads = grad_fn(state_q.params, key) |
| 107 | + state_q = state_q.apply_gradients(grads=grads) |
| 108 | + return state_q, loss |
| 109 | + |
| 110 | + key, loc_key = jax.random.split(key) |
| 111 | + state_q, loss = train_step(state_q, loc_key) |
| 112 | + |
| 113 | + loss_plot = [] |
| 114 | + for _ in trange(epochs): |
| 115 | + key, loc_key = jax.random.split(key) |
| 116 | + state_q, loss = train_step(state_q, loc_key) |
| 117 | + loss_plot.append(loss) |
| 118 | + |
| 119 | + return state_q, loss_plot |
| 120 | + |
| 121 | + |
| 122 | +def draw_samples_qt(state_q, w_logits, num_samples, T, A, key): |
| 123 | + num_mixtures = len(w_logits) |
| 124 | + w = jax.nn.softmax(w_logits)[None, :, None] |
| 125 | + t = T * jnp.linspace(0, 1, num_samples).reshape((-1, 1)) |
| 126 | + eps = jax.random.normal(key, [num_samples, num_mixtures, A.shape[-1]]) |
| 127 | + mu_t, sigma_t = state_q.apply_fn(state_q.params, t) |
| 128 | + return (w * (mu_t + sigma_t * eps)).sum(axis=1) |
| 129 | + |
| 130 | + |
| 131 | +def sample_stochastically(state_q, w_logits, num_samples, T, A, xi, dt, key): |
| 132 | + mu_t = lambda _t: state_q.apply_fn(state_q.params, _t)[0] |
| 133 | + sigma_t = lambda _t: state_q.apply_fn(state_q.params, _t)[1] |
| 134 | + |
| 135 | + def dmudt(_t): |
| 136 | + _dmudt = jax.jacrev(lambda _t: mu_t(_t).sum(0).T, argnums=0) |
| 137 | + return _dmudt(_t).squeeze(axis=-1).T |
| 138 | + |
| 139 | + def dsigmadt(_t): |
| 140 | + _dsigmadt = jax.jacrev(lambda _t: sigma_t(_t).sum(0).T) |
| 141 | + return _dsigmadt(_t).squeeze(axis=-1).T |
| 142 | + |
| 143 | + @jax.jit |
| 144 | + def u_t(_t, _x): |
| 145 | + _mu_t = mu_t(_t) |
| 146 | + _sigma_t = sigma_t(_t) |
| 147 | + _x = _x[:, None, :] |
| 148 | + |
| 149 | + log_q_i = jax.scipy.stats.norm.logpdf(_x, _mu_t, _sigma_t).sum(-1) |
| 150 | + relative_mixture_weights = jax.nn.softmax(w_logits + log_q_i)[:, :, None] |
| 151 | + |
| 152 | + log_q_t = -(relative_mixture_weights / (_sigma_t ** 2) * (_x - _mu_t)).sum(axis=1) |
| 153 | + _u_t = (relative_mixture_weights * (1 / _sigma_t * dsigmadt(_t) * (_x - _mu_t) + dmudt(_t))).sum(axis=1) |
| 154 | + |
| 155 | + return _u_t + 0.5 * (xi ** 2) * log_q_t |
| 156 | + |
| 157 | + N = int(T / dt) |
| 158 | + |
| 159 | + key, loc_key = jax.random.split(key) |
| 160 | + |
| 161 | + x_t = jnp.ones((num_samples, N + 1, 2)) * A |
| 162 | + eps = jax.random.normal(key, shape=(num_samples, A.shape[-1])) |
| 163 | + x_t = x_t.at[:, 0, :].set(x_t[:, 0, :] + sigma_t(jnp.zeros((num_samples, 1)))[:, 0, :] * eps) |
| 164 | + |
| 165 | + t = jnp.zeros((num_samples, 1)) |
| 166 | + for i in trange(N): |
| 167 | + key, loc_key = jax.random.split(key) |
| 168 | + eps = jax.random.normal(key, shape=(num_samples, A.shape[-1])) |
| 169 | + |
| 170 | + dx = dt * u_t(t, x_t[:, i, :]) + jnp.sqrt(dt) * xi * eps |
| 171 | + x_t = x_t.at[:, i + 1, :].set(x_t[:, i, :] + dx) |
| 172 | + t += dt |
| 173 | + |
| 174 | + return x_t |
| 175 | + |
| 176 | + |
| 177 | +if __name__ == '__main__': |
| 178 | + savedir = './out/var_doobs/mixture' |
| 179 | + os.makedirs(savedir, exist_ok=True) |
| 180 | + |
| 181 | + A = jnp.array([[-0.5, 0]]) |
| 182 | + B = jnp.array([[0.5, 0]]) |
| 183 | + dt = 5e-4 |
| 184 | + T = 1.0 |
| 185 | + xi = 0.1 |
| 186 | + epochs = 20_000 |
| 187 | + |
| 188 | + q_single, w_logits_single = create_mlp_q(A, B, T, 1) |
| 189 | + q_mixture, w_logits_mixture = create_mlp_q(A, B, T, 2) |
| 190 | + |
| 191 | + state_q_single, loss_plot_single = train(q_single, w_logits_single, epochs=epochs) |
| 192 | + state_q_mixture, loss_plot_mixture = train(q_mixture, w_logits_mixture, epochs=epochs) |
| 193 | + |
| 194 | + plt.plot(loss_plot_single, label='single') |
| 195 | + plt.plot(loss_plot_mixture, label='mixture') |
| 196 | + plt.legend() |
| 197 | + plt.show() |
| 198 | + |
| 199 | + samples_qt_single = draw_samples_qt(state_q_single, w_logits_single, num_samples=1000, T=T, A=A, |
| 200 | + key=jax.random.PRNGKey(0)) |
| 201 | + samples_qt_mixture = draw_samples_qt(state_q_mixture, w_logits_mixture, num_samples=1000, T=T, A=A, |
| 202 | + key=jax.random.PRNGKey(0)) |
| 203 | + |
| 204 | + plot_energy_surface() |
| 205 | + plt.scatter(samples_qt_single[:, 0], samples_qt_single[:, 1], label='single') |
| 206 | + plt.scatter(samples_qt_mixture[:, 0], samples_qt_mixture[:, 1], label='mixture') |
| 207 | + plt.legend() |
| 208 | + plt.show() |
| 209 | + |
| 210 | + samples_single = sample_stochastically(state_q_single, w_logits_single, num_samples=1000, T=T, A=A, xi=xi, dt=dt, |
| 211 | + key=jax.random.PRNGKey(0)) |
| 212 | + samples_mixture = sample_stochastically(state_q_mixture, w_logits_mixture, num_samples=1000, T=T, A=A, xi=xi, dt=dt, |
| 213 | + key=jax.random.PRNGKey(0)) |
| 214 | + |
| 215 | + plot_energy_surface(trajectories=samples_single) |
| 216 | + plt.savefig(f'{savedir}/toy-gaussian-single.pdf', bbox_inches='tight') |
| 217 | + plt.show() |
| 218 | + |
| 219 | + plot_energy_surface(trajectories=samples_mixture) |
| 220 | + plt.savefig(f'{savedir}/toy-gaussian-mixture.pdf', bbox_inches='tight') |
| 221 | + plt.show() |
| 222 | + |
| 223 | + t = T * jnp.linspace(0, 1, 10 * int(T / dt)).reshape((-1, 1)) |
| 224 | + _mu_t_single, _sigma_t_single = state_q_single.apply_fn(state_q_single.params, t) |
| 225 | + _mu_t_mixture, _sigma_t_mixture = state_q_mixture.apply_fn(state_q_mixture.params, t) |
| 226 | + |
| 227 | + vmin = min(_sigma_t_single.min(), _sigma_t_mixture.min()) |
| 228 | + vmax = max(_sigma_t_single.max(), _sigma_t_mixture.max()) |
| 229 | + |
| 230 | + plot_energy_surface() |
| 231 | + plt.scatter(_mu_t_single[:, :, 0], _mu_t_single[:, :, 1], c=_sigma_t_single, vmin=vmin, vmax=vmax, rasterized=True) |
| 232 | + plt.colorbar(label=r'$\sigma$') |
| 233 | + plt.savefig(f'{savedir}/toy-gaussian-single-mu.pdf', bbox_inches='tight') |
| 234 | + plt.show() |
| 235 | + |
| 236 | + plot_energy_surface() |
| 237 | + plt.scatter(_mu_t_mixture[:, :, 0], _mu_t_mixture[:, :, 1], c=_sigma_t_mixture, vmin=vmin, vmax=vmax, rasterized=True) |
| 238 | + plt.colorbar(label=r'$\sigma$') |
| 239 | + plt.savefig(f'{savedir}/toy-gaussian-mixture-mu.pdf', bbox_inches='tight') |
| 240 | + plt.show() |
0 commit comments