34
34
parser .add_argument ('--fixed_length' , type = int , default = 0 )
35
35
parser .add_argument ('--num_steps' , type = int , default = 10 ,
36
36
help = 'The number of MD steps taken at once. More takes longer to compile but runs faster in the end.' )
37
+ parser .add_argument ('--resume' , action = 'store_true' )
38
+ parser .add_argument ('--override' , action = 'store_true' )
37
39
38
40
39
41
def human_format (num ):
@@ -136,7 +138,7 @@ def step_n(step, _x, _v, n, _key):
136
138
mdtraj_topology = md .Topology .from_openmm (init_pdb .topology )
137
139
phis_psis = phi_psi_from_mdtraj (mdtraj_topology )
138
140
139
- savedir = f"out/baselines/alanine-{ args .mechanism } "
141
+ savedir = f"out/baselines/alanine-{ args .mechanism } - { args . fixed_length } "
140
142
os .makedirs (savedir , exist_ok = True )
141
143
142
144
# Construct the mass matrix
@@ -203,7 +205,7 @@ def step_langevin_forward(_x, _v, _key):
203
205
204
206
205
207
@jax .jit
206
- def step_langevin_log_density (_x , _v , _new_x , _new_v ):
208
+ def step_langevin_log_prob (_x , _v , _new_x , _new_v ):
207
209
alpha = jnp .exp (- gamma_in_ps * dt_in_ps )
208
210
f_scale = (1 - alpha ) / gamma_in_ps
209
211
new_v_det = alpha * _v + f_scale * - dUdx_fn (_x )
@@ -212,14 +214,14 @@ def step_langevin_log_density(_x, _v, _new_x, _new_v):
212
214
return jax .scipy .stats .norm .logpdf (new_v_rand , 0 , jnp .sqrt (kbT * (1 - alpha ** 2 ) / mass )).sum ()
213
215
214
216
215
- def langevin_log_path_density (path_and_velocities ):
217
+ def langevin_log_path_likelihood (path_and_velocities ):
216
218
path , velocities = path_and_velocities
217
219
218
220
log_prob = (- U (path [0 ]) / kbT ).sum ()
219
221
log_prob += jax .scipy .stats .norm .logpdf (velocities [0 ], 0 , jnp .sqrt (kbT / mass )).sum ()
220
222
221
223
for i in range (1 , len (path )):
222
- log_prob += step_langevin_log_density (path [i - 1 ], velocities [i - 1 ], path [i ], velocities [i ])
224
+ log_prob += step_langevin_log_prob (path [i - 1 ], velocities [i - 1 ], path [i ], velocities [i ])
223
225
224
226
return log_prob
225
227
@@ -236,34 +238,6 @@ def step_langevin_backward(_x, _v, _key):
236
238
return prev_x , prev_v
237
239
238
240
239
- key = jax .random .PRNGKey (1 )
240
- key , velocity_key = jax .random .split (key )
241
- steps = 10_000
242
-
243
- trajectory = [A ]
244
- _x = trajectory [- 1 ]
245
- _v = jnp .sqrt (kbT / mass ) * jax .random .normal (velocity_key , (1 , 66 ))
246
-
247
- for i in trange (steps ):
248
- key , iter_key = jax .random .split (key )
249
- _x , _v = step_langevin_forward (_x , _v , iter_key )
250
-
251
- trajectory .append (_x )
252
-
253
- trajectory = jnp .array (trajectory ).reshape (- 1 , 66 )
254
-
255
- # save_trajectory(mdtraj_topology, trajectory[-1000:], 'simulation.pdb')
256
-
257
- # we only need to check whether the last frame contains nan, is it propagates
258
- assert not jnp .isnan (trajectory [- 1 ]).any ()
259
- trajectory_phi_psi = phis_psis (trajectory )
260
-
261
- plt .title (f"{ human_format (steps )} steps @ { temp } K, dt = { human_format (dt )} s" )
262
- ramachandran (trajectory_phi_psi )
263
- plt .scatter (phis_psis (A )[0 , 0 ], phis_psis (A )[0 , 1 ], color = 'red' , marker = '*' )
264
- plt .scatter (phis_psis (B )[0 , 0 ], phis_psis (B )[0 , 1 ], color = 'green' , marker = '*' )
265
- plt .show ()
266
-
267
241
# Choose a system, either phi psi, or rmsd
268
242
# system = tps1.System(
269
243
# jax.jit(jax.vmap(lambda s: kabsch_rmsd(A.reshape(22, 3), s.reshape(22, 3)) < 0.1)),
@@ -273,72 +247,69 @@ def step_langevin_backward(_x, _v, _key):
273
247
274
248
radius = 20 / deg
275
249
276
- system = tps1 .FirstOrderSystem (
277
- lambda s : is_within (phis_psis (s ).reshape (- 1 , 2 ), phis_psis (A ), radius ),
278
- lambda s : is_within (phis_psis (s ).reshape (- 1 , 2 ), phis_psis (B ), radius ),
279
- step
280
- )
250
+ # system = tps1.FirstOrderSystem(
251
+ # lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(A), radius),
252
+ # lambda s: is_within(phis_psis(s).reshape(-1, 2), phis_psis(B), radius),
253
+ # step
254
+ # )
281
255
282
256
system = tps2 .SecondOrderSystem (
283
257
jax .jit (lambda s : is_within (phis_psis (s ).reshape (- 1 , 2 ), phis_psis (A ), radius )),
284
258
jax .jit (lambda s : is_within (phis_psis (s ).reshape (- 1 , 2 ), phis_psis (B ), radius )),
285
259
# jax.jit(jax.vmap(lambda s: kabsch_rmsd(A.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
286
260
# jax.jit(jax.vmap(lambda s: kabsch_rmsd(B.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
287
- # step_langevin_forward,
288
- # step_langevin_backward,
289
261
jax .jit (lambda _x , _v , _key : step_n (step_langevin_forward , _x , _v , args .num_steps , _key )),
290
262
jax .jit (lambda _x , _v , _key : step_n (step_langevin_backward , _x , _v , args .num_steps , _key )),
291
263
jax .jit (lambda key : jnp .sqrt (kbT / mass ) * jax .random .normal (key , (1 , 66 )))
292
264
)
293
265
294
- print ("A" , phis_psis (A ))
295
- print ("B" , phis_psis (B ))
296
-
297
- filter1 = system .start_state (trajectory )
298
- filter2 = system .target_state (trajectory )
299
-
300
- plt .title ('start' )
301
- ramachandran (trajectory_phi_psi [filter1 ])
302
- plt .show ()
303
-
304
- plt .title ('target' )
305
- ramachandran (trajectory_phi_psi [filter2 ])
306
- plt .show ()
307
-
308
266
initial_trajectory = md .load ('./files/AD_A_B_500K_initial_trajectory.pdb' ).xyz .reshape (- 1 , 1 , 66 )
309
267
initial_trajectory = [p for p in initial_trajectory ]
310
268
save_trajectory (mdtraj_topology , jnp .array (initial_trajectory ), f'{ savedir } /initial_trajectory.pdb' )
311
269
312
- load = False
313
- if load :
314
- paths = np .load (f'{ savedir } /paths.npy' , allow_pickle = True )
315
- velocities = np .load (f'{ savedir } /velocities.npy' , allow_pickle = True )
270
+ if args .resume :
271
+ paths = [[x for x in p .astype (np .float32 )] for p in np .load (f'{ savedir } /paths.npy' , allow_pickle = True )]
272
+ velocities = [[v for v in p .astype (np .float32 )] for p in np .load (f'{ savedir } /velocities.npy' , allow_pickle = True )]
316
273
with open (f'{ savedir } /stats.json' , 'r' ) as fp :
317
274
statistics = json .load (fp )
275
+
276
+ stored = {
277
+ 'trajectories' : [initial_trajectory ] + paths ,
278
+ 'velocities' : velocities ,
279
+ 'statistics' : statistics
280
+ }
281
+ else :
282
+ if os .path .exists (f'{ savedir } /paths.npy' ) and not args .override :
283
+ print (f"The target directory is not empy."
284
+ f"Please use --override to overwrite the existing data or --resume to continue." )
285
+ exit (1 )
286
+
287
+ stored = None
288
+
289
+ if args .mechanism == 'one-way-shooting' :
290
+ mechanism = tps2 .one_way_shooting
291
+ elif args .mechanism == 'two-way-shooting' :
292
+ mechanism = tps2 .two_way_shooting
318
293
else :
319
- if args .mechanism == 'one-way-shooting' :
320
- mechanism = tps2 .one_way_shooting
321
- elif args .mechanism == 'two-way-shooting' :
322
- mechanism = tps2 .two_way_shooting
323
- else :
324
- raise ValueError (f"Unknown mechanism { args .mechanism } " )
325
-
326
- try :
327
- paths , velocities , statistics = tps2 .mcmc_shooting (system , mechanism , initial_trajectory ,
328
- 100 , dt_in_ps , jax .random .PRNGKey (1 ), warmup = 0 ,
329
- fixed_length = args .fixed_length )
330
- # paths = tps2.unguided_md(system, B, 1, key)
331
- paths = [jnp .array (p ) for p in paths ]
332
- velocities = [jnp .array (p ) for p in velocities ]
333
- # store paths
334
- np .save (f'{ savedir } /paths.npy' , np .array (paths , dtype = object ), allow_pickle = True )
335
- np .save (f'{ savedir } /velocities.npy' , np .array (velocities , dtype = object ), allow_pickle = True )
336
- # save statistics, which is a dictionary
337
- with open (f'{ savedir } /stats.json' , 'w' ) as fp :
338
- json .dump (statistics , fp )
339
- except Exception as e :
340
- print (e )
341
- breakpoint ()
294
+ raise ValueError (f"Unknown mechanism { args .mechanism } " )
295
+
296
+ try :
297
+ paths , velocities , statistics = tps2 .mcmc_shooting (system , mechanism , initial_trajectory ,
298
+ 100 , dt_in_ps , jax .random .PRNGKey (1 ), warmup = 0 ,
299
+ fixed_length = args .fixed_length ,
300
+ stored = stored )
301
+ # paths = tps2.unguided_md(system, B, 1, key)
302
+ paths = [jnp .array (p ) for p in paths ]
303
+ velocities = [jnp .array (p ) for p in velocities ]
304
+ # store paths
305
+ np .save (f'{ savedir } /paths.npy' , np .array (paths , dtype = object ), allow_pickle = True )
306
+ np .save (f'{ savedir } /velocities.npy' , np .array (velocities , dtype = object ), allow_pickle = True )
307
+ # save statistics, which is a dictionary
308
+ with open (f'{ savedir } /stats.json' , 'w' ) as fp :
309
+ json .dump (statistics , fp )
310
+ except Exception as e :
311
+ print (e )
312
+ breakpoint ()
342
313
343
314
print (statistics )
344
315
print ([len (p ) for p in paths ])
@@ -368,8 +339,8 @@ def step_langevin_backward(_x, _v, _key):
368
339
plt .savefig (f'{ savedir } /median_energy.png' , bbox_inches = 'tight' )
369
340
plt .show ()
370
341
371
- plot_path_energy (list (zip (paths , velocities )), langevin_log_path_density , reduce = lambda x : x , already_ln = True )
372
- plt .ylabel ('Path Density ' )
342
+ plot_path_energy (list (zip (paths , velocities )), langevin_log_path_likelihood , reduce = lambda x : x , already_ln = True )
343
+ plt .ylabel ('Path Likelihood ' )
373
344
plt .savefig (f'{ savedir } /path_density.png' , bbox_inches = 'tight' )
374
345
plt .show ()
375
346
0 commit comments