7
7
import matplotlib .pyplot as plt
8
8
import os
9
9
10
+ num_paths = 1000
11
+ xi = 5
12
+ kbT = xi ** 2 / 2
13
+ dt = 1e-4
14
+ T = 275e-4
15
+ N = int (T / dt )
16
+
17
+
10
18
def load (path ):
11
19
return jnp .array (np .load (path , allow_pickle = True ).astype (np .float32 )).squeeze ()
12
20
13
21
14
22
@jax .jit
15
23
def log_prob_path (path ):
16
24
rand = path [1 :] - path [:- 1 ] + dt * dUdx_fn (path [:- 1 ])
17
- return U (path [0 ]) + jax .scipy .stats .norm .logpdf (rand , scale = jnp .sqrt (dt ) * xi ).sum ()
25
+ return U (path [0 ]) / kbT + jax .scipy .stats .norm .logpdf (rand , scale = jnp .sqrt (dt ) * xi ).sum ()
18
26
19
27
20
28
if __name__ == '__main__' :
@@ -27,12 +35,6 @@ def log_prob_path(path):
27
35
('var-doobs' , './out/var_doobs/mueller/paths.npy' ),
28
36
]
29
37
30
- num_paths = 1000
31
- xi = 5
32
- dt = 1e-4
33
- T = 275e-4
34
- N = int (T / dt )
35
-
36
38
global_minimum_energy = U (minima_points [0 ])
37
39
for point in minima_points :
38
40
global_minimum_energy = min (global_minimum_energy , minimize (U , point ).fun )
@@ -59,7 +61,7 @@ def log_prob_path(path):
59
61
60
62
for name , paths in all_paths :
61
63
plot_path_energy (paths , log_prob_path , reduce = lambda x : x , label = name )
62
- print ('Median energy of:' , name , jnp .median (jnp .array ([log_prob_path (path ) for path in paths ])))
64
+ print ('Median log-likelihood of:' , name , jnp .median (jnp .array ([log_prob_path (path ) for path in paths ])))
63
65
64
66
plt .legend ()
65
67
plt .ylabel ('log path likelihood' )
0 commit comments