1
1
import os
2
2
from functools import partial
3
-
3
+ import traceback
4
4
import jax
5
5
import numpy as np
6
6
import matplotlib .pyplot as plt
31
31
32
32
parser = ArgumentParser ()
33
33
parser .add_argument ('--mechanism' , type = str , choices = ['one-way-shooting' , 'two-way-shooting' ], required = True )
34
+ parser .add_argument ('--states' , type = str , default = 'phi-psi' , choices = ['phi-psi' , 'rmsd' ])
34
35
parser .add_argument ('--fixed_length' , type = int , default = 0 )
36
+ parser .add_argument ('--num_paths' , type = int , required = True )
35
37
parser .add_argument ('--num_steps' , type = int , default = 10 ,
36
38
help = 'The number of MD steps taken at once. More takes longer to compile but runs faster in the end.' )
37
39
parser .add_argument ('--resume' , action = 'store_true' )
@@ -138,7 +140,12 @@ def step_n(step, _x, _v, n, _key):
138
140
mdtraj_topology = md .Topology .from_openmm (init_pdb .topology )
139
141
phis_psis = phi_psi_from_mdtraj (mdtraj_topology )
140
142
141
- savedir = f"out/baselines/alanine-{ args .mechanism } -{ args .fixed_length } "
143
+ savedir = f"out/baselines/alanine-{ args .mechanism } "
144
+ if args .fixed_length > 0 :
145
+ savedir += f'-{ args .fixed_length } steps'
146
+ if args .states == 'rmsd' :
147
+ savedir += '-rmsd'
148
+
142
149
os .makedirs (savedir , exist_ok = True )
143
150
144
151
# Construct the mass matrix
@@ -253,11 +260,17 @@ def step_langevin_backward(_x, _v, _key):
253
260
# step
254
261
# )
255
262
263
+ if args .states == 'rmsd' :
264
+ state_A = jax .jit (jax .vmap (lambda s : kabsch_rmsd (A .reshape (22 , 3 ), s .reshape (22 , 3 )) <= 7.5e-2 ))
265
+ state_B = jax .jit (jax .vmap (lambda s : kabsch_rmsd (B .reshape (22 , 3 ), s .reshape (22 , 3 )) <= 7.5e-2 ))
266
+ elif args .states == 'phi-psi' :
267
+ state_A = jax .jit (lambda s : is_within (phis_psis (s ).reshape (- 1 , 2 ), phis_psis (A ), radius ))
268
+ state_B = jax .jit (lambda s : is_within (phis_psis (s ).reshape (- 1 , 2 ), phis_psis (B ), radius ))
269
+ else :
270
+ raise ValueError (f"Unknown states { args .states } " )
271
+
256
272
system = tps2 .SecondOrderSystem (
257
- jax .jit (lambda s : is_within (phis_psis (s ).reshape (- 1 , 2 ), phis_psis (A ), radius )),
258
- jax .jit (lambda s : is_within (phis_psis (s ).reshape (- 1 , 2 ), phis_psis (B ), radius )),
259
- # jax.jit(jax.vmap(lambda s: kabsch_rmsd(A.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
260
- # jax.jit(jax.vmap(lambda s: kabsch_rmsd(B.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
273
+ state_A , state_B ,
261
274
jax .jit (lambda _x , _v , _key : step_n (step_langevin_forward , _x , _v , args .num_steps , _key )),
262
275
jax .jit (lambda _x , _v , _key : step_n (step_langevin_backward , _x , _v , args .num_steps , _key )),
263
276
jax .jit (lambda key : jnp .sqrt (kbT / mass ) * jax .random .normal (key , (1 , 66 )))
@@ -269,7 +282,8 @@ def step_langevin_backward(_x, _v, _key):
269
282
270
283
if args .resume :
271
284
paths = [[x for x in p .astype (np .float32 )] for p in np .load (f'{ savedir } /paths.npy' , allow_pickle = True )]
272
- velocities = [[v for v in p .astype (np .float32 )] for p in np .load (f'{ savedir } /velocities.npy' , allow_pickle = True )]
285
+ velocities = [[v for v in p .astype (np .float32 )] for p in
286
+ np .load (f'{ savedir } /velocities.npy' , allow_pickle = True )]
273
287
with open (f'{ savedir } /stats.json' , 'r' ) as fp :
274
288
statistics = json .load (fp )
275
289
@@ -295,7 +309,7 @@ def step_langevin_backward(_x, _v, _key):
295
309
296
310
try :
297
311
paths , velocities , statistics = tps2 .mcmc_shooting (system , mechanism , initial_trajectory ,
298
- 100 , dt_in_ps , jax .random .PRNGKey (1 ), warmup = 0 ,
312
+ args . num_paths , dt_in_ps , jax .random .PRNGKey (1 ), warmup = 0 ,
299
313
fixed_length = args .fixed_length ,
300
314
stored = stored )
301
315
# paths = tps2.unguided_md(system, B, 1, key)
@@ -308,13 +322,16 @@ def step_langevin_backward(_x, _v, _key):
308
322
with open (f'{ savedir } /stats.json' , 'w' ) as fp :
309
323
json .dump (statistics , fp )
310
324
except Exception as e :
311
- print (e )
325
+ print (traceback . format_exc () )
312
326
breakpoint ()
313
327
314
328
print (statistics )
315
- print ([len (p ) for p in paths ])
316
- plt .hist ([len (p ) for p in paths ], bins = jnp .sqrt (len (paths )).astype (int ).item ())
317
- plt .show ()
329
+
330
+ if args .fixed_length == 0 :
331
+ print ([len (p ) for p in paths ])
332
+ plt .hist ([len (p ) for p in paths ], bins = jnp .sqrt (len (paths )).astype (int ).item ())
333
+ plt .savefig (f'{ savedir } /lengths.png' , bbox_inches = 'tight' )
334
+ plt .show ()
318
335
319
336
path_hist = PeriodicPathHistogram ()
320
337
for i , path in tqdm (enumerate (paths ), desc = 'Adding paths to histogram' , total = len (paths )):
0 commit comments