Skip to content

Commit d6e3a1e

Browse files
committed
Fix path likelihood for euler
1 parent 8a906da commit d6e3a1e

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

evaluate_mueller.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,22 @@
77
import matplotlib.pyplot as plt
88
import os
99

10+
num_paths = 1000
11+
xi = 5
12+
kbT = xi ** 2 / 2
13+
dt = 1e-4
14+
T = 275e-4
15+
N = int(T / dt)
16+
17+
1018
def load(path):
1119
return jnp.array(np.load(path, allow_pickle=True).astype(np.float32)).squeeze()
1220

1321

1422
@jax.jit
1523
def log_prob_path(path):
1624
rand = path[1:] - path[:-1] + dt * dUdx_fn(path[:-1])
17-
return U(path[0]) + jax.scipy.stats.norm.logpdf(rand, scale=jnp.sqrt(dt) * xi).sum()
25+
return U(path[0]) / kbT + jax.scipy.stats.norm.logpdf(rand, scale=jnp.sqrt(dt) * xi).sum()
1826

1927

2028
if __name__ == '__main__':
@@ -27,12 +35,6 @@ def log_prob_path(path):
2735
('var-doobs', './out/var_doobs/mueller/paths.npy'),
2836
]
2937

30-
num_paths = 1000
31-
xi = 5
32-
dt = 1e-4
33-
T = 275e-4
34-
N = int(T / dt)
35-
3638
global_minimum_energy = U(minima_points[0])
3739
for point in minima_points:
3840
global_minimum_energy = min(global_minimum_energy, minimize(U, point).fun)
@@ -59,7 +61,7 @@ def log_prob_path(path):
5961

6062
for name, paths in all_paths:
6163
plot_path_energy(paths, log_prob_path, reduce=lambda x: x, label=name)
62-
print('Median energy of:', name, jnp.median(jnp.array([log_prob_path(path) for path in paths])))
64+
print('Median log-likelihood of:', name, jnp.median(jnp.array([log_prob_path(path) for path in paths])))
6365

6466
plt.legend()
6567
plt.ylabel('log path likelihood')

gaussian_mixture.py

Whitespace-only changes.

0 commit comments

Comments
 (0)