19
19
# helper function for plotting in 2D
20
20
from matplotlib import colors
21
21
22
+ from utils .angles import phi_psi_from_mdtraj
22
23
from utils .animation import save_trajectory , to_md_traj
23
24
from utils .rmsd import kabsch_align , kabsch_rmsd
24
25
@@ -58,13 +59,6 @@ def interpolate_two_points(start, stop, steps):
58
59
return interpolation
59
60
60
61
61
- def phis_psis (position , mdtraj_topology ):
62
- traj = to_md_traj (mdtraj_topology , position )
63
- phi = md .compute_phi (traj )[1 ].squeeze ()
64
- psi = md .compute_psi (traj )[1 ].squeeze ()
65
- return jnp .array ([phi , psi ]).T
66
-
67
-
68
62
def ramachandran (samples , bins = 100 , path = None , paths = None , states = None , alpha = 1.0 ):
69
63
if samples is not None :
70
64
plt .hist2d (samples [:, 0 ], samples [:, 1 ], bins = bins , norm = colors .LogNorm (), rasterized = True )
@@ -97,7 +91,8 @@ def draw_path(_path, **kwargs):
97
91
draw_path (path , color = 'blue' )
98
92
99
93
for state in (states if states is not None else []):
100
- c = plt .Circle (state ['center' ], radius = state ['radius' ], edgecolor = 'gray' , facecolor = 'white' , ls = '--' , lw = 0.7 , alpha = alpha )
94
+ c = plt .Circle (state ['center' ], radius = state ['radius' ], edgecolor = 'gray' , facecolor = 'white' , ls = '--' , lw = 0.7 ,
95
+ alpha = alpha )
101
96
plt .gca ().add_patch (c )
102
97
plt .gca ().annotate (state ['name' ], xy = state ['center' ], ha = "center" , va = "center" )
103
98
@@ -131,6 +126,7 @@ def is_within(_phis_psis, _center, _radius, _period=2 * jnp.pi):
131
126
init_pdb = app .PDBFile ("./files/AD_A.pdb" )
132
127
target_pdb = app .PDBFile ("./files/AD_B.pdb" )
133
128
mdtraj_topology = md .Topology .from_openmm (init_pdb .topology )
129
+ phis_psis = phi_psi_from_mdtraj (mdtraj_topology )
134
130
135
131
savedir = f"baselines/alanine"
136
132
os .makedirs (savedir , exist_ok = True )
@@ -151,7 +147,6 @@ def is_within(_phis_psis, _center, _radius, _period=2 * jnp.pi):
151
147
A , B = kabsch_align (A , B )
152
148
A , B = A .reshape (1 , - 1 ), B .reshape (1 , - 1 )
153
149
154
-
155
150
# Initialize the potential energy with amber forcefields
156
151
ff = Hamiltonian ('amber14/protein.ff14SB.xml' , 'amber14/tip3p.xml' )
157
152
potentials = ff .createPotential (init_pdb .topology ,
@@ -236,12 +231,12 @@ def step_langevin_backward(_x, _v, _key):
236
231
237
232
# we only need to check whether the last frame contains nan, is it propagates
238
233
assert not jnp .isnan (trajectory [- 1 ]).any ()
239
- trajectory_phi_psi = phis_psis (trajectory , mdtraj_topology )
234
+ trajectory_phi_psi = phis_psis (trajectory )
240
235
241
236
plt .title (f"{ human_format (steps )} steps @ { temp } K, dt = { human_format (dt )} s" )
242
237
ramachandran (trajectory_phi_psi )
243
- plt .scatter (phis_psis (A , mdtraj_topology )[0 ], phis_psis (A , mdtraj_topology )[1 ], color = 'red' , marker = '*' )
244
- plt .scatter (phis_psis (B , mdtraj_topology )[0 ], phis_psis (B , mdtraj_topology )[1 ], color = 'green' , marker = '*' )
238
+ plt .scatter (phis_psis (A )[0 ], phis_psis (A )[1 ], color = 'red' , marker = '*' )
239
+ plt .scatter (phis_psis (B )[0 ], phis_psis (B )[1 ], color = 'green' , marker = '*' )
245
240
plt .show ()
246
241
247
242
# Choose a system, either phi psi, or rmsd
@@ -254,23 +249,23 @@ def step_langevin_backward(_x, _v, _key):
254
249
radius = 20 / deg
255
250
256
251
system = tps1 .FirstOrderSystem (
257
- lambda s : is_within (phis_psis (s , mdtraj_topology ).reshape (- 1 , 2 ), phis_psis (A , mdtraj_topology ), radius ),
258
- lambda s : is_within (phis_psis (s , mdtraj_topology ).reshape (- 1 , 2 ), phis_psis (B , mdtraj_topology ), radius ),
252
+ lambda s : is_within (phis_psis (s ).reshape (- 1 , 2 ), phis_psis (A ), radius ),
253
+ lambda s : is_within (phis_psis (s ).reshape (- 1 , 2 ), phis_psis (B ), radius ),
259
254
step
260
255
)
261
256
262
257
system = tps2 .SecondOrderSystem (
263
- # lambda s: is_within(phis_psis(s, mdtraj_topology ).reshape(-1, 2), phis_psis(A, mdtraj_topology ), radius),
264
- # lambda s: is_within(phis_psis(s, mdtraj_topology ).reshape(-1, 2), phis_psis(B, mdtraj_topology ), radius),
265
- jax .jit (jax .vmap (lambda s : kabsch_rmsd (A .reshape (22 , 3 ), s .reshape (22 , 3 )) <= 7.5e-2 )),
266
- jax .jit (jax .vmap (lambda s : kabsch_rmsd (B .reshape (22 , 3 ), s .reshape (22 , 3 )) <= 7.5e-2 )),
258
+ jax . jit ( lambda s : is_within (phis_psis (s ).reshape (- 1 , 2 ), phis_psis (A ), radius ) ),
259
+ jax . jit ( lambda s : is_within (phis_psis (s ).reshape (- 1 , 2 ), phis_psis (B ), radius ) ),
260
+ # jax.jit(jax.vmap(lambda s: kabsch_rmsd(A.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
261
+ # jax.jit(jax.vmap(lambda s: kabsch_rmsd(B.reshape(22, 3), s.reshape(22, 3)) <= 7.5e-2)),
267
262
step_langevin_forward ,
268
263
step_langevin_backward ,
269
264
jax .jit (lambda key : jnp .sqrt (kbT / mass ) * jax .random .normal (key , (1 , 66 )))
270
265
)
271
266
272
- print ("A" , phis_psis (A , mdtraj_topology ))
273
- print ("B" , phis_psis (B , mdtraj_topology ))
267
+ print ("A" , phis_psis (A ))
268
+ print ("B" , phis_psis (B ))
274
269
275
270
filter1 = system .start_state (trajectory )
276
271
filter2 = system .target_state (trajectory )
@@ -291,7 +286,8 @@ def step_langevin_backward(_x, _v, _key):
291
286
if load :
292
287
paths = np .load (f'{ savedir } /paths.npy' , allow_pickle = True )
293
288
else :
294
- paths = tps2 .mcmc_shooting (system , tps2 .two_way_shooting , initial_trajectory , 5 , jax .random .PRNGKey (1 ), warmup = 0 )
289
+ paths = tps2 .mcmc_shooting (system , tps2 .two_way_shooting , initial_trajectory , 100 , jax .random .PRNGKey (1 ),
290
+ warmup = 10 )
295
291
# paths = tps2.unguided_md(system, B, 1, key)
296
292
paths = [jnp .array (p ) for p in paths ]
297
293
# store paths
@@ -303,16 +299,17 @@ def step_langevin_backward(_x, _v, _key):
303
299
304
300
path_hist = PeriodicPathHistogram ()
305
301
for i , path in tqdm (enumerate (paths )):
306
- path_hist .add_path (np .array (phis_psis (path , mdtraj_topology )))
302
+ path_hist .add_path (np .array (phis_psis (path )))
307
303
308
304
plt .title (f"{ human_format (len (paths ))} paths @ { temp } K, dt = { human_format (dt )} s" )
309
305
path_hist .plot (cmin = 0.001 )
310
306
ramachandran (None , states = [
311
- {'name' : 'A' , 'center' : phis_psis (A , mdtraj_topology ), 'radius' : radius },
312
- {'name' : 'B' , 'center' : phis_psis (B , mdtraj_topology ), 'radius' : radius },
307
+ {'name' : 'A' , 'center' : phis_psis (A ). squeeze ( ), 'radius' : radius },
308
+ {'name' : 'B' , 'center' : phis_psis (B ). squeeze ( ), 'radius' : radius },
313
309
], alpha = 0.7 )
314
310
plt .savefig (f'{ savedir } /paths.png' , bbox_inches = 'tight' )
315
311
plt .show ()
316
312
317
313
for i , path in tqdm (enumerate (paths )):
318
- save_trajectory (mdtraj_topology , jnp .array ([kabsch_align (p .reshape (- 1 , 3 ), B .reshape (- 1 , 3 ))[0 ] for p in path ]), f'{ savedir } /trajectory_{ i } .pdb' )
314
+ save_trajectory (mdtraj_topology , jnp .array ([kabsch_align (p .reshape (- 1 , 3 ), B .reshape (- 1 , 3 ))[0 ] for p in path ]),
315
+ f'{ savedir } /trajectory_{ i } .pdb' )
0 commit comments