Skip to content

Commit 7b8e404

Browse files
committed
Fix log_path_likelihood in evaluate_mueller
1 parent 69e238e commit 7b8e404

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

evaluate_mueller.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ def load(path):
2020

2121

2222
@jax.jit
23-
def log_prob_path(path):
23+
def log_path_likelihood(path):
2424
rand = path[1:] - path[:-1] + dt * dUdx_fn(path[:-1])
25-
return U(path[0]) / kbT + jax.scipy.stats.norm.logpdf(rand, scale=jnp.sqrt(dt) * xi).sum()
25+
return (-U(path[0]) / kbT).sum() + jax.scipy.stats.norm.logpdf(rand, scale=jnp.sqrt(dt) * xi).sum()
2626

2727

2828
if __name__ == '__main__':
@@ -60,8 +60,8 @@ def log_prob_path(path):
6060
plt.show()
6161

6262
for name, paths in all_paths:
63-
plot_path_energy(paths, log_prob_path, reduce=lambda x: x, label=name)
64-
print('Median log-likelihood of:', name, jnp.median(jnp.array([log_prob_path(path) for path in paths])))
63+
plot_path_energy(paths, log_path_likelihood, reduce=lambda x: x, label=name)
64+
print('Median log-likelihood of:', name, jnp.median(jnp.array([log_path_likelihood(path) for path in paths])))
6565

6666
plt.legend()
6767
plt.ylabel('log path likelihood')

0 commit comments

Comments
 (0)