21
21
from utils .PlotPathsAlanine_jax import PlotPathsAlanine
22
22
from matplotlib import colors
23
23
24
+ from utils .rmsd import kabsch
25
+ from scipy .constants import physical_constants
26
+
24
27
25
28
def human_format (num ):
26
29
"""https://stackoverflow.com/a/45846841/4417954"""
@@ -88,7 +91,6 @@ def draw_path(_path, **kwargs):
88
91
masked_path_x , masked_path_y = np .ma .MaskedArray (_path [:, 0 ], mask ), np .ma .MaskedArray (_path [:, 1 ], mask )
89
92
plt .plot (masked_path_x , masked_path_y , ** kwargs )
90
93
91
-
92
94
if path is not None :
93
95
draw_path (path , color = 'red' )
94
96
@@ -98,13 +100,21 @@ def draw_path(_path, **kwargs):
98
100
99
101
100
102
T = 2.0
101
- dt = 1.0 * unit .microsecond
102
- dt = dt .value_in_unit (unit .second )
103
+ dt_as_unit = unit .Quantity (value = 1.0 , unit = unit .femtoseconds )
104
+ dt_in_ps = dt_as_unit .value_in_unit (unit .picosecond )
105
+ dt = dt_as_unit .value_in_unit (unit .second )
106
+
107
+ gamma_as_unit = 1.0 / unit .picosecond
108
+ # actually gamma is 1/s, but we are working without units and just need the correct scaling
109
+ # TODO: try to get rid of this duplicate definition
110
+ gamma = 1.0 * unit .picosecond
111
+ gamma_in_ps = gamma .value_in_unit (unit .picosecond )
112
+ gamma = gamma .value_in_unit (unit .second )
103
113
104
114
temp = 298.15
105
- temp = 10000 # TODO: remove this
106
115
kbT = 1.380649 * 6.02214076 * 1e-3 * temp
107
116
117
+
108
118
if __name__ == '__main__' :
109
119
init_pdb = app .PDBFile ("./files/AD_c7eq.pdb" )
110
120
target_pdb = app .PDBFile ("./files/AD_c7ax.pdb" )
@@ -121,7 +131,7 @@ def draw_path(_path, **kwargs):
121
131
new_mass .append (mass_ )
122
132
mass = jnp .array (new_mass )
123
133
# Obtain sigma, gamma is by default 1
124
- sigma = jnp .sqrt (2 * kbT / mass )
134
+ sigma = jnp .sqrt (2 * kbT / mass / gamma )
125
135
126
136
# Initial and target shape [BS, 66]
127
137
A = jnp .array (init_pdb .getPositions (asNumpy = True ).value_in_unit (unit .nanometer )).reshape (1 , - 1 )
@@ -135,44 +145,76 @@ def draw_path(_path, **kwargs):
135
145
constraints = None ,
136
146
ewaldErrorTolerance = 0.0005 )
137
147
# Create a box used when calling
138
- # Calling U by U(x, box, pairs, ff.paramset.parameters), x is [22, 3] and output the energy, if it is batched, use vmap
139
148
box = np .array ([[50.0 , 0.0 , 0.0 ], [0.0 , 50.0 , 0.0 ], [0.0 , 0.0 , 50.0 ]])
140
149
nbList = NeighborList (box , 4.0 , potentials .meta ["cov_map" ])
141
150
nbList .allocate (init_pdb .getPositions (asNumpy = True ).value_in_unit (unit .nanometer ))
142
151
pairs = nbList .pairs
143
152
144
153
145
154
def U (_x ):
155
+ """
156
+ Calling U by U(x, box, pairs, ff.paramset.parameters), x is [22, 3] and output the energy, if it is batched, use vmap
157
+ """
146
158
_U = potentials .getPotentialFunc ()
147
159
148
160
return _U (_x .reshape (22 , 3 ), box , pairs , ff .paramset .parameters )
149
161
150
162
151
- #TODO: we can introduce gamma here
152
163
def dUdx_fn (_x ):
153
- return jax .grad (lambda _x : U (_x ).sum ())(_x ) / mass
164
+ return jax .grad (lambda _x : U (_x ).sum ())(_x ) / mass / gamma
154
165
155
166
156
167
dUdx_fn = jax .vmap (dUdx_fn )
157
168
dUdx_fn = jax .jit (dUdx_fn )
158
169
170
+
159
171
@jax .jit
160
172
def step (_x , _key ):
161
173
"""Perform one step of forward euler"""
162
174
return _x - dt * dUdx_fn (_x ) + jnp .sqrt (dt ) * sigma * jax .random .normal (_key , _x .shape )
163
175
176
+
177
+ def dUdx_fn_unscaled (_x ):
178
+ return jax .grad (lambda _x : U (_x ).sum ())(_x )
179
+
180
+ dUdx_fn_unscaled = jax .vmap (dUdx_fn_unscaled )
181
+ dUdx_fn_unscaled = jax .jit (dUdx_fn_unscaled )
182
+
183
+ @jax .jit
184
+ def step_langevin (_x , _v , _key ):
185
+ alpha = jnp .exp (- gamma_in_ps * dt_in_ps )
186
+ f_scale = (1 - alpha ) / gamma_in_ps
187
+ new_v_det = alpha * _v + f_scale * - dUdx_fn_unscaled (_x ) / mass
188
+ new_v = new_v_det + jnp .sqrt (kbT * (1 - alpha ** 2 ) / mass ) * jax .random .normal (_key , _x .shape )
189
+
190
+ return _x + dt_in_ps * new_v , new_v
191
+
164
192
key = jax .random .PRNGKey (1 )
193
+ key , velocity_key = jax .random .split (key )
194
+ steps = 1_000_000
165
195
166
- trajectory = [A ] # or [B]
196
+ trajectory = [A ]
167
197
_x = trajectory [- 1 ]
168
- steps = 20_000
198
+
199
+ velocity_variance = unit .Quantity (1 / mass , unit = 1 / unit .dalton ) * unit .BOLTZMANN_CONSTANT_kB * unit .Quantity (temp , unit = unit .kelvin )
200
+ # Although velocity+variance is of the unit J / Da = m^2 / s^2, openmm cannot handle this directly and we need to convert it
201
+ velocity_variance_in_si = 1 / physical_constants ['unified atomic mass unit' ][
202
+ 0 ] * velocity_variance .value_in_unit (unit .joule / unit .dalton )
203
+ # velocity_variance_in_si = unit.Quantity(velocity_variance_in_si, unit.meter / unit.second)
204
+
205
+ _v = jnp .sqrt (velocity_variance_in_si ) * jax .random .normal (velocity_key , _x .shape )
206
+ _v = unit .Quantity (_v , unit .meter / unit .second ).value_in_unit (unit .nanometer / unit .picosecond )
207
+
169
208
for i in trange (steps ):
170
209
key , iter_key = jax .random .split (key )
171
- _x = step (_x , iter_key )
210
+ _x , _v = step_langevin (_x , _v , iter_key )
211
+
172
212
trajectory .append (_x )
173
213
174
214
trajectory = jnp .array (trajectory ).reshape (- 1 , 66 )
175
- assert not jnp .isnan (trajectory ).any ()
215
+
216
+ # we only need to check whether the last frame contains nan, is it propagates
217
+ assert not jnp .isnan (trajectory [- 1 ]).any ()
176
218
trajectory_phi_psi = phis_psis (trajectory , mdtraj_topology )
177
219
178
220
trajs = None
@@ -189,15 +231,56 @@ def step(_x, _key):
189
231
ramachandran (trajectory_phi_psi )
190
232
plt .show ()
191
233
192
-
193
234
# TODO: this is work in progress. Get some baselines with tps
194
- system = tps .System (
195
- jax .jit (
196
- lambda s : jnp .all (jnp .linalg .norm (A .reshape (- 1 , 22 , 3 ) - s .reshape (- 1 , 22 , 3 ), axis = 2 ) <= 5e-2 , axis = 1 )),
197
- jax .jit (
198
- lambda s : jnp .all (jnp .linalg .norm (B .reshape (- 1 , 22 , 3 ) - s .reshape (- 1 , 22 , 3 ), axis = 2 ) <= 5e-2 , axis = 1 )),
199
- step
200
- )
235
+
236
+ # l2_system = tps.System(
237
+ # jax.jit(
238
+ # lambda s: jnp.all(jnp.linalg.norm(A.reshape(-1, 22, 3) - s.reshape(-1, 22, 3), axis=2) <= 5e-2, axis=1)),
239
+ # jax.jit(
240
+ # lambda s: jnp.all(jnp.linalg.norm(B.reshape(-1, 22, 3) - s.reshape(-1, 22, 3), axis=2) <= 5e-2, axis=1)),
241
+ # step
242
+ # )
243
+ #
244
+ # rmsd_system = tps.System(
245
+ # jax.jit(lambda s: kabsch(A.reshape(22, 3), s.reshape(22, 3)) < 0.15),
246
+ # jax.jit(lambda s: kabsch(B.reshape(22, 3), s.reshape(22, 3)) < 0.15),
247
+ # step
248
+ # )
249
+ #
250
+ # # @jax.jit
251
+ # def is_within_phi_psi(s, center, radius, period=2 * jnp.pi):
252
+ # points = phis_psis(s, mdtraj_topology)
253
+ # delta = jnp.abs(center - points)
254
+ # delta = jnp.where(delta > period / 2, delta - period, delta)
255
+ #
256
+ # return jnp.hypot(delta[:, 0], delta[:, 1]) < radius
257
+ #
258
+ #
259
+ # deg = 180.0 / jnp.pi
260
+ # # State('A', torch.tensor([-150, 150]) / deg, torch.tensor([20, 45, 65, 80]) / deg),
261
+ # # State('B', torch.tensor([-70, 135]) / deg, torch.tensor([20, 45, 65, 75]) / deg),
262
+ # # State('C', torch.tensor([-150, -65]) / deg, torch.tensor([20, 45, 60]) / deg),
263
+ # # State('D', torch.tensor([-70, -50]) / deg, torch.tensor([20, 45, 60]) / deg),
264
+ # # State('E', torch.tensor([50, -100]) / deg, torch.tensor([20, 45, 65, 80]) / deg),
265
+ # # State('F', torch.tensor([40, 65]) / deg, torch.tensor([20, 45, 65, 80]) / deg),
266
+ #
267
+ # phi_psi_system = tps.System(
268
+ # lambda s: is_within_phi_psi(s, jnp.array([-150, 150]) / deg, 20 / deg),
269
+ # lambda s: is_within_phi_psi(s, jnp.array([50, -100]) / deg, 20 / deg),
270
+ # step
271
+ # )
272
+ #
273
+ # # TODO: fix vmap
274
+ # filter1 = jax.vmap(phi_psi_system.start_state)(trajectory)
275
+ # filter2 = jax.vmap(phi_psi_system.target_state)(trajectory)
276
+ #
277
+ # plt.title('start')
278
+ # ramachandran(trajectory_phi_psi[filter1])
279
+ # plt.show()
280
+ #
281
+ # plt.title('target')
282
+ # ramachandran(trajectory_phi_psi[filter2])
283
+ # plt.show()
201
284
202
285
# initial_trajectory = [t.reshape(1, -1) for t in interpolate([A, B], 100)]
203
286
0 commit comments