27
27
from utils .animation import save_trajectory , to_md_traj
28
28
from utils .rmsd import kabsch_align , kabsch_rmsd
29
29
30
+ from argparse import ArgumentParser
31
+
32
+ parser = ArgumentParser ()
33
+ parser .add_argument ('--mechanism' , type = str , choices = ['one-way-shooting' , 'two-way-shooting' ], required = True )
34
+ parser .add_argument ('--fixed_length' , type = int , default = 0 )
35
+ parser .add_argument ('--num_steps' , type = int , default = 10 ,
36
+ help = 'The number of MD steps taken at once. More takes longer to compile but runs faster in the end.' )
37
+
30
38
31
39
def human_format (num ):
32
40
"""https://stackoverflow.com/a/45846841/4417954"""
@@ -121,12 +129,14 @@ def step_n(step, _x, _v, n, _key):
121
129
122
130
123
131
if __name__ == '__main__' :
132
+ args = parser .parse_args ()
133
+
124
134
init_pdb = app .PDBFile ("./files/AD_A.pdb" )
125
135
target_pdb = app .PDBFile ("./files/AD_B.pdb" )
126
136
mdtraj_topology = md .Topology .from_openmm (init_pdb .topology )
127
137
phis_psis = phi_psi_from_mdtraj (mdtraj_topology )
128
138
129
- savedir = f"out/baselines/alanine"
139
+ savedir = f"out/baselines/alanine- { args . mechanism } "
130
140
os .makedirs (savedir , exist_ok = True )
131
141
132
142
# Construct the mass matrix
@@ -171,13 +181,8 @@ def U(_x):
171
181
172
182
@jax .jit
173
183
@jax .vmap
174
- def dUdx_fn_unscaled (_x ):
175
- return jax .grad (lambda _x : U (_x ).sum ())(_x )
176
-
177
-
178
- @jax .jit
179
184
def dUdx_fn (_x ):
180
- return dUdx_fn_unscaled ( _x ) / mass / gamma
185
+ return jax . grad ( lambda _x : U ( _x ). sum ())( _x ) / mass / gamma_in_ps
181
186
182
187
183
188
@jax .jit
@@ -188,10 +193,10 @@ def step(_x, _key):
188
193
189
194
@jax .jit
190
195
def step_langevin_forward (_x , _v , _key ):
191
- """Perform one step of forward langevin"""
196
+ """Perform one step of forward langevin as implemented in openmm """
192
197
alpha = jnp .exp (- gamma_in_ps * dt_in_ps )
193
198
f_scale = (1 - alpha ) / gamma_in_ps
194
- new_v_det = alpha * _v + f_scale * - dUdx_fn_unscaled (_x ) / mass
199
+ new_v_det = alpha * _v + f_scale * - dUdx_fn (_x )
195
200
new_v = new_v_det + jnp .sqrt (kbT * (1 - alpha ** 2 ) / mass ) * jax .random .normal (_key , _x .shape )
196
201
197
202
return _x + dt_in_ps * new_v , new_v
@@ -201,7 +206,7 @@ def step_langevin_forward(_x, _v, _key):
201
206
def step_langevin_log_density (_x , _v , _new_x , _new_v ):
202
207
alpha = jnp .exp (- gamma_in_ps * dt_in_ps )
203
208
f_scale = (1 - alpha ) / gamma_in_ps
204
- new_v_det = alpha * _v + f_scale * - dUdx_fn_unscaled (_x ) / mass
209
+ new_v_det = alpha * _v + f_scale * - dUdx_fn (_x )
205
210
new_v_rand = new_v_det - _new_v
206
211
207
212
return jax .scipy .stats .norm .logpdf (new_v_rand , 0 , jnp .sqrt (kbT * (1 - alpha ** 2 ) / mass )).sum ()
@@ -225,7 +230,7 @@ def step_langevin_backward(_x, _v, _key):
225
230
alpha = jnp .exp (- gamma_in_ps * dt_in_ps )
226
231
f_scale = (1 - alpha ) / gamma_in_ps
227
232
prev_x = _x - dt_in_ps * _v
228
- prev_v = 1 / alpha * (_v + f_scale * dUdx_fn_unscaled (prev_x ) / mass - jnp .sqrt (
233
+ prev_v = 1 / alpha * (_v + f_scale * dUdx_fn (prev_x ) - jnp .sqrt (
229
234
kbT * (1 - alpha ** 2 ) / mass ) * jax .random .normal (_key , _x .shape ))
230
235
231
236
return prev_x , prev_v
@@ -281,8 +286,8 @@ def step_langevin_backward(_x, _v, _key):
281
286
# jax.jit(jax.vmap(lambda s: kabsch_rmsd(B.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
282
287
# step_langevin_forward,
283
288
# step_langevin_backward,
284
- jax .jit (lambda _x , _v , _key : step_n (step_langevin_forward , _x , _v , 40 , _key )),
285
- jax .jit (lambda _x , _v , _key : step_n (step_langevin_backward , _x , _v , 40 , _key )),
289
+ jax .jit (lambda _x , _v , _key : step_n (step_langevin_forward , _x , _v , args . num_steps , _key )),
290
+ jax .jit (lambda _x , _v , _key : step_n (step_langevin_backward , _x , _v , args . num_steps , _key )),
286
291
jax .jit (lambda key : jnp .sqrt (kbT / mass ) * jax .random .normal (key , (1 , 66 )))
287
292
)
288
293
@@ -311,9 +316,17 @@ def step_langevin_backward(_x, _v, _key):
311
316
with open (f'{ savedir } /stats.json' , 'r' ) as fp :
312
317
statistics = json .load (fp )
313
318
else :
319
+ if args .mechanism == 'one-way-shooting' :
320
+ mechanism = tps2 .one_way_shooting
321
+ elif args .mechanism == 'two-way-shooting' :
322
+ mechanism = tps2 .two_way_shooting
323
+ else :
324
+ raise ValueError (f"Unknown mechanism { args .mechanism } " )
325
+
314
326
try :
315
- paths , velocities , statistics = tps2 .mcmc_shooting (system , tps2 .two_way_shooting , initial_trajectory ,
316
- 100 , jax .random .PRNGKey (1 ), warmup = 0 , fixed_length = 1000 )
327
+ paths , velocities , statistics = tps2 .mcmc_shooting (system , mechanism , initial_trajectory ,
328
+ 100 , dt_in_ps , jax .random .PRNGKey (1 ), warmup = 0 ,
329
+ fixed_length = args .fixed_length )
317
330
# paths = tps2.unguided_md(system, B, 1, key)
318
331
paths = [jnp .array (p ) for p in paths ]
319
332
velocities = [jnp .array (p ) for p in velocities ]
@@ -324,6 +337,7 @@ def step_langevin_backward(_x, _v, _key):
324
337
with open (f'{ savedir } /stats.json' , 'w' ) as fp :
325
338
json .dump (statistics , fp )
326
339
except Exception as e :
340
+ print (e )
327
341
breakpoint ()
328
342
329
343
print (statistics )
0 commit comments