21
21
from utils .PlotPathsAlanine_jax import PlotPathsAlanine
22
22
from matplotlib import colors
23
23
24
- from utils .rmsd import kabsch
25
- from scipy . constants import physical_constants
24
+ from utils .animation import save_trajectory , to_md_traj
25
+ from utils . rmsd import kabsch_align , kabsch_rmsd
26
26
27
27
28
28
def human_format (num ):
@@ -61,7 +61,7 @@ def interpolate_two_points(start, stop, steps):
61
61
62
62
63
63
def phis_psis (position , mdtraj_topology ):
64
- traj = md . Trajectory ( position . reshape ( - 1 , mdtraj_topology . n_atoms , 3 ), mdtraj_topology )
64
+ traj = to_md_traj ( mdtraj_topology , position )
65
65
phi = md .compute_phi (traj )[1 ].squeeze ()
66
66
psi = md .compute_psi (traj )[1 ].squeeze ()
67
67
return jnp .array ([phi , psi ]).T
@@ -99,22 +99,31 @@ def draw_path(_path, **kwargs):
99
99
draw_path (path , color = 'blue' )
100
100
101
101
102
- T = 2.0
103
- dt_as_unit = unit .Quantity (value = 1.0 , unit = unit .femtoseconds )
102
+ dt_as_unit = unit .Quantity (value = 1 , unit = unit .microsecond )
104
103
dt_in_ps = dt_as_unit .value_in_unit (unit .picosecond )
105
104
dt = dt_as_unit .value_in_unit (unit .second )
106
105
107
- gamma_as_unit = 1.0 / unit .picosecond
106
+ gamma_as_unit = 1.0 / unit .second
108
107
# actually gamma is 1/s, but we are working without units and just need the correct scaling
109
108
# TODO: try to get rid of this duplicate definition
110
- gamma = 1.0 * unit .picosecond
109
+ gamma = 1.0 * unit .second
111
110
gamma_in_ps = gamma .value_in_unit (unit .picosecond )
112
111
gamma = gamma .value_in_unit (unit .second )
113
112
114
- temp = 298.15
113
+ temp = 300
115
114
kbT = 1.380649 * 6.02214076 * 1e-3 * temp
116
115
117
116
117
+ @jax .jit
118
+ def is_within (_phis_psis , _center , _radius , _period = 2 * jnp .pi ):
119
+ delta = jnp .abs (_center - _phis_psis )
120
+ delta = jnp .where (delta > _period / 2 , delta - _period , delta )
121
+
122
+ return jnp .hypot (delta [:, 0 ], delta [:, 1 ]) < _radius
123
+
124
+
125
+ deg = 180.0 / jnp .pi
126
+
118
127
if __name__ == '__main__' :
119
128
init_pdb = app .PDBFile ("./files/AD_c7eq.pdb" )
120
129
target_pdb = app .PDBFile ("./files/AD_c7ax.pdb" )
@@ -130,12 +139,14 @@ def draw_path(_path, **kwargs):
130
139
for _ in range (3 ):
131
140
new_mass .append (mass_ )
132
141
mass = jnp .array (new_mass )
133
- # Obtain sigma , gamma is by default 1
134
- sigma = jnp .sqrt (2 * kbT / mass / gamma )
142
+ # Obtain xi , gamma is by default 1
143
+ xi = jnp .sqrt (2 * kbT / mass / gamma )
135
144
136
145
# Initial and target shape [BS, 66]
137
- A = jnp .array (init_pdb .getPositions (asNumpy = True ).value_in_unit (unit .nanometer )).reshape (1 , - 1 )
138
- B = jnp .array (target_pdb .getPositions (asNumpy = True ).value_in_unit (unit .nanometer )).reshape (1 , - 1 )
146
+ A = jnp .array (init_pdb .getPositions (asNumpy = True ).value_in_unit (unit .nanometer ))
147
+ B = jnp .array (target_pdb .getPositions (asNumpy = True ).value_in_unit (unit .nanometer ))
148
+ A , B = kabsch_align (A , B )
149
+ A , B = A .reshape (1 , - 1 ), B .reshape (1 , - 1 )
139
150
140
151
# Initialize the potential energy with amber forcefields
141
152
ff = Hamiltonian ('amber14/protein.ff14SB.xml' , 'amber14/tip3p.xml' )
@@ -160,150 +171,82 @@ def U(_x):
160
171
return _U (_x .reshape (22 , 3 ), box , pairs , ff .paramset .parameters )
161
172
162
173
174
+ @jax .jit
175
+ @jax .vmap
163
176
def dUdx_fn (_x ):
164
177
return jax .grad (lambda _x : U (_x ).sum ())(_x ) / mass / gamma
165
178
166
179
167
- dUdx_fn = jax .vmap (dUdx_fn )
168
- dUdx_fn = jax .jit (dUdx_fn )
169
-
170
-
171
180
@jax .jit
172
181
def step (_x , _key ):
173
182
"""Perform one step of forward euler"""
174
- return _x - dt * dUdx_fn (_x ) + jnp .sqrt (dt ) * sigma * jax .random .normal (_key , _x .shape )
175
-
176
-
177
- def dUdx_fn_unscaled (_x ):
178
- return jax .grad (lambda _x : U (_x ).sum ())(_x )
183
+ return _x - dt * dUdx_fn (_x ) + jnp .sqrt (dt ) * xi * jax .random .normal (_key , _x .shape )
179
184
180
- dUdx_fn_unscaled = jax .vmap (dUdx_fn_unscaled )
181
- dUdx_fn_unscaled = jax .jit (dUdx_fn_unscaled )
182
-
183
- @jax .jit
184
- def step_langevin (_x , _v , _key ):
185
- alpha = jnp .exp (- gamma_in_ps * dt_in_ps )
186
- f_scale = (1 - alpha ) / gamma_in_ps
187
- new_v_det = alpha * _v + f_scale * - dUdx_fn_unscaled (_x ) / mass
188
- new_v = new_v_det + jnp .sqrt (kbT * (1 - alpha ** 2 ) / mass ) * jax .random .normal (_key , _x .shape )
189
-
190
- return _x + dt_in_ps * new_v , new_v
191
185
192
186
key = jax .random .PRNGKey (1 )
193
187
key , velocity_key = jax .random .split (key )
194
- steps = 1_000_000
188
+ steps = 100_000
195
189
196
190
trajectory = [A ]
197
191
_x = trajectory [- 1 ]
198
192
199
- velocity_variance = unit .Quantity (1 / mass , unit = 1 / unit .dalton ) * unit .BOLTZMANN_CONSTANT_kB * unit .Quantity (temp , unit = unit .kelvin )
200
- # Although velocity+variance is of the unit J / Da = m^2 / s^2, openmm cannot handle this directly and we need to convert it
201
- velocity_variance_in_si = 1 / physical_constants ['unified atomic mass unit' ][
202
- 0 ] * velocity_variance .value_in_unit (unit .joule / unit .dalton )
203
- # velocity_variance_in_si = unit.Quantity(velocity_variance_in_si, unit.meter / unit.second)
204
-
205
- _v = jnp .sqrt (velocity_variance_in_si ) * jax .random .normal (velocity_key , _x .shape )
206
- _v = unit .Quantity (_v , unit .meter / unit .second ).value_in_unit (unit .nanometer / unit .picosecond )
207
-
208
193
for i in trange (steps ):
209
194
key , iter_key = jax .random .split (key )
210
- _x , _v = step_langevin (_x , _v , iter_key )
195
+ _x = step (_x , iter_key )
211
196
212
197
trajectory .append (_x )
213
198
214
199
trajectory = jnp .array (trajectory ).reshape (- 1 , 66 )
215
200
201
+ # save_trajectory(mdtraj_topology, trajectory[-1000:], 'simulation.pdb')
202
+
216
203
# we only need to check whether the last frame contains nan, is it propagates
217
204
assert not jnp .isnan (trajectory [- 1 ]).any ()
218
205
trajectory_phi_psi = phis_psis (trajectory , mdtraj_topology )
219
206
220
- trajs = None
221
- for i in range (10000 , 11000 ):
222
- traj = md .load_pdb ('./files/AD_c7eq.pdb' )
223
- traj .xyz = trajectory [i ].reshape (22 , 3 )
224
- if trajs is None :
225
- trajs = traj
226
- else :
227
- trajs = trajs .join (traj )
228
- trajs .save (f'{ savedir } /ALDP_forward_euler.pdb' )
229
-
230
207
plt .title (f"{ human_format (steps )} steps @ { temp } K, dt = { human_format (dt )} s" )
231
208
ramachandran (trajectory_phi_psi )
232
209
plt .show ()
233
210
234
- # TODO: this is work in progress. Get some baselines with tps
235
-
236
- # l2_system = tps.System(
237
- # jax.jit(
238
- # lambda s: jnp.all(jnp.linalg.norm(A.reshape(-1, 22, 3) - s.reshape(-1, 22, 3), axis=2) <= 5e-2, axis=1)),
239
- # jax.jit(
240
- # lambda s: jnp.all(jnp.linalg.norm(B.reshape(-1, 22, 3) - s.reshape(-1, 22, 3), axis=2) <= 5e-2, axis=1)),
211
+ # Choose a system, either phi psi, or rmsd
212
+ # system = tps.System(
213
+ # jax.jit(jax.vmap(lambda s: kabsch_rmsd(A.reshape(22, 3), s.reshape(22, 3)) < 0.1)),
214
+ # jax.jit(jax.vmap(lambda s: kabsch_rmsd(B.reshape(22, 3), s.reshape(22, 3)) < 0.1)),
241
215
# step
242
216
# )
243
- #
244
- # rmsd_system = tps.System(
245
- # jax.jit(lambda s: kabsch(A.reshape(22, 3), s.reshape(22, 3)) < 0.15),
246
- # jax.jit(lambda s: kabsch(B.reshape(22, 3), s.reshape(22, 3)) < 0.15),
247
- # step
248
- # )
249
- #
250
- # # @jax.jit
251
- # def is_within_phi_psi(s, center, radius, period=2 * jnp.pi):
252
- # points = phis_psis(s, mdtraj_topology)
253
- # delta = jnp.abs(center - points)
254
- # delta = jnp.where(delta > period / 2, delta - period, delta)
255
- #
256
- # return jnp.hypot(delta[:, 0], delta[:, 1]) < radius
257
- #
258
- #
259
- # deg = 180.0 / jnp.pi
260
- # # State('A', torch.tensor([-150, 150]) / deg, torch.tensor([20, 45, 65, 80]) / deg),
261
- # # State('B', torch.tensor([-70, 135]) / deg, torch.tensor([20, 45, 65, 75]) / deg),
262
- # # State('C', torch.tensor([-150, -65]) / deg, torch.tensor([20, 45, 60]) / deg),
263
- # # State('D', torch.tensor([-70, -50]) / deg, torch.tensor([20, 45, 60]) / deg),
264
- # # State('E', torch.tensor([50, -100]) / deg, torch.tensor([20, 45, 65, 80]) / deg),
265
- # # State('F', torch.tensor([40, 65]) / deg, torch.tensor([20, 45, 65, 80]) / deg),
266
- #
267
- # phi_psi_system = tps.System(
268
- # lambda s: is_within_phi_psi(s, jnp.array([-150, 150]) / deg, 20 / deg),
269
- # lambda s: is_within_phi_psi(s, jnp.array([50, -100]) / deg, 20 / deg),
270
- # step
271
- # )
272
- #
273
- # # TODO: fix vmap
274
- # filter1 = jax.vmap(phi_psi_system.start_state)(trajectory)
275
- # filter2 = jax.vmap(phi_psi_system.target_state)(trajectory)
276
- #
277
- # plt.title('start')
278
- # ramachandran(trajectory_phi_psi[filter1])
279
- # plt.show()
280
- #
281
- # plt.title('target')
282
- # ramachandran(trajectory_phi_psi[filter2])
283
- # plt.show()
284
-
285
- # initial_trajectory = [t.reshape(1, -1) for t in interpolate([A, B], 100)]
286
-
287
- #
288
- # for i in range(10):
289
- # key, iter_key = jax.random.split(key)
290
- #
291
- # # ramachandran(None, path=phis_psis(jnp.vstack(initial_trajectory), mdtraj_topology))
292
- # # plt.show()
293
- #
294
- #
295
- # ok, trajectory = tps.one_way_shooting(system, initial_trajectory, 0, key)
296
- # trajectory = jnp.array(trajectory)
297
- # trajectory = phis_psis(trajectory, mdtraj_topology)
298
- # print('ok?', ok)
299
- #
300
- # ramachandran(None, path=phis_psis(jnp.vstack(initial_trajectory), mdtraj_topology), paths=[trajectory])
301
- # plt.show()
302
- #
303
-
304
- # paths = tps.mcmc_shooting(system, tps.one_way_shooting, initial_trajectory, 10, key, warmup=0)
305
- # paths = [jnp.array(p) for p in paths]
306
- #
307
- # print(paths)
308
- # ramachandran(None, path=[phis_psis(p, mdtraj_topology) for p in paths][-1])
309
- # plt.show()
217
+
218
+ system = tps .System (
219
+ lambda s : is_within (phis_psis (s , mdtraj_topology ).reshape (- 1 , 2 ), phis_psis (A , mdtraj_topology ), 20 / deg ),
220
+ lambda s : is_within (phis_psis (s , mdtraj_topology ).reshape (- 1 , 2 ), phis_psis (B , mdtraj_topology ), 20 / deg ),
221
+ step
222
+ )
223
+
224
+ filter1 = system .start_state (trajectory )
225
+ filter2 = system .target_state (trajectory )
226
+
227
+ plt .title ('start' )
228
+ ramachandran (trajectory_phi_psi [filter1 ])
229
+ plt .show ()
230
+
231
+ plt .title ('target' )
232
+ ramachandran (trajectory_phi_psi [filter2 ])
233
+ plt .show ()
234
+
235
+ initial_trajectory = [t .reshape (1 , - 1 ) for t in interpolate ([A , B ], 100 )]
236
+ save_trajectory (mdtraj_topology , jnp .array (initial_trajectory ), f'{ savedir } /initial_trajectory.pdb' )
237
+
238
+ paths = tps .mcmc_shooting (system , tps .two_way_shooting , initial_trajectory , 5 , key , warmup = 2 )
239
+ paths = [jnp .array (p ) for p in paths ]
240
+ # store paths
241
+ np .save (f'{ savedir } /paths.npy' , np .array (paths , dtype = object ), allow_pickle = True )
242
+
243
+ print ([len (p ) for p in paths ])
244
+ plt .hist ([len (p ) for p in paths ], bins = jnp .sqrt (len (paths )).astype (int ).item ())
245
+ plt .show ()
246
+
247
+ plt .title (f"{ human_format (len (paths ))} steps @ { temp } K, dt = { human_format (dt )} s" )
248
+ ramachandran (jnp .concatenate ([phis_psis (p , mdtraj_topology ) for p in paths ]),
249
+ path = phis_psis (jnp .array (initial_trajectory ), mdtraj_topology ))
250
+ plt .show ()
251
+
252
+ save_trajectory (mdtraj_topology , paths [- 1 ], f'{ savedir } /final_trajectory.pdb' )
0 commit comments